# import libs, load data

In [1]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

# reproducibility
torch.set_deterministic(True)
torch.manual_seed(0)

BATCH_SIZE = 64

# resize: LeNet accepts 32 x 32 
train_dataset = datasets.MNIST(root='../data/MNIST',
                               train=True,
                               download=True,
                               transform=transforms.Compose([transforms.ToTensor(),
                                                            transforms.Resize((32, 32))]))
test_dataset = datasets.MNIST(root='../data/MNIST',
                              train=False,
                              download=True,
                              transform=transforms.Compose([transforms.ToTensor(),
                                                            transforms.Resize((32, 32))]))

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ../data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ../data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ../data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ../data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# LeNet5

In [2]:
class LeNet5(nn.Module):
    def __init__(self, finetune, mu=0, sigma=0.1):
        super(LeNet5, self).__init__()
        self.mu = mu
        self.sigma = sigma
        self.conv1 = nn.Conv2d(1, 6, (5, 5)) 
        self.max_pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(6, 16, (5, 5))
        self.max_pool2 = nn.MaxPool2d((2, 2))
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        # weight initialization
        self._init_weight()
        
    def _init_weight(self):
        for conv in [self.conv1, self.conv2]:
            torch.nn.init.trunc_normal_(conv.weight, self.mu, self.sigma)
            torch.nn.init.zeros_(conv.bias)
        for fc in [self.fc1, self.fc2, self.fc3]:
            torch.nn.init.trunc_normal_(fc.weight, self.mu, self.sigma)
            torch.nn.init.zeros_(fc.bias)
        
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.max_pool1(out)
        out = F.relu(self.conv2(out))
        out = self.max_pool2(out)
        out = out.view(out.shape[0], -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out




# Necessary Functions

In [3]:
# train func
def train(model, train_loader, loss_list, optimizer, criterion, log_interval=400):
    model.train()
    train_loss = 0
    
    for batch_idx, (image, label) in enumerate(train_loader):
        output = model(image)
        optimizer.zero_grad()
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        loss_list.append(loss.item())
        
        if batch_idx % log_interval == 0:
            print("Batch Idx: {}/{}, Train Loss: {:.6f}".format(batch_idx, len(train_loader), loss.item()))
        
    train_loss /= len(train_loader)
    
    return train_loss

# evaluate(test) func
def evaluate(model, test_loader, criterion):
    model.eval()
    val_loss = 0
    correct = 0
    
    for image, label in test_loader:
        with torch.no_grad():
            output = model(image)
            val_loss += criterion(output, label).item()
            pred = torch.max(output, 1)[1]
            correct += (pred==label).sum().item()
    
    val_loss /= len(test_loader)
    acc = (correct/len(test_loader.dataset))*100
    
    return val_loss, acc

# Train

In [4]:
# init
model_origin = LeNet5(finetune=False)
optimizer_origin = optim.Adam(model_origin.parameters())
criterion = nn.CrossEntropyLoss()

print(model_origin)

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (max_pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (max_pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [5]:
EPOCHS = 10
train_losses_origin = []
for epoch in range(1, EPOCHS+1):
    loss = train(model_origin, train_loader, train_losses_origin, optimizer_origin, criterion)
    print("Epoch: {}, Train Loss: {:.6f}".format(epoch, loss))
    print("=" * 50)

Batch Idx: 0/938, Train Loss: 2.356592
Batch Idx: 400/938, Train Loss: 0.123616
Batch Idx: 800/938, Train Loss: 0.038984
Epoch: 1, Train Loss: 0.221470
Batch Idx: 0/938, Train Loss: 0.041801
Batch Idx: 400/938, Train Loss: 0.082240
Batch Idx: 800/938, Train Loss: 0.061273
Epoch: 2, Train Loss: 0.068302
Batch Idx: 0/938, Train Loss: 0.013224
Batch Idx: 400/938, Train Loss: 0.049334
Batch Idx: 800/938, Train Loss: 0.071718
Epoch: 3, Train Loss: 0.050021
Batch Idx: 0/938, Train Loss: 0.022275
Batch Idx: 400/938, Train Loss: 0.024211
Batch Idx: 800/938, Train Loss: 0.083071
Epoch: 4, Train Loss: 0.040342
Batch Idx: 0/938, Train Loss: 0.017178
Batch Idx: 400/938, Train Loss: 0.016958
Batch Idx: 800/938, Train Loss: 0.011800
Epoch: 5, Train Loss: 0.032254
Batch Idx: 0/938, Train Loss: 0.002418
Batch Idx: 400/938, Train Loss: 0.006253
Batch Idx: 800/938, Train Loss: 0.012672
Epoch: 6, Train Loss: 0.027776
Batch Idx: 0/938, Train Loss: 0.007485
Batch Idx: 400/938, Train Loss: 0.079277
Batch Id

In [6]:
val_loss, val_acc = evaluate(model_origin, test_loader, criterion)
print("Validation Loss: {:.6f}, Validation Accuracy: {:.6f}%".format(val_loss, val_acc))

Validation Loss: 0.032983, Validation Accuracy: 98.970000%
