In [1]:
from models.cnn import CNN
import os
import torch
import torch.nn as nn
from torch.optim import Adam, lr_scheduler
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from pgd import PGD

device = torch.device("cuda")
root = "./datasets"
download = os.path.exists("./datasets/MNIST")

model = CNN()
model.to(device)

trainset = datasets.MNIST(root="./datasets", download=download, train=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root="./datasets", download=download, train=False, transform=transforms.ToTensor())
trainloader = DataLoader(trainset, shuffle=True, batch_size=1024)
testloader = DataLoader(testset, shuffle=False, batch_size=128)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(),lr=1e-3)

pgd = PGD(eps=60/255, step_size=20/255, max_iter=10, random_init=True, batch_size=128)

epochs = 100
best_val_acc = 0
for e in range(epochs):
    with tqdm(trainloader,desc=f"{e+1}/{epochs} epochs") as t:
        running_correct = 0
        running_loss = 0
        running_total = 0
        model.train()
        for i, (x,y) in enumerate(t):
            x_adv = pgd.generate(model,x,y,device=device)
            model.train()
            out = model(x_adv.to(device))
            loss = loss_fn(out,y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_correct += (out.max(dim=1)[1]==y.to(device)).sum().item()
            running_loss += loss.item()*x.size(0)
            running_total += x.size(0)
            if i < len(trainloader)-1:
                t.set_postfix({
                    "train_advacc": running_correct/running_total,
                    "train_advloss": running_loss/running_total
                })
            else:
                model.eval()
                val_correct = 0
                val_loss = 0
                val_total = 0
                for _, (x,y) in zip(range(4),testloader):
                    x_adv = pgd.generate(model,x,y,device=device)
                    with torch.no_grad():
                        out = model(x_adv.to(device))
                        loss = loss_fn(out,y.to(device))
                    val_correct += (out.max(dim=1)[1]==y.to(device)).sum().item()
                    val_loss += loss.item()*x.size(0)
                    val_total += x.size(0)
                    t.set_postfix({
                        "train_advacc": running_correct/running_total,
                        "train_advloss": running_loss/running_total,
                        "val_advacc": val_correct/val_total,
                        "val_advloss": val_loss/val_total
                    })
                if val_correct/val_total >= best_val_acc:
                    best_val_acc = val_correct/val_total
                    torch.save(model.state_dict(), "./model_weights/mnist_cnn_adv.pt")

1/100 epochs: 100%|██████████| 59/59 [00:19<00:00,  3.09it/s, train_advacc=0.291, train_advloss=1.98, val_advacc=0.451, val_advloss=1.54]
2/100 epochs: 100%|██████████| 59/59 [00:19<00:00,  3.10it/s, train_advacc=0.521, train_advloss=1.38, val_advacc=0.562, val_advloss=1.17]
3/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.12it/s, train_advacc=0.604, train_advloss=1.13, val_advacc=0.627, val_advloss=0.98] 
4/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.15it/s, train_advacc=0.661, train_advloss=0.983, val_advacc=0.67, val_advloss=0.843] 
5/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.11it/s, train_advacc=0.701, train_advloss=0.873, val_advacc=0.709, val_advloss=0.764]
6/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.14it/s, train_advacc=0.732, train_advloss=0.79, val_advacc=0.762, val_advloss=0.668]
7/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.14it/s, train_advacc=0.76, train_advloss=0.714, val_advacc=0.787, val_advloss=0.604]
8/100 epochs: 100%|████████

59/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.15it/s, train_advacc=0.939, train_advloss=0.189, val_advacc=0.951, val_advloss=0.149] 
60/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.15it/s, train_advacc=0.94, train_advloss=0.188, val_advacc=0.949, val_advloss=0.146] 
61/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.15it/s, train_advacc=0.94, train_advloss=0.188, val_advacc=0.947, val_advloss=0.152] 
62/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.16it/s, train_advacc=0.941, train_advloss=0.187, val_advacc=0.947, val_advloss=0.144] 
63/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.15it/s, train_advacc=0.94, train_advloss=0.186, val_advacc=0.949, val_advloss=0.156] 
64/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.15it/s, train_advacc=0.942, train_advloss=0.185, val_advacc=0.947, val_advloss=0.15]  
65/100 epochs: 100%|██████████| 59/59 [00:18<00:00,  3.15it/s, train_advacc=0.94, train_advloss=0.184, val_advacc=0.949, val_advloss=0.157] 
66/100 epo