In [None]:
from models.vgg import VGG16
import os
import torch
import torch.nn as nn
from torch.optim import SGD, lr_scheduler
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda")
root = "./datasets"
download = os.path.exists("./datasets/cifar-10-batches-py")

model = VGG16()
model.to(device)

transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(size=32, padding=4)
])

transform_test = transforms.ToTensor()

trainset = datasets.CIFAR10(root="./datasets", download=False, train=True, transform=transform_train)
testset = datasets.CIFAR10(root="./datasets", download=False, train=False, transform=transform_test)
trainloader = DataLoader(trainset, shuffle=True, batch_size=128)
testloader = DataLoader(testset, shuffle=False, batch_size=256)

loss_fn = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(),lr=0.1,momentum=0.9,weight_decay=5e-4,nesterov=True)
scheduler = lr_scheduler.StepLR(optimizer,step_size=100,gamma=0.1)

epochs = 200
best_val_acc = 0
for e in range(epochs):
    with tqdm(trainloader,desc=f"{e+1}/{epochs} epochs") as t:
        running_correct = 0
        running_loss = 0
        running_total = 0
        model.train()
        for i, (x,y) in enumerate(t):
            out = model(x.to(device))
            loss = loss_fn(out,y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_correct += (out.max(dim=1)[1]==y.to(device)).sum().item()
            running_loss += loss.item()*x.size(0)
            running_total += x.size(0)
            if i < len(trainloader)-1:
                t.set_postfix({
                    "train_acc": running_correct/running_total,
                    "train_loss": running_loss/running_total
                })
            else:
                scheduler.step()
                model.eval()
                val_correct = 0
                val_loss = 0
                val_total = 0
                with torch.no_grad():
                    for _, (x,y) in zip(range(4),testloader):
                        out = model(x.to(device))
                        loss = loss_fn(out,y.to(device))
                        val_correct += (out.max(dim=1)[1]==y.to(device)).sum().item()
                        val_loss += loss.item()*x.size(0)
                        val_total += x.size(0)
                        t.set_postfix({
                            "train_acc": running_correct/running_total,
                            "train_loss": running_loss/running_total,
                            "val_acc": val_correct/val_total,
                            "val_loss": val_loss/val_total
                        })
                    if val_correct/val_total > best_val_acc:
                        best_val_acc = val_correct/val_total
                        torch.save(model.state_dict(), "./model_weights/cifar_vgg_clean.pt")

1/200 epochs: 100%|██████████| 391/391 [01:03<00:00,  6.20it/s, train_acc=0.22, train_loss=1.92, val_acc=0.291, val_loss=1.84]
2/200 epochs: 100%|██████████| 391/391 [01:12<00:00,  5.39it/s, train_acc=0.342, train_loss=1.67, val_acc=0.39, val_loss=1.68] 
3/200 epochs: 100%|██████████| 391/391 [01:06<00:00,  5.91it/s, train_acc=0.453, train_loss=1.4, val_acc=0.456, val_loss=1.51]
4/200 epochs: 100%|██████████| 391/391 [01:19<00:00,  4.93it/s, train_acc=0.58, train_loss=1.15, val_acc=0.508, val_loss=1.62]
5/200 epochs: 100%|██████████| 391/391 [01:02<00:00,  6.23it/s, train_acc=0.676, train_loss=0.959, val_acc=0.551, val_loss=1.52]
6/200 epochs: 100%|██████████| 391/391 [01:12<00:00,  5.37it/s, train_acc=0.722, train_loss=0.852, val_acc=0.619, val_loss=1.32]
7/200 epochs: 100%|██████████| 391/391 [00:54<00:00,  7.13it/s, train_acc=0.75, train_loss=0.775, val_acc=0.717, val_loss=0.872]
8/200 epochs: 100%|██████████| 391/391 [01:05<00:00,  5.93it/s, train_acc=0.769, train_loss=0.725, val_a

125/200 epochs: 100%|██████████| 391/391 [01:03<00:00,  6.13it/s, train_acc=0.975, train_loss=0.076, val_acc=0.912, val_loss=0.341]
126/200 epochs: 100%|██████████| 391/391 [00:46<00:00,  8.45it/s, train_acc=0.975, train_loss=0.077, val_acc=0.918, val_loss=0.301]
127/200 epochs: 100%|██████████| 391/391 [01:03<00:00,  6.16it/s, train_acc=0.976, train_loss=0.0763, val_acc=0.913, val_loss=0.327]
128/200 epochs: 100%|██████████| 391/391 [00:47<00:00,  8.15it/s, train_acc=0.975, train_loss=0.0773, val_acc=0.913, val_loss=0.318]
129/200 epochs: 100%|██████████| 391/391 [00:44<00:00,  8.73it/s, train_acc=0.976, train_loss=0.0735, val_acc=0.912, val_loss=0.299]
130/200 epochs: 100%|██████████| 391/391 [00:54<00:00,  7.21it/s, train_acc=0.977, train_loss=0.0724, val_acc=0.922, val_loss=0.26] 
131/200 epochs: 100%|██████████| 391/391 [01:11<00:00,  5.49it/s, train_acc=0.975, train_loss=0.0784, val_acc=0.916, val_loss=0.296]
132/200 epochs: 100%|██████████| 391/391 [01:24<00:00,  4.63it/s, train