# HW3 Image Classification
## We strongly recommend that you run with Kaggle for this homework
https://www.kaggle.com/c/ml2022spring-hw3b/code?competitionId=34954&sortBy=dateCreated

In [None]:
_exp_name = "resnet18"
!nvidia-smi

In [None]:
!pip install torchinfo

## Get Data
Notes: if the links are dead, you can download the data directly from Kaggle and upload it to the workspace, or you can use the Kaggle API to directly download the data into colab.


In [None]:
# #! wget https://www.dropbox.com/s/6l2vcvxl54b0b6w/food11.zip
# ! wget -O food11.zip "https://github.com/virginiakm1988/ML2022-Spring/blob/main/HW03/food11.zip?raw=true"
# ! unzip food11.zip

In [None]:
# Import necessary packages.
import pandas as pd
import numpy as np
import torch
import os
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset, random_split
from torchvision.datasets import DatasetFolder, VisionDataset
from torchinfo import summary

# This is for the progress bar.
from tqdm.auto import tqdm
import random

In [None]:
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda:0' if cuda else 'cpu')
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
device

In [None]:
myseed = 3  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

In [None]:
import matplotlib.pyplot as plt

def no_axis_show(img, title='', cmap=None):
    # imshow, and set the interpolation mode to be "nearest"。
    fig = plt.imshow(img, interpolation='nearest', cmap=cmap)
    # do not show the axes in the images.
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(title)

In [None]:
titles = ["Bread", "Dairy product", "Dessert", "Egg", "Fried food", "Meat", "Noodles/Pasta", "Rice", "Seafood", "Soup", "Vegetable/Fruit"]
for i in range(11):
    plt.figure(figsize=(18, 18))
    for j in range(10):
        plt.subplot(1, 10, j+1)
        fig = no_axis_show(plt.imread(f'../input/ml2022spring-hw3b/food11/training/{i}_{j}.jpg'), title=titles[i])
    plt.show()

## Hyperparameter

In [None]:
# The number of training epochs and patience.
n_epochs = 200
patience = 50 # If no improvement in 'patience' epochs, early stop

_dataset_dir = "../input/ml2022spring-hw3b/food11"

batch_size = 32
valid_ratio = 0.1
lr = 0.001
weight_decay = 2e-5

## **Transforms**
Torchvision provides lots of useful utilities for image preprocessing, data wrapping as well as data augmentation.

Please refer to PyTorch official website for details about different transforms.

In [None]:
MEAN = torch.tensor([0.485, 0.456, 0.406]).to(device)
STD = torch.tensor([0.229, 0.224, 0.225]).to(device)

# Normally, We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD),
])

# However, it is also possible to use augmentation in the testing phase.
# You may use train_tfm to produce a variety of images and then test using ensemble methods
train_tfm = transforms.Compose([
    # add some useful transform or augmentation here, according to your experience in HW3.
    transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),#随机裁剪到 256*256
#     transforms.Resize(256),  # You can change this
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD),
])

## **Datasets**
The data is labelled by the name, so we load images and label while calling '__getitem__'

In [None]:
class FoodDataset(Dataset):

    def __init__(self,paths, tfm, files = None):
        super(FoodDataset).__init__()
        self.paths = paths
        self.files = sorted([os.path.join(path,x) for path in paths for x in os.listdir(path) if x.endswith(".jpg")])
        if files != None:
            self.files = files
        print(f"One {paths} sample",self.files[0])
        self.transform = tfm
  
    def __len__(self):
        return len(self.files)
  
    def __getitem__(self,idx):
        fname = self.files[idx]
        im = Image.open(fname)
        im = self.transform(im)
        
        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1 # test has no label
        return im, label

In [None]:
def train_valid_split(data_set, valid_ratio, seed):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set)) 
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return train_set, valid_set

In [None]:
data_set1 = FoodDataset([os.path.join(_dataset_dir,"training")], tfm=train_tfm)
data_set2 = FoodDataset([os.path.join(_dataset_dir,"validation")], tfm=train_tfm)

data_set = ConcatDataset([data_set1, data_set2])
train_set, valid_set = train_valid_split(data_set, valid_ratio, myseed)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

print("train:", len(train_set))
print("valid:", len(valid_set))

In [None]:
def tensor_show(imgs, size_inches=(15, 10)):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    fig.set_size_inches(*size_inches)
    for i, img in enumerate(imgs):
        img = torchvision.transforms.functional.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

In [None]:
import torchvision

recover_from_normalize = lambda img: img * STD[:, None, None] + MEAN[:, None, None]

imgs = iter(train_loader).next()[0][:30]
grid = torchvision.utils.make_grid([recover_from_normalize(img.to(device)) for img in imgs], nrow=10)
tensor_show(grid, size_inches=(15, 5))

imgs = iter(valid_loader).next()[0][:30]
grid = torchvision.utils.make_grid([recover_from_normalize(img.to(device)) for img in imgs], nrow=10)
tensor_show(grid, size_inches=(15, 5))

## Model

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        # torch.nn.MaxPool2d(kernel_size, stride, padding)
        # input 維度 [3, 128, 128]
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),  # [64, 128, 128]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [64, 64, 64]

            nn.Conv2d(64, 128, 3, 1, 1), # [128, 64, 64]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [128, 32, 32]

            nn.Conv2d(128, 256, 3, 1, 1), # [256, 32, 32]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [256, 16, 16]

            nn.Conv2d(256, 512, 3, 1, 1), # [512, 16, 16]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # [512, 8, 8]
            
            nn.Conv2d(512, 512, 3, 1, 1), # [512, 8, 8]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # [512, 4, 4]
        )
        self.fc = nn.Sequential(
            nn.Linear(512*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 11)
        )

    def forward(self, x):
        out = self.cnn(x)
        out = out.view(out.size()[0], -1)
        return self.fc(out)

### 自實現 ResNet18

In [None]:
# https://zhuanlan.zhihu.com/p/157134695

class ResBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResBlock, self).__init__()
        #这里定义了残差块内连续的2个卷积层
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            #shortcut，这里为了跟2个卷积层的结果结构一致，要做处理
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )
            
    def forward(self, x):
        out = self.left(x)
        #将2个卷积层的输出跟处理过的x相加，实现ResNet的基本结构
        out = out + self.shortcut(x)
        out = nn.functional.relu(out)
        
        return out

In [None]:
class ResNet18(nn.Module):
    def __init__(self, ResBlock, num_classes=11):
        super(ResNet18, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1)
        self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2)        
        self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2)        
        self.fc = nn.Linear(8192, num_classes)
#         self.dropout = nn.Dropout(p=0.25)  # dropout
    #这个函数主要是用来，重复同一个残差块    
    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)
    
    def forward(self, x):
        #在这里，整个ResNet18的结构就很清晰了
        out = self.conv1(x) # [64, 128, 128]
        out = self.layer1(out) # [64, 128, 128]
        out = self.layer2(out) # [128, 64, 64]
        out = self.layer3(out) # [256, 32, 32]
        out = self.layer4(out) # [512, 16, 16]
        out = nn.functional.avg_pool2d(out, 4) # [512, 4, 4]
#         out = self.dropout(out) # dropout
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
class EasyNet(nn.Module):
    def __init__(self):
        super(EasyNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU()
        )     
        self.fc = nn.Linear(3*32*32, 11)
    
    def forward(self, x):
        out = self.conv1(x) # [3, 128, 128]
        out = nn.functional.avg_pool2d(out, 4) # [3, 32, 32]
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

## Warmup

In [None]:
def get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: float = 0.5,
    last_epoch: int = -1,
):
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.

    Args:
        optimizer (:class:`~torch.optim.Optimizer`):
        The optimizer for which to schedule the learning rate.
        num_warmup_steps (:obj:`int`):
        The number of steps for the warmup phase.
        num_training_steps (:obj:`int`):
        The total number of training steps.
        num_cycles (:obj:`float`, `optional`, defaults to 0.5):
        The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
        following a half-cosine).
        last_epoch (:obj:`int`, `optional`, defaults to -1):
        The index of the last epoch when resuming training.

    Return:
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """
    def lr_lambda(current_step, lowerbound=0.003):
        # Warmup
        if current_step < num_warmup_steps:
            return max(lowerbound, float(current_step) / float(max(1, num_warmup_steps)))
        # decadence
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps)
        )
        return max(
            lowerbound, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
        )
    

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

In [None]:
import math

def show_plot(total):
    optimizer = torch.optim.SGD(torch.nn.Linear(2, 1).parameters(), lr=lr)
    scheduler = get_cosine_schedule_with_warmup(optimizer, total//10, total)

    lrs = []

    for i in range(total):
        optimizer.step()
        lrs.append(optimizer.param_groups[0]["lr"])
        scheduler.step()

    plt.plot(range(total), lrs)

    print(lrs[-5:])
    print(lrs[:5])

total_steps = len(train_loader) * n_epochs
show_plot(total_steps)

## Train

In [None]:
# %%script false --no-raise-error

# modelSel = Classifier
# modelSel = ResNet18(ResBlock, 11)
modelSel = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False, num_classes=11)

# Initialize a model, and put it on the device specified.
model = modelSel.to(device)
# model.load_state_dict(torch.load(f"../input/ml2022hw3tmp/version7_best.ckpt"))

# For the classification task, we use cross-entropy as the measurement of performance.
criterion = nn.CrossEntropyLoss()

# Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 

total_steps = len(train_loader) * n_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, total_steps//10, total_steps)

In [None]:
# %%script false --no-raise-error

# Initialize trackers, these are not parameters and should not be changed
stale = 0
best_acc = 0

for epoch in range(n_epochs):

    # ---------- Training ----------
    # Make sure the model is in train mode before training.
    model.train()

    # These are used to record information in training.
    train_loss = []
    train_accs = []

    for batch in tqdm(train_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        #imgs = imgs.half()
        #print(imgs.shape,labels.shape)

        # Forward the data. (Make sure data and model are on the same device.)
        logits = model(imgs.to(device))

        # Calculate the cross-entropy loss.
        # We don't need to apply softmax before computing cross-entropy as it is done automatically.
        loss = criterion(logits, labels.to(device))

        # Gradients stored in the parameters in the previous step should be cleared out first.
        optimizer.zero_grad()

        # Compute the gradients for parameters.
        loss.backward()

        # Clip the gradient norms for stable training.
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # Update the parameters with computed gradients.
        optimizer.step()
        
        # Update learning rate
        scheduler.step()

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        train_loss.append(loss.item())
        train_accs.append(acc)
        
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}, lr = {optimizer.param_groups[0]['lr']:.6f}")

    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    model.eval()

    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(imgs.to(device))

        # We can still compute the loss (but not the gradient).
        loss = criterion(logits, labels.to(device))

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)
        #break

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    # update logs
    if valid_acc > best_acc:
        print(f"[ Valid | {stale + 1:03d}/{epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f} -> best")
    else:
        print(f"[ Valid | {stale + 1:03d}/{epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")


    # save models
    if valid_acc > best_acc:
        print(f"Best model found at epoch {epoch + 1}, saving model {_exp_name}_best.ckpt")
        torch.save(model.state_dict(), f"{_exp_name}_best.ckpt") # only save best to prevent output memory exceed error
        best_acc = valid_acc
        stale = 0
    else:
        stale += 1
        if stale > patience:
            print(f"No improvment {patience} consecutive epochs, early stopping")
            break

torch.save(model.state_dict(), f"{_exp_name}_model_last.ckpt") 

### DML

In [None]:
%%script false --no-raise-error

# Initialize a model, and put it on the device specified.
model1 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False, num_classes=11).to(device)
model2 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False, num_classes=11).to(device)

# For the classification task, we use cross-entropy as the measurement of performance.
criterion = nn.CrossEntropyLoss()

# Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.
optimizer1 = torch.optim.Adam(model1.parameters(), lr=lr, weight_decay=weight_decay) 
optimizer2 = torch.optim.Adam(model2.parameters(), lr=lr, weight_decay=weight_decay) 

total_steps = len(train_loader) * n_epochs
scheduler1 = get_cosine_schedule_with_warmup(optimizer1, total_steps//10, total_steps)
scheduler2 = get_cosine_schedule_with_warmup(optimizer2, total_steps//10, total_steps)

In [None]:
%%script false --no-raise-error
import torch.nn.functional as F

def loss_fn_dml(input, target):
    return F.kl_div(F.log_softmax(input, dim=1),
                        F.softmax(target, dim=1),
                        log_target=False,
                        reduction='batchmean')

In [None]:
%%script false --no-raise-error

# Initialize trackers, these are not parameters and should not be changed
stale = 0
best_acc = 0

for epoch in range(n_epochs):

    # ---------- Training ----------
    # Make sure the model is in train mode before training.
    model1.train()
    model2.train()

    # These are used to record information in training.
    train_loss1 = []
    train_accs1 = []
    train_loss2 = []
    train_accs2 = []

    for batch in tqdm(train_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)
        # print(imgs.shape, labels.shape)
        
        # Forward the data. (Make sure data and model are on the same device.)
        logits1 = model1(imgs)
        logits2 = model2(imgs)

        # Calculate the cross-entropy loss.
        # We don't need to apply softmax before computing cross-entropy as it is done automatically.
        loss1 = criterion(logits1, labels) + loss_fn_dml(logits1, logits2.detach())
        loss2 = criterion(logits2, labels) + loss_fn_dml(logits2, logits1.detach())

        optimizer1.zero_grad() 
        loss1.backward() 
        grad_norm = nn.utils.clip_grad_norm_(model1.parameters(), max_norm=10)
        optimizer1.step() 
        scheduler1.step()
        
        optimizer2.zero_grad() 
        loss2.backward() 
        grad_norm = nn.utils.clip_grad_norm_(model2.parameters(), max_norm=10)
        optimizer2.step() 
        scheduler2.step()

        # Compute the accuracy for current batch.
        acc1 = (logits1.argmax(dim=-1) == labels).float().mean()
        acc2 = (logits2.argmax(dim=-1) == labels).float().mean()

        # Record the loss and accuracy.
        train_loss1.append(loss1.item())
        train_accs1.append(acc1)
        train_loss2.append(loss2.item())
        train_accs2.append(acc2)
        
    train_loss1 = sum(train_loss1) / len(train_loss1)
    train_acc1 = sum(train_accs1) / len(train_accs1)
    train_loss2 = sum(train_loss2) / len(train_loss2)
    train_acc2 = sum(train_accs2) / len(train_accs2)

    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss1:.5f}|{train_loss2:.5f}, acc = {train_acc1:.5f}|{train_acc2:.5f}, lr = {optimizer1.param_groups[0]['lr']:.6f}|{optimizer2.param_groups[0]['lr']:.6f}")

    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    model1.eval()
    model2.eval()

    # These are used to record information in validation.
    valid_loss1 = []
    valid_accs1 = []
    valid_loss2 = []
    valid_accs2 = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits1 = model1(imgs)
            logits2 = model2(imgs)

        # We can still compute the loss (but not the gradient).
        loss1 = criterion(logits1, labels)
        loss2 = criterion(logits2, labels)

        # Compute the accuracy for current batch.
        acc1 = (logits1.argmax(dim=-1) == labels).float().mean()
        acc2 = (logits2.argmax(dim=-1) == labels).float().mean()

        # Record the loss and accuracy.
        valid_loss1.append(loss1.item())
        valid_accs1.append(acc1)
        valid_loss2.append(loss2.item())
        valid_accs2.append(acc2)

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss1 = sum(valid_loss1) / len(valid_loss1)
    valid_acc1 = sum(valid_accs1) / len(valid_accs1)
    valid_loss2 = sum(valid_loss2) / len(valid_loss2)
    valid_acc2 = sum(valid_accs2) / len(valid_accs2)

    # update logs
    if max(valid_acc1, valid_acc2) > best_acc:
        print(f"[ Valid | {stale + 1:03d}/{epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss1:.5f}|{valid_loss2:.5f}, acc = {valid_acc1:.5f}|{valid_acc2:.5f} -> best")
    else:
        print(f"[ Valid | {stale + 1:03d}/{epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss1:.5f}|{valid_loss2:.5f}, acc = {valid_acc1:.5f}|{valid_acc2:.5f}")


    # save models
    if max(valid_acc1, valid_acc2) > best_acc:
        best_acc = max(valid_acc1, valid_acc2)
        if valid_acc1 == best_acc:
            print(f"Best model found at epoch {epoch + 1}, saving model1 {_exp_name}_best.ckpt")
            torch.save(model1.state_dict(), f"{_exp_name}_best.ckpt") # only save best to prevent output memory exceed error
            torch.save(model2.state_dict(), f"{_exp_name}_other.ckpt") # only save best to prevent output memory exceed error
        else:
            print(f"Best model found at epoch {epoch + 1}, saving model2 {_exp_name}_best.ckpt")
            torch.save(model2.state_dict(), f"{_exp_name}_best.ckpt") # only save best to prevent output memory exceed error
            torch.save(model1.state_dict(), f"{_exp_name}_other.ckpt") # only save best to prevent output memory exceed error
        
        stale = 0
    else:
        stale += 1
        if stale > patience:
            print(f"No improvment {patience} consecutive epochs, early stopping")
            break

torch.save(model1.state_dict(), f"{_exp_name}_model1_last.ckpt") 
torch.save(model2.state_dict(), f"{_exp_name}_model2_last.ckpt") 

## Testing and generate prediction CSV

In [None]:
class FoodDatasetTest(Dataset):
    def __init__(self, paths, tfms, files=None, n=5):
        super(FoodDataset).__init__()
        self.paths = paths
        self.files = sorted([os.path.join(path,x) for path in paths for x in os.listdir(path) if x.endswith(".jpg")])
        if files != None:
            self.files = files
        print(f"One {paths} sample",self.files[0])
        self.transforms = tfms
        self.n = n
  
    def __len__(self):
        return len(self.files)
  
    def __getitem__(self,idx):
        fname = self.files[idx]
        im = Image.open(fname)
        im_test = self.transforms[0](im)
        
        im_train = []
        for _ in range(self.n):
            im_train.append(self.transforms[1](im))
            
        return im_test, im_train


In [None]:
test_set = FoodDatasetTest([os.path.join(_dataset_dir,"test")], tfms=[test_tfm, train_tfm], n=10)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
def avg_train_tfm_pred(model, imgs):
    logits = torch.zeros(imgs[0].shape[0], 11).to(device)
    
    with torch.no_grad():
        for img in imgs:
            logits += model(img.to(device))
        logits /= len(imgs)
        
    return logits

In [None]:
# model_best = ResNet18(ResBlock, 11).to(device)
model_best = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False, num_classes=11).to(device)
# model_best.load_state_dict(torch.load(f"{_exp_name}_best.ckpt", map_location=device))
model_best.load_state_dict(torch.load(f"resnet18_model_last.ckpt", map_location=device))

In [None]:
summary(model_best, (batch_size, 3, 224, 224), device=device)

In [None]:
model_best.eval()
prediction = []
with torch.no_grad():
    for im_test, im_train in tqdm(test_loader):
        im_test = im_test.to(device)
        test_pred = model_best(im_test) + avg_train_tfm_pred(model_best, im_train)
        test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
        prediction += test_label.squeeze().tolist()

In [None]:
#create test csv
def pad4(i):
    return "0"*(4-len(str(i)))+str(i)
df = pd.DataFrame()
df["Id"] = [pad4(i) for i in range(1,len(test_set)+1)]
df["Category"] = prediction
df.to_csv(f"submission_{best_acc}.csv",index = False)