In [1]:
import torch
import torchvision

In [8]:
from train.train import Trainer

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# get data

# note: autougment doesn't seem to work very well 

from utility.data.preprocessing import Autoaugment_preprocess
transform = Autoaugment_preprocess(channels=3, resize_dim=(32,32), crop_dim=(28,28)) 


root = "/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/data" # change this to your data directory

trainset = torchvision.datasets.CIFAR10(root=root, train=True,
                                      download=True, transform=transform.transform)

trainset, valset = torch.utils.data.random_split(trainset, [45000, 5000])

testset = torchvision.datasets.CIFAR10(root=root, train=False,
                                       download=True, transform=transform.transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=64,
                                        shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

data_loaders = {
    "train_loader": trainloader,
    "val_loader": valloader,
    "test_loader": testloader
}          

In [15]:
# create simple model 

import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [16]:
from methods.SAM.sam import SAM
from torch.optim.lr_scheduler import StepLR
from utility.custom_scheduler import StepLR_SAM

model = SimpleCNN()

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.005)
scheduler = StepLR(optimizer, step_size=1, gamma=0.01)

In [17]:
sam_optimizer = SAM(model.parameters(), 
                torch.optim.SGD, 
                rho=2, 
                adaptive=True, 
                lr=0.1, momentum=0.9, weight_decay=0.005)
SAM_scheduler = StepLR_SAM(sam_optimizer, step_size=1, gamma=0.01)

In [18]:
simple_net = Trainer(data_loaders=data_loaders,
                     dataset_name="CIFAR10",
                     model=model,
                     optimizer=optimizer,
                     loss_fn=nn.CrossEntropyLoss(label_smoothing=0.1),
                     device=device,
                     seed=42,
                     exp_path="/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/Code/experiments",
                     exp_name="test_1",
                     use_early_stopping=True,
                     scheduler=scheduler
                    )

In [19]:
preds = simple_net.get_predictions(test=True)

In [None]:
from utility.plot import plot_confusion_matrix

plot_confusion_matrix(preds["y_true"], preds["y_pred"], 
                      normalize=None, # takes one of {"true", "pred", "all"}
                      fmt="d", # set to ".2f" for percentages
                      width=10,
                      height=4,
                      type="normal", # takes one of {"normal", "one_vs_all"} 
                      onv_vs_all_class=None,
                      cmap="Blues",
                      classes=None, # select classes to make the confusion matrix, if None uses all  
                      ) 

In [None]:
simple_net.main(epochs=2, 
                log_interval=0.1)

In [None]:
# visualize with tensorboard
%reload_ext tensorboard
%tensorboard --logdir=/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/Code/experiments # experiment path