In [25]:
import numpy as np

import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler

In [26]:
class ClassifierCNN(nn.Module):
    def __init__(self, class_num) -> None:
        super().__init__()
        self.class_num = class_num
        
        self.conv_net = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5),
            nn.BatchNorm2d(10),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5),
            nn.BatchNorm2d(20),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        self.fc_net = nn.Sequential(
            nn.Linear(320, 50),
            nn.BatchNorm1d(50),
            nn.ReLU(),
            nn.Linear(50, self.class_num),
            nn.Softmax()
        )
    
    def forward(self, x):
        return self.conv_net(x).view(-1, 320) 
        
        

In [27]:
def weight_init(m):
    
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)

In [28]:
def one_hot_embedding(labels, num_classes):
    y = torch.eye(num_classes)
    return y[labels]

In [30]:
epochs = 5
learning_rate = 0.01
batch_size = 100
valid_size = 500

loss_function = nn.BCELoss()


dataset = datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307, ), (0.3081, ))
                   ]))


num_train = len(dataset)

indices = list(range(num_train))
split = num_train-valid_size
np.random.shuffle(indices)


train_idx, valid_idx = indices[:split], indices[split:]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)


train_loader = torch.utils.data.DataLoader(dataset,
                                           batch_size=batch_size, 
                                           sampler=train_sampler
                                           )
valid_loader = torch.utils.data.DataLoader(dataset,
                                           batch_size=batch_size, 
                                           sampler=valid_sampler
                                           )

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=batch_size,
)

train_loss_list = []
val_loss_list = []


In [23]:

# net = ClassifierCNN(class_num=10).cuda()
net = ClassifierCNN(class_num=10)
net.apply(weight_init)
net.train()


for epoch in range(epochs):
    for i, (X, t) in enumerate(train_loader):
        # X = X.view(-1, 784)
        # X = X.cuda()
        # t = one_hot_embedding(t, 10).cuda()
        
        t = one_hot_embedding(t, 10)
        Y = net(X)
        
        loss = loss_function(Y, t)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        # if i % 100 == 0:
        #     with torch.no_grad():
        #         val_100_loss = []
        #         for (X, t) in valid_loader:
        #             # X = X.view(-1, 784)
        #             
        #             Y = net(X)
        #             t = one_hot_embedding(t, 10)
        #             
        #             loss = loss_function(Y, t)
        #             
        #             
        #             val_100_loss.append(loss)
        #         
        #         train_loss_list.append(loss)
        #         val_loss_list.append(np.asarray(val_100_loss).sum()/len(valid_loader))
        #     print(f"[{i}/{len(train_loader)}][{epoch}/{epochs}] loss: {loss}")

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 10 1 5 5, but got 2-dimensional input of size [100, 784] instead