In [4]:
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from torchvision import transforms, datasets, models
from torch.utils.data import random_split, DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 1
img_size = 224
batch_size = 64
num_classes = 10
lr = 3e-4
T = transforms.Compose(
    [
     transforms.Resize((img_size, img_size)),
     transforms.ToTensor()
    ]
)
print(device)

cuda


In [5]:
data = datasets.CIFAR10("data/", train=True, download=True, transform=T)
test_data = datasets.CIFAR10("data/", train=False, download=True, transform=T)

val_len = int(0.3 * len(data))
train_data, val_data = random_split(data, [len(data) - val_len, val_len])

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

x, y = next(iter(train_loader))
print(len(train_data), len(val_data), len(test_data), x.shape, y.shape)

Files already downloaded and verified
Files already downloaded and verified
35000 15000 10000 torch.Size([64, 3, 224, 224]) torch.Size([64])


In [6]:
net = models.resnet18(pretrained=False)
net.fc = nn.Linear(net.fc.in_features, num_classes)
net.to(device)

inp = torch.randn(1, 3, 224, 224).to(device)
out = net(inp)
print(out.shape)

torch.Size([1, 10])


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [7]:
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
def get_accuracy(preds, y):
    preds = preds.argmax(1)
    num_correct = (preds == y).sum().item()
    acc = num_correct / y.shape[0]
    return acc

In [8]:
def loop(net, loader, is_train, epoch=None):
    net.train(is_train)
    losses = []
    accs = []
    if is_train:
        split = 'train'
    else:
        split = ' val '

    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)

        with torch.set_grad_enabled(is_train):
            preds = net(x)
            loss = loss_fn(preds, y)
            acc = get_accuracy(preds, y)
            losses.append(loss.item())
            accs.append(acc.item())

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if epoch != None:
            pbar.set_description(f'{split}: epoch={epoch}, loss={np.mean(losses):.4f}, acc={np.mean(accs):.4f}')
        else:
            pbar.set_description(f'loss={np.mean(losses):.4f}, acc={np.mean(accs):.4f}')

In [9]:
for epoch in range(n_epochs):
    loop(net, train_loader, True, epoch)
    loop(net, val_loader, False, epoch)

train: epoch=0, loss=1.3293, acc=0.5191: 100%|██████████| 547/547 [03:15<00:00,  2.80it/s]
 val : epoch=0, loss=1.2758, acc=0.5424: 100%|██████████| 235/235 [00:32<00:00,  7.15it/s]


In [10]:
loop(net, test_loader, False)

loss=1.2643, acc=0.5444: 100%|██████████| 157/157 [00:22<00:00,  7.09it/s]
