<a href="https://colab.research.google.com/github/pooriaazami/deep_learning_class_notebooks/blob/main/CIFAR10_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

source: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as T

from torch.utils.data import DataLoader

from tqdm.notebook import tqdm

In [2]:
EPOCHS = 10
BATCH_SIZE = 32
LR =  1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [30]:
transforms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
train_dataset = torchvision.datasets.CIFAR10(root='.',
                             train=True,
                             download=True,
                             transform=transforms)

test_dataset = torchvision.datasets.CIFAR10(root='.',
                             train=False,
                             download=True,
                             transform=transforms)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv_1_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
    self.conv_1_2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)

    self.conv_2_1 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
    self.conv_2_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)

    self.conv_3_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
    self.conv_3_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

    self.fc1 = nn.Linear(in_features=1024, out_features=128)
    self.fc2 = nn.Linear(in_features=128, out_features=10)

  def forward(self, x):
    x = F.relu(self.conv_1_1(x))
    x = F.relu(self.conv_1_2(x))
    x = F.max_pool2d(x, kernel_size=2)

    x = F.relu(self.conv_2_1(x))
    x = F.relu(self.conv_2_2(x))
    x = F.max_pool2d(x, kernel_size=2)

    x = F.relu(self.conv_3_1(x))
    x = F.relu(self.conv_3_2(x))
    x = F.max_pool2d(x, kernel_size=2)

    x = x.view(-1, 1024)

    x = F.relu(self.fc1(x))
    x = self.fc2(x)

    return x

In [7]:
network = Network().to(DEVICE)

In [8]:
optimizer = optim.Adam(network.parameters(), lr=LR)

In [9]:
loss_function = nn.CrossEntropyLoss()

In [10]:
for i in range(1, EPOCHS+1):
  total_loss = .0
  total_corrects = .0
  print(f'Epoch {i}:')
  for x, y in tqdm(train_dataloader):
    optimizer.zero_grad()

    x = x.to(DEVICE)
    y = y.to(DEVICE)

    preds = network(x)
    loss = loss_function(preds, y)
    total_corrects += torch.sum(preds.argmax(dim=-1).detach().cpu() == y.cpu())
    loss.backward()
    optimizer.step()

    total_loss += loss.detach().cpu().item()

  accuracy = total_corrects / len(train_dataset) * 100
  print(f'loss: {total_loss:.2f} accuracy: {accuracy:.2f} %')

Epoch 1:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 2552.98 accuracy: 39.93 %
Epoch 2:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 1911.36 accuracy: 55.69 %
Epoch 3:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 1609.17 accuracy: 63.40 %
Epoch 4:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 1405.82 accuracy: 68.17 %
Epoch 5:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 1253.16 accuracy: 71.67 %
Epoch 6:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 1131.23 accuracy: 74.37 %
Epoch 7:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 1036.97 accuracy: 76.59 %
Epoch 8:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 944.41 accuracy: 78.57 %
Epoch 9:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 882.76 accuracy: 79.71 %
Epoch 10:


  0%|          | 0/1563 [00:00<?, ?it/s]

loss: 813.16 accuracy: 81.52 %


In [29]:
with torch.no_grad():
  total_corrects = .0
  for x, y in tqdm(test_dataloader):
    x = x.to(DEVICE)
    y = y.to(DEVICE)

    preds = network(x)
    total_corrects += torch.sum(preds.argmax(dim=-1).detach().cpu() == y.cpu())

  accuracy = total_corrects / len(test_dataset) * 100
  print(f'Test accuracy: {accuracy:.2f} %')

  0%|          | 0/313 [00:00<?, ?it/s]

Test accuracy: 71.79 %


In [12]:
torch.save(network.state_dict(), 'model.pth')

In [15]:
del network

In [28]:
network = Network()
network.load_state_dict(torch.load('model.pth'))
network = network.to(DEVICE)

In [25]:
# list(filter(lambda x: 'state_dict' in x, dir(network)))