In [1]:
import torch
import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 20
batch_size = 64
lr = 1e-3

In [3]:
device

device(type='cuda')

In [4]:
transform = transforms.Compose([transforms.ToTensor()])

In [5]:
train_data = datasets.CIFAR10("data/", train=True, download=True, transform=transform)
val_data = datasets.CIFAR10("data/", train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_batches = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_batches = DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [7]:
len(train_data.classes)

10

In [8]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=(3, 3))
        self.fc1 = nn.Linear(32 * 30 * 30, 10)

        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.dropout(self.conv1(x))
        x = self.relu(x)

        x = x.view(x.shape[0], -1)
        x = self.fc1(x)

        return x

In [9]:
net = Net().to(device)

In [10]:
inp = torch.randn(1, 3, 32, 32).to(device)
output = net(inp)
output.shape

torch.Size([1, 10])

In [11]:
num_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
num_parameters

288906

In [12]:
opt = torch.optim.Adam(net.parameters(), lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.1, patience=2, verbose=True)
loss_fn = nn.CrossEntropyLoss()

In [13]:
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)

    return acc.item()

In [14]:
def loop(net, batches, train):
    batch_losses = []
    batch_accs = []

    if train:
        print("Train Loop:")
        print("")
        net.train()

        for X, y in tqdm.tqdm(batches, total=len(batches)):
            X = X.to(device)
            y = y.to(device)

            preds = net(X)
            loss = loss_fn(preds, y)
            acc = get_accuracy(preds, y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_losses.append(loss.item())
            batch_accs.append(acc)

    else:
        print("Validation Loop:")
        print("")
        net.eval()

        with torch.no_grad():
            for X, y in tqdm.tqdm(batches, total=len(batches)):
                X = X.to(device)
                y = y.to(device)

                preds = net(X)
                loss = loss_fn(preds, y)
                acc = get_accuracy(preds, y)

                batch_losses.append(loss.item())
                batch_accs.append(acc) 

    print("")
    print("")
    
    return sum(batch_losses) / len(batch_losses), sum(batch_accs) / len(batch_accs)

In [15]:
for epoch in range(epochs):
    train_loss, train_acc = loop(net, train_batches, True)
    val_loss, val_acc = loop(net, val_batches, False)

    scheduler.step(train_loss)
    print(f"epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")
    print("")

  1%|▏         | 11/782 [00:00<00:07, 106.46it/s]

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.78it/s]
  9%|▉         | 14/157 [00:00<00:01, 137.66it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 139.32it/s]
  2%|▏         | 12/782 [00:00<00:06, 114.27it/s]



epoch: 0 | train_loss: 1.5199 | train_acc: 0.4682 | val_loss: 1.3216 | val_acc: 0.5304

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.21it/s]
 10%|▉         | 15/157 [00:00<00:01, 141.55it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 136.90it/s]
  2%|▏         | 12/782 [00:00<00:06, 115.52it/s]



epoch: 1 | train_loss: 1.2682 | train_acc: 0.5576 | val_loss: 1.2408 | val_acc: 0.5700

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.35it/s]
 10%|▉         | 15/157 [00:00<00:00, 142.05it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 133.13it/s]
  1%|▏         | 11/782 [00:00<00:07, 107.53it/s]



epoch: 2 | train_loss: 1.1850 | train_acc: 0.5886 | val_loss: 1.2469 | val_acc: 0.5703

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.00it/s]
  9%|▉         | 14/157 [00:00<00:01, 138.78it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 135.36it/s]
  2%|▏         | 12/782 [00:00<00:06, 110.72it/s]



epoch: 3 | train_loss: 1.1301 | train_acc: 0.6084 | val_loss: 1.1741 | val_acc: 0.5942

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 111.12it/s]
  9%|▉         | 14/157 [00:00<00:01, 136.17it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 133.80it/s]
  2%|▏         | 12/782 [00:00<00:06, 112.38it/s]



epoch: 4 | train_loss: 1.0786 | train_acc: 0.6244 | val_loss: 1.1530 | val_acc: 0.6015

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 110.05it/s]
  9%|▉         | 14/157 [00:00<00:01, 139.12it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 134.35it/s]
  1%|▏         | 11/782 [00:00<00:07, 106.16it/s]



epoch: 5 | train_loss: 1.0449 | train_acc: 0.6385 | val_loss: 1.1417 | val_acc: 0.6038

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 111.64it/s]
  9%|▉         | 14/157 [00:00<00:01, 134.95it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 137.96it/s]
  1%|▏         | 11/782 [00:00<00:07, 108.07it/s]



epoch: 6 | train_loss: 1.0038 | train_acc: 0.6525 | val_loss: 1.1500 | val_acc: 0.5954

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 113.22it/s]
 10%|▉         | 15/157 [00:00<00:00, 142.21it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 139.23it/s]
  1%|▏         | 11/782 [00:00<00:07, 109.99it/s]



epoch: 7 | train_loss: 0.9764 | train_acc: 0.6609 | val_loss: 1.1369 | val_acc: 0.6115

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.77it/s]
  9%|▉         | 14/157 [00:00<00:01, 136.53it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 138.19it/s]
  2%|▏         | 12/782 [00:00<00:06, 117.17it/s]



epoch: 8 | train_loss: 0.9502 | train_acc: 0.6703 | val_loss: 1.1353 | val_acc: 0.6046

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.75it/s]
 10%|▉         | 15/157 [00:00<00:00, 143.53it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 138.26it/s]
  1%|▏         | 11/782 [00:00<00:07, 109.17it/s]



epoch: 9 | train_loss: 0.9344 | train_acc: 0.6759 | val_loss: 1.1440 | val_acc: 0.6052

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.35it/s]
 10%|▉         | 15/157 [00:00<00:01, 141.13it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 139.26it/s]
  2%|▏         | 12/782 [00:00<00:06, 113.53it/s]



epoch: 10 | train_loss: 0.9081 | train_acc: 0.6847 | val_loss: 1.1265 | val_acc: 0.6102

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.50it/s]
 10%|▉         | 15/157 [00:00<00:00, 142.11it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 138.16it/s]
  2%|▏         | 12/782 [00:00<00:06, 112.73it/s]



epoch: 11 | train_loss: 0.8923 | train_acc: 0.6909 | val_loss: 1.1251 | val_acc: 0.6120

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.46it/s]
 10%|▉         | 15/157 [00:00<00:00, 143.23it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 137.41it/s]
  2%|▏         | 12/782 [00:00<00:07, 107.66it/s]



epoch: 12 | train_loss: 0.8771 | train_acc: 0.6942 | val_loss: 1.1487 | val_acc: 0.6114

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.34it/s]
  9%|▉         | 14/157 [00:00<00:01, 138.57it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 136.02it/s]
  1%|▏         | 11/782 [00:00<00:07, 108.61it/s]



epoch: 13 | train_loss: 0.8703 | train_acc: 0.7000 | val_loss: 1.1801 | val_acc: 0.5921

Train Loop:



100%|██████████| 782/782 [00:07<00:00, 111.25it/s]
  8%|▊         | 13/157 [00:00<00:01, 128.18it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 135.54it/s]
  2%|▏         | 12/782 [00:00<00:06, 111.21it/s]



epoch: 14 | train_loss: 0.8568 | train_acc: 0.7026 | val_loss: 1.2051 | val_acc: 0.6015

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.89it/s]
 10%|▉         | 15/157 [00:00<00:00, 143.08it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 139.45it/s]
  1%|▏         | 11/782 [00:00<00:07, 106.20it/s]



epoch: 15 | train_loss: 0.8465 | train_acc: 0.7044 | val_loss: 1.1365 | val_acc: 0.6236

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 113.02it/s]
  9%|▉         | 14/157 [00:00<00:01, 136.61it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 137.83it/s]
  1%|▏         | 11/782 [00:00<00:07, 108.71it/s]



epoch: 16 | train_loss: 0.8348 | train_acc: 0.7114 | val_loss: 1.2085 | val_acc: 0.6074

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.77it/s]
  9%|▉         | 14/157 [00:00<00:01, 136.66it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 136.90it/s]
  1%|▏         | 11/782 [00:00<00:07, 108.60it/s]



epoch: 17 | train_loss: 0.8295 | train_acc: 0.7113 | val_loss: 1.1427 | val_acc: 0.6213

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 112.72it/s]
 10%|▉         | 15/157 [00:00<00:00, 142.72it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 136.86it/s]
  1%|▏         | 11/782 [00:00<00:07, 103.88it/s]



epoch: 18 | train_loss: 0.8191 | train_acc: 0.7143 | val_loss: 1.1382 | val_acc: 0.6204

Train Loop:



100%|██████████| 782/782 [00:06<00:00, 113.53it/s]
 10%|▉         | 15/157 [00:00<00:01, 141.51it/s]



Validation Loop:



100%|██████████| 157/157 [00:01<00:00, 136.82it/s]



epoch: 19 | train_loss: 0.8092 | train_acc: 0.7179 | val_loss: 1.1448 | val_acc: 0.6191




