In [112]:
import torch
import torchvision
from tqdm import tqdm
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [108]:
train = torchvision.datasets.MNIST("", train=True, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),download=False)
test = torchvision.datasets.MNIST("", train=False, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),download=False)

trainset = torch.utils.data.DataLoader(train, batch_size=65, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=65, shuffle=True)


In [110]:
class Block(torch.nn.Module):
    def __init__(self, stride=(1,1), channels=(1, 1)):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(channels[0], 
                                     channels[1], 
                                     kernel_size=(3, 3), 
                                     dilation=(1, 1), 
                                     padding=(1,1), 
                                     stride=(1,1)) 
        
        self.conv2 = torch.nn.Conv2d(channels[1],
                                     channels[1],
                                     kernel_size=(3,3),
                                     dilation=(1,1),
                                     padding=(1,1),
                                     stride=(1,1))
        
        self.relu = torch.nn.functional.relu
        self.btn = torch.nn.BatchNorm2d(channels[1])
        
    def forward(self, x):
        org_input = x
        x = self.conv1(x)
        x = self.relu(x)
        x = self.btn(x)
        x = self.conv2(x)
        return self.relu(org_input + x)      # TODO: optional W and H

    

class ResNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.b1 = Block()
        self.b2 = Block()
        self.b3 = Block()
        self.b4 = Block()
        self.b5 = Block()
        self.b6 = Block()
        self.b7 = Block()
        self.b8 = Block()
        self.b9 = Block()
        self.b10 = Block()
        self.b11 = Block()
        
        self.conv = torch.nn.Conv2d(1,1, 
                                     kernel_size=(3,3), 
                                     padding=(1,1),
                                     stride=2, 
                                     dilation=(1,1))
                
        self.pool = torch.nn.MaxPool2d(2)
        self.ln = torch.nn.Linear(3*3, 10)
        self.relu = torch.nn.functional.relu
        
        
    def forward(self, x):
        x = self.pool(x)
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        x = self.pool(x)
        x = self.b6(x)
        x = self.b7(x)
        x = self.b8(x)
        x = self.b9(x)
        x = self.pool(x)
        x = self.b10(x)
        x = self.b11(x).view(-1, 3*3)
        
        x = self.relu(self.ln(x))
        return x

In [None]:
model = ResNet().to(device)


In [84]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
epochs = 20
training_loss = []
training_acc = []


#training
for e in range(epochs):
    train_loss = 0.0
    train_acc = 0.0
    
    model.train()
    for data in tqdm(trainset):
        X, y = data
        optimizer.zero_grad()
        
        output = model(X)
        
        loss = criterion(output, y)
        
        loss.backward()
        optimizer.step()
        
        prediction = torch.argmax(output, dim=1)
        
        train_loss += loss.item()
        train_acc += (prediction == y).sum().item()
        
    training_loss.append(train_loss/len(trainset))
    training_acc.append(train_acc/len(train))
        

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

# some graphs
plt.plot(training_loss)
plt.plot(training_acc)

In [None]:
# test on individual image
n = 35

img = test.__getitem__(n)[0]
print(img.shape)

pred = model(img.unsqueeze(1))
print(pred)
print(torch.argmax(pred))

plt.imshow(img.numpy()[0])