In [3]:
import torch
import torch.nn as nn

In [4]:
class myLenet(nn.Module):
  def __init__(self):
    super(myLenet,self).__init__()
    self.layer1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,stride=1)
    self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.tan1 = nn.ReLU(inplace=True)
    self.layer2 = nn.Conv2d(6,16,5)
    self.pool2 = nn.MaxPool2d(2,2)
    self.tan2 = nn.ReLU(inplace=True)
    self.flatten = nn.Flatten()
    self.FC1 = nn.Linear(16*5*5,84)
    self.tan3 = nn.ReLU(inplace=True)
    self.FC2 = nn.Linear(84,10)

  def forward(self,x):
    x = self.tan1(self.pool1(self.layer1(x)))
    x = self.tan2(self.pool2(self.layer2(x)))
    x = self.flatten(x)
    x = self.tan3(self.FC1(x))
    x = self.FC2(x)

    return x


In [5]:
from torchsummary import summary

#Lenet for CIFAR10

In [8]:
import os

import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision
from torchsummary import summary

torch.manual_seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [10]:
batch_size = 256
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 81903104.62it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


#MODEL

In [11]:
class lenetCNN(nn.Module):
  def __init__(self, imdim=3, num_classes=10):
    super(lenetCNN,self).__init__()
    self.layer1 = nn.Conv2d(imdim,64,5)
    self.mp1 = nn.MaxPool2d(2)
    self.ac1 = nn.ReLU(inplace=True)
    self.layer2 = nn.Conv2d(64,128,5)
    self.mp2 = nn.MaxPool2d(2)
    self.ac2 = nn.ReLU(inplace=True)
    self.flatten = nn.Flatten()
    self.FC1 = nn.Linear(128*5*5,1024)
    self.ac3 = nn.ReLU()
    self.FC2 = nn.Linear(1024,10)

  def forward(self,x):
    x = self.ac1(self.mp1(self.layer1(x)))
    x = self.ac2(self.mp2(self.layer2(x)))
    x = self.flatten(x)
    x = self.ac3(self.FC1(x))
    x = self.FC2(x)
    return x


In [15]:
model = lenetCNN()

#Train

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-3)

In [17]:
losses_total = []
accuracy = []
max_epoch = 10

In [19]:
model = model.to(device)

In [20]:
for i in range(max_epoch):
  model.train()
  cnt_acc,cnt_total = 0,0
  losses = []
  for idx,(inputs,outputs) in enumerate(trainloader):
    inputs = inputs.to(device)
    outputs = outputs.to(device)
    optimizer.zero_grad()


    pred = model(inputs)
    loss = criterion(pred,outputs)
    losses.append(loss.item())


    loss.backward()
    optimizer.step()

    #cal accuracy
    _, predicted = torch.max(pred.data, 1)
    cnt_acc += (predicted == outputs).sum().item()
    cnt_total += outputs.size(0)
  accuracy = 100 * cnt_acc / cnt_total
  test_loss = sum(losses) / len(losses)
  print(f"epoch : {i} | acc : {accuracy} | loss : {test_loss}")


epoch : 0 | acc : 45.07 | loss : 1.5018063187599182
epoch : 1 | acc : 60.342 | loss : 1.1077983324625054
epoch : 2 | acc : 67.896 | loss : 0.9082521008593696
epoch : 3 | acc : 72.72 | loss : 0.781157235405883
epoch : 4 | acc : 76.012 | loss : 0.6803731620311737
epoch : 5 | acc : 79.37 | loss : 0.585021584313743
epoch : 6 | acc : 83.14 | loss : 0.48300820953991946
epoch : 7 | acc : 86.26 | loss : 0.39195807567056345
epoch : 8 | acc : 89.694 | loss : 0.3026476597451434
epoch : 9 | acc : 91.994 | loss : 0.23697072001440184
