In [1]:
import torch
import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 20
batch_size = 64
lr = 1e-3

In [3]:
device

device(type='cuda')

In [4]:
transform = transforms.Compose([transforms.ToTensor()])

In [None]:
train_data = datasets.CIFAR10("data/", train=True, download=True, transform=transform)
val_data = datasets.CIFAR10("data/", train=False, download=True, transform=transform)

In [6]:
train_batches = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_batches = DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [7]:
len(train_data.classes)

10

In [8]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=(3, 3))
        self.fc1 = nn.Linear(32 * 30 * 30, 10)

        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.dropout(self.conv1(x))
        x = self.relu(x)

        x = x.view(x.shape[0], -1)
        x = self.fc1(x)

        return x

In [9]:
net = Net().to(device)

In [10]:
inp = torch.randn(1, 3, 32, 32).to(device)
output = net(inp)
output.shape

torch.Size([1, 10])

In [11]:
num_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
num_parameters

288906

In [12]:
opt = torch.optim.Adam(net.parameters(), lr)
loss_fn = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

In [13]:
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)

    return acc.item()

In [14]:
def loop(net, batches, train):
    batch_losses = []
    batch_accs = []

    if train:
        print("Train Loop:")
        print("")
        net.train()

        for X, y in tqdm.tqdm(batches, total=len(batches)):
            X = X.to(device)
            y = y.to(device)


            with torch.cuda.amp.autocast():
                preds = net(X)
                loss = loss_fn(preds, y)
                acc = get_accuracy(preds, y)

            opt.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            batch_losses.append(loss.item())
            batch_accs.append(acc)

    else:
        print("Validation Loop:")
        print("")
        net.eval()

        with torch.no_grad():
            for X, y in tqdm.tqdm(batches, total=len(batches)):
                X = X.to(device)
                y = y.to(device)

                preds = net(X)
                loss = loss_fn(preds, y)
                acc = get_accuracy(preds, y)

                batch_losses.append(loss.item())
                batch_accs.append(acc) 

    print("")
    print("")
    
    return sum(batch_losses) / len(batch_losses), sum(batch_accs) / len(batch_accs)

In [15]:
for epoch in range(epochs):
    train_loss, train_acc = loop(net, train_batches, True)
    val_loss, val_acc = loop(net, val_batches, False)

    print(f"epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")
    print("")

  0%|          | 0/782 [00:00<?, ?it/s]

Train Loop:



100%|██████████| 782/782 [00:08<00:00, 95.78it/s] 
  9%|▉         | 14/157 [00:00<00:01, 138.95it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 137.86it/s]
  1%|▏         | 11/782 [00:00<00:07, 105.35it/s]



epoch: 0 | train_loss: 1.5285 | train_acc: 0.4650 | val_loss: 1.3442 | val_acc: 0.5213

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 101.12it/s]
  9%|▉         | 14/157 [00:00<00:01, 131.72it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 128.58it/s]
  1%|▏         | 10/782 [00:00<00:08, 93.25it/s]



epoch: 1 | train_loss: 1.2899 | train_acc: 0.5486 | val_loss: 1.2712 | val_acc: 0.5545

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 99.48it/s] 
 10%|▉         | 15/157 [00:00<00:00, 142.74it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 140.30it/s]
  1%|▏         | 11/782 [00:00<00:07, 103.66it/s]



epoch: 2 | train_loss: 1.2220 | train_acc: 0.5777 | val_loss: 1.2415 | val_acc: 0.5655

Train Loop:



100%|██████████| 782/782 [00:08<00:00, 95.04it/s]
 10%|▉         | 15/157 [00:00<00:01, 140.76it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 130.78it/s]
  1%|▏         | 11/782 [00:00<00:07, 101.79it/s]



epoch: 3 | train_loss: 1.1644 | train_acc: 0.5935 | val_loss: 1.2051 | val_acc: 0.5790

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 97.98it/s]
  7%|▋         | 11/157 [00:00<00:01, 109.16it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 121.90it/s]
  1%|          | 9/782 [00:00<00:08, 89.82it/s]



epoch: 4 | train_loss: 1.1224 | train_acc: 0.6083 | val_loss: 1.2163 | val_acc: 0.5729

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 101.75it/s]
  9%|▉         | 14/157 [00:00<00:01, 139.02it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 131.03it/s]
  1%|          | 9/782 [00:00<00:08, 89.44it/s]



epoch: 5 | train_loss: 1.0865 | train_acc: 0.6223 | val_loss: 1.1838 | val_acc: 0.5920

Train Loop:



100%|██████████| 782/782 [00:08<00:00, 96.41it/s]
  8%|▊         | 12/157 [00:00<00:01, 119.90it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 135.15it/s]
  1%|▏         | 11/782 [00:00<00:07, 102.61it/s]



epoch: 6 | train_loss: 1.0496 | train_acc: 0.6342 | val_loss: 1.2007 | val_acc: 0.5814

Train Loop:



100%|██████████| 782/782 [00:08<00:00, 97.57it/s]
  8%|▊         | 12/157 [00:00<00:01, 114.38it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 122.67it/s]
  1%|▏         | 11/782 [00:00<00:07, 101.74it/s]



epoch: 7 | train_loss: 1.0243 | train_acc: 0.6448 | val_loss: 1.1600 | val_acc: 0.6019

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 99.80it/s] 
 10%|▉         | 15/157 [00:00<00:00, 142.84it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 125.51it/s]
  1%|          | 9/782 [00:00<00:08, 87.67it/s]



epoch: 8 | train_loss: 0.9956 | train_acc: 0.6531 | val_loss: 1.1759 | val_acc: 0.5944

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 101.51it/s]
  8%|▊         | 13/157 [00:00<00:01, 128.38it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 139.36it/s]
  1%|▏         | 11/782 [00:00<00:07, 102.87it/s]



epoch: 9 | train_loss: 0.9754 | train_acc: 0.6612 | val_loss: 1.2397 | val_acc: 0.5782

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 99.97it/s]
  8%|▊         | 12/157 [00:00<00:01, 118.34it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 137.22it/s]
  1%|▏         | 10/782 [00:00<00:07, 99.56it/s]



epoch: 10 | train_loss: 0.9571 | train_acc: 0.6679 | val_loss: 1.1567 | val_acc: 0.6006

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 100.37it/s]
  9%|▉         | 14/157 [00:00<00:01, 139.88it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 136.82it/s]
  1%|▏         | 10/782 [00:00<00:07, 99.87it/s]



epoch: 11 | train_loss: 0.9356 | train_acc: 0.6748 | val_loss: 1.1873 | val_acc: 0.5970

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 101.30it/s]
  9%|▉         | 14/157 [00:00<00:01, 137.52it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 123.88it/s]
  1%|▏         | 10/782 [00:00<00:07, 99.17it/s]



epoch: 12 | train_loss: 0.9199 | train_acc: 0.6799 | val_loss: 1.2043 | val_acc: 0.5947

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 99.13it/s] 
 10%|▉         | 15/157 [00:00<00:00, 142.84it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 138.66it/s]
  1%|▏         | 11/782 [00:00<00:07, 103.61it/s]



epoch: 13 | train_loss: 0.9031 | train_acc: 0.6841 | val_loss: 1.1801 | val_acc: 0.6027

Train Loop:



100%|██████████| 782/782 [00:08<00:00, 97.55it/s] 
 10%|▉         | 15/157 [00:00<00:00, 144.20it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 129.47it/s]
  1%|▏         | 10/782 [00:00<00:07, 96.75it/s]



epoch: 14 | train_loss: 0.8897 | train_acc: 0.6884 | val_loss: 1.1900 | val_acc: 0.5991

Train Loop:



100%|██████████| 782/782 [00:08<00:00, 97.39it/s]
  9%|▉         | 14/157 [00:00<00:01, 139.95it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 136.83it/s]
  1%|          | 9/782 [00:00<00:08, 89.56it/s]



epoch: 15 | train_loss: 0.8753 | train_acc: 0.6922 | val_loss: 1.1557 | val_acc: 0.6088

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 98.85it/s]
 10%|▉         | 15/157 [00:00<00:00, 143.26it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 137.95it/s]
  1%|▏         | 10/782 [00:00<00:08, 91.00it/s]



epoch: 16 | train_loss: 0.8626 | train_acc: 0.6987 | val_loss: 1.1642 | val_acc: 0.6075

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 98.51it/s] 
  9%|▉         | 14/157 [00:00<00:01, 139.77it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 141.23it/s]
  1%|▏         | 11/782 [00:00<00:07, 103.57it/s]



epoch: 17 | train_loss: 0.8519 | train_acc: 0.7019 | val_loss: 1.1869 | val_acc: 0.6052

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 99.11it/s] 
 10%|▉         | 15/157 [00:00<00:00, 145.28it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 132.97it/s]
  1%|▏         | 10/782 [00:00<00:08, 91.76it/s]



epoch: 18 | train_loss: 0.8349 | train_acc: 0.7070 | val_loss: 1.1629 | val_acc: 0.6064

Train Loop:



100%|██████████| 782/782 [00:08<00:00, 96.24it/s]
  9%|▉         | 14/157 [00:00<00:01, 138.05it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 140.42it/s]



epoch: 19 | train_loss: 0.8292 | train_acc: 0.7099 | val_loss: 1.1834 | val_acc: 0.6039




