# VGG5 SpinalNet (for MNIST)

In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

class VGG(nn.Module):  
         
    def __init__(self, num_classes=10):
        super().__init__()
        self.l1 = self.two_conv_pool(1, 64, 64)
        self.l2 = self.two_conv_pool(64, 128, 128)
        self.l3 = self.three_conv_pool(128, 256, 256, 256)
        self.l4 = self.three_conv_pool(256, 256, 256, 256)

    def two_conv_pool(self, in_channels, f1, f2):
        s = nn.Sequential(
            nn.Conv2d(in_channels, f1, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(f1),
            nn.ReLU(inplace=True),
            nn.Conv2d(f1, f2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(f2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        for m in s.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        return s
    
    def three_conv_pool(self,in_channels, f1, f2, f3):
        s = nn.Sequential(
            nn.Conv2d(in_channels, f1, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(f1),
            nn.ReLU(inplace=True),
            nn.Conv2d(f1, f2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(f2),
            nn.ReLU(inplace=True),
            nn.Conv2d(f2, f3, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(f3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        for m in s.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        return s

    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        
        return x

class SpinalVGG5(nn.Module):
    def __init__(self, num_classes = 10):
        super().__init__()
        self.vgg = VGG(num_classes)

        feat_dim = 256
        half_dim = feat_dim //2
        self.half_dim = half_dim

        self.linear1 = nn.Sequential(
            nn.Dropout(p = 0.5), nn.Linear(half_dim, feat_dim), nn.BatchNorm1d(feat_dim), nn.GELU()
        )
        self.linear2 = nn.Sequential(
            nn.Dropout(p = 0.5), nn.Linear(half_dim + feat_dim, feat_dim), nn.BatchNorm1d(feat_dim), nn.GELU()
        )
        self.linear3 = nn.Sequential(
            nn.Dropout(p = 0.5), nn.Linear(half_dim + feat_dim, feat_dim), nn.BatchNorm1d(feat_dim), nn.GELU()
        )
        self.linear4 = nn.Sequential(
            nn.Dropout(p = 0.5), nn.Linear(half_dim + feat_dim, feat_dim), nn.BatchNorm1d(feat_dim), nn.GELU()
        )
        self.proj = nn.Linear(4 * feat_dim, num_classes)


    def forward(self, x):
        x = self.vgg(x)
        x = x.view(x.size(0), -1)

        x1 = self.linear1(x[:, :self.half_dim])
        x2 = self.linear2(torch.cat([x[:, self.half_dim:], x1], dim = 1))
        x3 = self.linear3(torch.cat([x[:, :self.half_dim], x2], dim = 1))
        x4 = self.linear4(torch.cat([x[:, self.half_dim:], x3], dim = 1))

        x = torch.cat([x1, x2], dim = 1)
        x = torch.cat([x, x3], dim = 1)
        x = torch.cat([x, x4], dim =1)

        return self.proj(x)


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SpinalVGG5(10).to(DEVICE)


# Load Dataset and DataLoader

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
mnist_transform = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]

mnist = datasets.MNIST(root = './data', download = True, transform = transforms.Compose(mnist_transform), train = True)
# print(len(mnist))
# dataloader = DataLoader(dataset= mnist, batch_size= 128, shuffle = True)

dataset, valid_dataset = torch.utils.data.random_split(mnist, [59000, 1000])

dataloader = DataLoader(dataset, batch_size =128, shuffle = True)
val_dataloader = DataLoader(valid_dataset, batch_size =128, shuffle = True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# Train

In [3]:
from tqdm.notebook import tqdm
import numpy as np
criterion = nn.CrossEntropyLoss()

optim = torch.optim.Adam(params = model.parameters(), lr = 1e-4)

for epoch in range(80):
    print("EPOCH : {}".format(epoch + 1))
    losses = []
    model.train()
    for i, (img, label) in enumerate(tqdm(dataloader)):
        img = img.to(DEVICE)
        label = label.to(DEVICE)

        pred = model(img)

        loss = criterion(pred, label)

        losses.append(loss.item())

        optim.zero_grad()
        loss.backward()
        optim.step()

    model.eval()
    scores = 0
    all_data = 0
    with torch.no_grad():
        for i , (img, label) in enumerate(tqdm(val_dataloader)):
            img = img.to(DEVICE)
            label = label.to(DEVICE)

            pred = model(img)

            arg = F.softmax(pred,dim=1).argmax(dim=1)
            score = (arg == label).sum().item()
            scores += score
            all_data += len(img)

    print("Loss : %f \t Score : %f"% (np.mean(losses), scores/all_data))



EPOCH : 1


  0%|          | 0/461 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.303420 	 Score : 0.993000
EPOCH : 2


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.037222 	 Score : 0.994000
EPOCH : 3


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.020868 	 Score : 0.983000
EPOCH : 4


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.012844 	 Score : 0.990000
EPOCH : 5


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.012799 	 Score : 0.996000
EPOCH : 6


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.009013 	 Score : 0.995000
EPOCH : 7


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.010499 	 Score : 0.993000
EPOCH : 8


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.008455 	 Score : 0.995000
EPOCH : 9


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.006892 	 Score : 0.992000
EPOCH : 10


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004820 	 Score : 0.991000
EPOCH : 11


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.005815 	 Score : 0.993000
EPOCH : 12


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.007244 	 Score : 0.995000
EPOCH : 13


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004890 	 Score : 0.990000
EPOCH : 14


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.005413 	 Score : 0.987000
EPOCH : 15


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004098 	 Score : 0.994000
EPOCH : 16


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.006059 	 Score : 0.989000
EPOCH : 17


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003304 	 Score : 0.994000
EPOCH : 18


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004258 	 Score : 0.994000
EPOCH : 19


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001963 	 Score : 0.997000
EPOCH : 20


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004500 	 Score : 0.991000
EPOCH : 21


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004409 	 Score : 0.997000
EPOCH : 22


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004897 	 Score : 0.993000
EPOCH : 23


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004509 	 Score : 0.992000
EPOCH : 24


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001879 	 Score : 0.993000
EPOCH : 25


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.005281 	 Score : 0.997000
EPOCH : 26


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001934 	 Score : 0.996000
EPOCH : 27


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002417 	 Score : 0.995000
EPOCH : 28


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002821 	 Score : 0.982000
EPOCH : 29


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002993 	 Score : 0.993000
EPOCH : 30


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003226 	 Score : 0.993000
EPOCH : 31


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003206 	 Score : 0.994000
EPOCH : 32


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002971 	 Score : 0.996000
EPOCH : 33


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003070 	 Score : 0.994000
EPOCH : 34


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003132 	 Score : 0.994000
EPOCH : 35


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001257 	 Score : 0.997000
EPOCH : 36


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000687 	 Score : 0.994000
EPOCH : 37


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.004666 	 Score : 0.997000
EPOCH : 38


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002988 	 Score : 0.996000
EPOCH : 39


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002511 	 Score : 0.998000
EPOCH : 40


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000728 	 Score : 0.997000
EPOCH : 41


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000698 	 Score : 1.000000
EPOCH : 42


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003159 	 Score : 0.998000
EPOCH : 43


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001587 	 Score : 0.998000
EPOCH : 44


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001840 	 Score : 0.998000
EPOCH : 45


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002281 	 Score : 0.998000
EPOCH : 46


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003033 	 Score : 0.996000
EPOCH : 47


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001849 	 Score : 0.997000
EPOCH : 48


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000208 	 Score : 0.998000
EPOCH : 49


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000242 	 Score : 0.997000
EPOCH : 50


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003915 	 Score : 0.994000
EPOCH : 51


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002452 	 Score : 0.997000
EPOCH : 52


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001514 	 Score : 0.996000
EPOCH : 53


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001448 	 Score : 0.995000
EPOCH : 54


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000604 	 Score : 0.997000
EPOCH : 55


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001759 	 Score : 0.998000
EPOCH : 56


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001171 	 Score : 0.996000
EPOCH : 57


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002958 	 Score : 0.997000
EPOCH : 58


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002081 	 Score : 0.996000
EPOCH : 59


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000726 	 Score : 0.996000
EPOCH : 60


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000085 	 Score : 0.998000
EPOCH : 61


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000265 	 Score : 0.998000
EPOCH : 62


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000798 	 Score : 0.998000
EPOCH : 63


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.005710 	 Score : 0.998000
EPOCH : 64


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000812 	 Score : 0.999000
EPOCH : 65


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000194 	 Score : 0.999000
EPOCH : 66


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001352 	 Score : 0.996000
EPOCH : 67


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001602 	 Score : 0.998000
EPOCH : 68


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002718 	 Score : 0.997000
EPOCH : 69


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000959 	 Score : 0.995000
EPOCH : 70


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001237 	 Score : 0.996000
EPOCH : 71


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001992 	 Score : 0.995000
EPOCH : 72


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000802 	 Score : 0.998000
EPOCH : 73


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000086 	 Score : 0.998000
EPOCH : 74


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001233 	 Score : 0.998000
EPOCH : 75


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.003183 	 Score : 0.996000
EPOCH : 76


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000524 	 Score : 0.997000
EPOCH : 77


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.000945 	 Score : 0.997000
EPOCH : 78


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.002110 	 Score : 0.992000
EPOCH : 79


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001434 	 Score : 0.994000
EPOCH : 80


  0%|          | 0/461 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Loss : 0.001026 	 Score : 0.999000
