# If you want to access the version you have already modified, click "Edit"
# If you want to access the original sample code, click "...", then click "Copy & Edit Notebook"

In [None]:
## This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        pass
        #print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
_exp_name = "sample"

In [None]:
# Import necessary packages.
import numpy as np
import torch
import os
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from torchvision.datasets import DatasetFolder, VisionDataset

# This is for the progress bar.
from tqdm.auto import tqdm
import random
from torch.cuda import amp
from pathlib import Path
import pandas as pd

In [None]:
myseed = 6666  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

## **Transforms**
Torchvision provides lots of useful utilities for image preprocessing, data wrapping as well as data augmentation.

Please refer to PyTorch official website for details about different transforms.

In [None]:
# Normally, We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# However, it is also possible to use augmentation in the testing phase.
# You may use train_tfm to produce a variety of images and then test using ensemble methods
train_tfm = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.5),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomAffine(degrees=10, translate=(0.3, 0.3), scale=(0.5, 1.5), shear=10),
    transforms.RandomCrop(256),
    transforms.ToTensor(),
])


## **Datasets**
The data is labelled by the name, so we load images and label while calling '__getitem__'

In [None]:
class FoodDataset(Dataset):

    def __init__(self,path,tfm=test_tfm,files = None):
        super(FoodDataset).__init__()
        self.path = path
        self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(".jpg")])
        if files != None:
            self.files = files
        print(f"One {path} sample",self.files[0])
        self.transform = tfm
  
    def __len__(self):
        return len(self.files)
  
    def __getitem__(self,idx):
        fname = self.files[idx]
        im = Image.open(fname)
        im = self.transform(im)
        #im = self.data[idx]
        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1 # test has no label
        return im,label


## model

In [None]:
class Bottleneck(nn.Module):
    expansion = 4
 
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
 
    def forward(self, x):
        residual = x
 
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
 
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
 
        out = self.conv3(out)
        out = self.bn3(out)
 
        if self.downsample is not None:
            residual = self.downsample(x)
 
        out += residual
        out = self.relu(out)
 
        return out


class ResNet(nn.Module):
 
    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)  # /2
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # /2
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  # /2
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  # /2
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)  # /2
        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
 
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
 
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
 
        return nn.Sequential(*layers)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
 
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
 
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
 
        return x


def resnet18(pretrained=False, num_classes=11):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

## dataloader

In [None]:
batch_size = 128
_dataset_dir = "../input/ml2022spring-hw3b/food11"
# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
train_set = FoodDataset(os.path.join(_dataset_dir,"training"), tfm=train_tfm)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_set = FoodDataset(os.path.join(_dataset_dir,"validation"), tfm=test_tfm)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

## EMA

In [None]:
from copy import deepcopy
import numpy as np


class ExponentialMovingAverageModel:
    """
    从始至终维持一个model，并不断更新该model的参数，但该mdoel仅仅是为了inference。
    随着训练的进行，越靠后面的模型参数对ema模型的影响越大。
    """
    def __init__(self, model, decay_ratio=0.9999, tot_epoch=200, update_num=0):
        self.ema = deepcopy(model).eval()
        self.update_num = update_num
        self.get_decay_weight = lambda x: decay_ratio * (1 - np.exp(-x / tot_epoch))
        for parm in self.ema.parameters():
            parm.requires_grad_(False)

    def update(self, model):
        with torch.no_grad():
            self.update_num += 1
            decay_weight = self.get_decay_weight(self.update_num)
            cur_state_dict = model.state_dict()
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= decay_weight
                    v += (1 - decay_weight) * cur_state_dict[k].detach()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize a model, and put it on the device specified.
model = resnet18().to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=1e-5) 
tot_epoch = 500
init_lr = 0.001
weight_decay = 1e-4
optimizer = torch.optim.SGD(model.parameters(), lr=init_lr, weight_decay=weight_decay)

def _lr_lambda(epoch, scheduler_type='linear'):
    lr_bias = 0.01  # lr_bias越大lr的下降速度越慢，整个epoch跑完最后的lr值也越大
    if scheduler_type == 'linear':
        return (1 - epoch / (tot_epoch - 1)) * (1. - lr_bias) + lr_bias
    elif scheduler_type == 'cosine':
        return ((1 + math.cos(epoch * math.pi / tot_epoch)) / 2) * (1. - lr_bias) + lr_bias  # cosine
    else:
        return math.pow(1 - epoch / tot_epoch, 0.9)
    
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=_lr_lambda)

In [None]:
use_cuda = True if device == "cuda" else False
scaler = amp.GradScaler(enabled=use_cuda)  # mix precision training
ema_model = ExponentialMovingAverageModel(model, tot_epoch=tot_epoch)
accumulate = 256

## trainer

In [None]:
def train(model, train_loader, criterion, optimizer, lr_scheduler, n_epochs=100):
    
    patience = 30 # If no improvement in 'patience' epochs, early stop
    
    stale = 0
    best_acc = 0

    for epoch in range(n_epochs):
        model.train()
        # These are used to record information in training.
        train_loss = []
        train_accs = []
        for i, batch in enumerate(tqdm(train_loader)):
            cur_steps = epoch * len(train_loader) + i + 1
            imgs, labels = batch
            #imgs = imgs.half()
            
            # ==================================
            with amp.autocast(enabled=use_cuda):
                logits = model(imgs.to(device))
                loss = criterion(logits, labels.to(device)) / (accumulate / batch_size)
            # backward
            scaler.scale(loss).backward()
             # optimize
            if cur_steps % accumulate == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                ema_model.update(model)
            # ===============================
    
#             logits = model(imgs.to(device))
#             loss = criterion(logits, labels.to(device))
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
                    
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
            train_loss.append(loss.item())
            train_accs.append(acc)

        train_loss = sum(train_loss) / len(train_loss)
        train_acc = sum(train_accs) / len(train_accs)

        # Print the information.
        lr = lr_scheduler.get_last_lr()[0]
        print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}, lr = {lr:.2e}")

        # ---------- Validation ----------
        model.eval()

        # These are used to record information in validation.
        valid_loss = []
        valid_accs = []

        # Iterate the validation set by batches.
        for batch in tqdm(valid_loader):
            imgs, labels = batch
            #imgs = imgs.half()

            with torch.no_grad():
                logits = model(imgs.to(device))

            loss = criterion(logits, labels.to(device))
            acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

            # Record the loss and accuracy.
            valid_loss.append(loss.item())
            valid_accs.append(acc)

        # The average loss and accuracy for entire validation set is the average of the recorded values.
        valid_loss = sum(valid_loss) / len(valid_loss)
        valid_acc = sum(valid_accs) / len(valid_accs)

        # Print the information.
        print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")

        # update logs
        if valid_acc > best_acc:
            with open(f"./{_exp_name}_log.txt","a"):
                print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f} -> best")
        else:
            with open(f"./{_exp_name}_log.txt","a"):
                print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")

        # save models
        if valid_acc > best_acc:
            print(f"Best model found at epoch {epoch}, saving model")
            torch.save(ema_model.ema.state_dict(), f"./ema_{_exp_name}_best.ckpt") # only save best to prevent output memory exceed error
            torch.save(model.state_dict(), f"./{_exp_name}_best.ckpt")
            best_acc = valid_acc
            stale = 0
            print(f"model_path: {Path(f'./{_exp_name}_best.ckpt').resolve()}")
        else:
            stale += 1
            if stale > patience:
                print(f"No improvment {patience} consecutive epochs, early stopping")
                break
        lr_scheduler.step()

In [None]:
train(model, train_loader, criterion, optimizer, lr_scheduler, n_epochs=tot_epoch)

## Testing and generate prediction CSV

In [None]:
test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=test_tfm)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

### TTA

In [None]:
def predict(model, model_path, test_loader, tta):
    def pad4(i):
        return "0"*(4-len(str(i)))+str(i)
    
    probs = []
    prediction = []
    model.load_state_dict(torch.load(model_path))
    model.eval()
    with torch.no_grad():
        for data,_ in test_loader:
            if tta:
                data = data.to(device)
                test_pred_vertical_flip = model(torch.flip(data, dims=[2]))  # vertical flip
                test_pred_horizontal_flip = model(torch.flip(data, dims=[3]))  # horizontal flip
                test_pred = test_pred_vertical_flip * 0.5 + test_pred_horizontal_flip * 0.5  # (bs, cls_num)
            else:
                test_pred = model(data.to(device))  # (bs, cls_num)
                
            probs.append(test_pred.cpu().numpy())
            test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
            prediction += test_label.squeeze().tolist()
            
    df = pd.DataFrame()
    df["Id"] = [pad4(i) for i in range(1,len(test_set)+1)]
    df["Category"] = prediction
    df.to_csv(f"submission_{Path(model_path).stem}.csv",index = False)
    return np.concatenate(probs, axis=0)

In [None]:
model = resnet18().to(device)
model_path = "./ema_sample_best.ckpt"
probs = predict(model, model_path, test_loader, True)

In [None]:
# model_best = Classifier().to(device)
# model_best.load_state_dict(torch.load(f"{_exp_name}_best.ckpt"))
# model_best.eval()
# prediction = []
# with torch.no_grad():
#     for data,_ in test_loader:
#         test_pred = model_best(data.to(device))
#         test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
#         prediction += test_label.squeeze().tolist()

In [None]:
#create test csv
# def pad4(i):
#     return "0"*(4-len(str(i)))+str(i)
# df = pd.DataFrame()
# df["Id"] = [pad4(i) for i in range(1,len(test_set)+1)]
# df["Category"] = prediction
# df.to_csv("submission.csv",index = False)