# Implementation Guidelines of Sample Code (Pytorch)

    See the annotations at every markdown blocks correspoding to each code blocks, and also # TODO annotations. :D

# Usage guideline of Jupyter Notebook (If needed)

    Installation   : https://jupyter.org/install  
    User Document  : https://jupyter-notebook.readthedocs.io/en/latest/user-documentation.html

# Test Environment (Recommended)

    In test time, we will evaluate the given codes from you with the following version of libraries.  
    So, it is highly recommended to use those packages with specific version below.

    test environment : pytorch

### Packages
    python   : 3.8.17  
    torch    : 2.0.1   
    skimage  : 0.21.0  
    cv2      : 4.8.0

# Import libraries (Do not change!)

In [None]:
import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import cv2
from torch.utils.data import DataLoader
from skimage import io
import pandas as pd
import matplotlib.pyplot as plt
import math
import copy
import time
import PIL
import pickle

# Split dataset (Do not change!)

### Notice 1
    This function do split your dataset of 1000 classes into 10 groups of 100 each.    
    So, it is needed to be implemented just once at first to split your dataset for continual learning.   
    *Again, you dont need to use this function in every tranining time if you already split your dataset into 10 groups.

    Notice the annotation codes below. (You can see this codes in 'main' block.)

```python
        parser = argparse.ArgumentParser()   
        # Change this as 'False' after dividing your datsaet into 10 groups.
        parser.add_argument('--div_data',   default = True)  
        args = parser.parse_args(args=[])  
```

### Notice 2
    We reshapes all the input data size into constant 128x128.   
    Until further notification, use this constant size. 

```python
        # Split input data.  
        for i in range(start, end):
            for img_idx in range(0, 130):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))  # resize image into 128 x 128 
                x_train.append(img)
```


In [None]:
def train_split(validation_num):
    # TODO : set dataset path
    # TODO : We recommends you to place your code and tranining dataset in the same location.
    
    dir = './Koh_Young_AI_data/'
    

    for div_idx in range(0, 10): # Div into 10 groups
        # Divide data 0-129 for training, 130-150 for validation.
        x_train = []
        x_valid = []
        y_train = []
        y_valid = []
        start   = 100*div_idx + 1
        end     = 100*div_idx + 100

        # Split input data.  
        for i in range(start, end):
            for img_idx in range(0, 150-validation_num):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))
                x_train.append(img)

            for img_idx in range(150-validation_num, 150):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))
                x_valid.append(img)

        # Split corresponding output label data.
        for folder_idx in range(start, end):
            for img_idx in range(0, 150-validation_num):
                y_train.append(np.array([folder_idx]))
            for img_idx in range(150-validation_num, 150):
                y_valid.append(np.array([folder_idx]))

        # Convert list to numpy 
        x_train = np.array(x_train)
        y_train = np.array(y_train)
        x_valid = np.array(x_valid)
        y_valid = np.array(y_valid)

        # TODO : Define train data and valid data directory path.
        # TODO : Recommends not to change these directory paths. 
        train_save_dir = 'train_data'
        valid_save_dir = 'valid_data'
        if not os.path.exists(train_save_dir):
            os.makedirs(train_save_dir)

        if not os.path.exists(valid_save_dir):
            os.makedirs(valid_save_dir)

        # TODO : Save train/valid data
        np.save(f'./train_data/x_data_{div_idx+1}', x_train)
        np.save(f'./train_data/y_data_{div_idx+1}', y_train)
        np.save(f'./valid_data/x_data_{div_idx+1}', x_valid)
        np.save(f'./valid_data/y_data_{div_idx+1}', y_valid)

        print(f" ===================== Done in {div_idx} ===================== ")

# Define Dataloader (Do not change!)

    You can define your own dataloader with API of torch.utils.data.Dataset.  
    This can usually help you to reduce computational burden when dealing with high dimensional data, such as images.  

    reference url : https://pytorch.org/tutorials/beginner/basics/data_tutorial.html


In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x_data, y_data, device):
        self.x_data = x_data
        self.y_data = y_data
        self.device = device

    def __getitem__(self, idx):
        # .transpose(0, 2) : width x height x channel (0, 1, 2) ---> channel x width x height (2, 0, 1).
        # .squeeze(0) : add extra dimension at axis 0.
        x = torch.FloatTensor(self.x_data[idx]).transpose(0, 2)
        y = torch.LongTensor(self.y_data[idx]).squeeze(0)
        return x, y
        
    def __len__(self):
        return len(self.x_data)

def load_train_data(class_num):
    # TODO : set 'class_path' with your train_data path.
    class_path  = f'./train_data/'
    x_data_path = class_path + 'x_data_' + str(class_num+1) + '.npy'
    y_data_path = class_path + 'y_data_' + str(class_num+1) + '.npy'
    x_data      = np.load(x_data_path, allow_pickle=True)
    y_data      = np.load(y_data_path, allow_pickle=True)
    return x_data, y_data

def load_valid_data(class_num):
    # TODO : set 'class_path' with your valid_data path.
    class_path  = f'./valid_data/'
    x_data_path = class_path + 'x_data_' + str(class_num+1) + '.npy'
    y_data_path = class_path + 'y_data_' + str(class_num+1) + '.npy'
    x_data      = np.load(x_data_path, allow_pickle=True)
    y_data      = np.load(y_data_path, allow_pickle=True)

    # return processed data. 
    return x_data, y_data

# Define tranining function (You can modify this part!)

    Set your model with train mode as 'model.train()'.   

    useful reference : https://wikidocs.net/195118

In [None]:
def train_model(model, x_train, optimizer, num_epochs, train_data_loader, criterion):
    """
    model             : your customized model 
    x_train           : input data for tranining
    optimizer         : optimizer
    num_epoches       : number of iteration
    train_data_loader : dataloder of training dataset 
    """
    
    model.train()                     # Set train mode. 
    model.before_task()
    for epoch in range(num_epochs):
        acc      = 0 # Accuracy
        avg_cost = 0 # Average Cost 
        for x, y in train_data_loader:
            out = model(x.to(device))            # Inference
            _, preds = torch.max(out, 1)         # preds : Predicted class
            cost = model.get_cost(out, y.to(device))

            # Optimize processs.
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()

            avg_cost += cost # Average cost 
            acc      += torch.sum(preds.detach().cpu() == (y.data).detach().cpu()) # Accuracy
        print(f" # - EPOCHS {epoch + 1} / {num_epochs} | AvgCost {avg_cost} | Accuracy : {acc/len(x_train)} - #")
    
    model.after_task(train_data_loader)


    # Return trainded model and accuracy. 
    return model, acc/len(x_train)

# Define validataion function (Do not change!)

    And eval mode as 'model.eval()' or 'model.train(False)'.

In [None]:
def validation(model, x_valid, valid_data_loader, criterion):
    """
    model             : your customized model 
    x_vallid          : input data for validation
    valid_data_loader : dataloder of valid dataset 
    """
    
    model.eval() # Set eval mode
    
    acc = 0
    
    for x, y in valid_data_loader:
        out = model(x.data.to(device))
        _, preds = torch.max(out, 1)
        cost  = criterion(out, y.to(device))
        acc += torch.sum(preds.detach().cpu() == (y.data).detach().cpu())
    print(f" # - ValidCost {cost} | Accuracy : {acc / len(x_valid)} - #")

    # Return Accuracy 
    return acc/len(x_valid)

# Define your model and hyperparameter (You can modify this part!)

    Here is the pivotal part of your competition.
    We gives a simple CNN model, for example. 
    Go make your own model!         

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, in_channel, num_class):
        super().__init__()
        from torchvision.models import resnet18
        self.backbone = resnet18()
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_class)

        self.fisher = None
        self.lamb = 80
        self.fishermax = 0.0001
        self._known_classes = 0
        self._total_classes = 0
        self._cur_task = -1

    def forward(self, x):
        x = x/255.0
        out = self.backbone(x)
            
        return out

    def before_task(self):
        self._cur_task += 1
        self._total_classes = self._known_classes + 100 

    def after_task(self, train_loader):
        if self.fisher is None:
            self.fisher = self.getFisherDiagonal(train_loader)
        else:
            alpha = self._known_classes / self._total_classes
            new_finsher = self.getFisherDiagonal(train_loader)
            for n, p in new_finsher.items():
                new_finsher[n][: len(self.fisher[n])] = (
                    alpha * self.fisher[n]
                    + (1 - alpha) * new_finsher[n][: len(self.fisher[n])]
                )
            self.fisher = new_finsher
        self.mean = {
            n: p.clone().detach()
            for n, p in self.named_parameters()
            if p.requires_grad
        }
        self._known_classes = self._total_classes

    def getFisherDiagonal(self, train_loader):
        fisher = {
            n: torch.zeros(p.shape).to(device)
            for n, p in self.named_parameters()
            if p.requires_grad
        }
        self.train()
        optimizer = optim.SGD(self.parameters(), lr=learning_rate)
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            logits = self.forward(inputs)
            loss = torch.nn.functional.cross_entropy(logits, targets)
            optimizer.zero_grad()
            loss.backward()
            for n, p in self.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.pow(2).clone()
        for n, p in fisher.items():
            fisher[n] = p / len(train_loader)
            fisher[n] = torch.min(fisher[n], torch.tensor(self.fishermax))
        return fisher
    
    def get_cost(self, logits, targets):
        if self._cur_task == 0:
            return torch.nn.functional.cross_entropy(logits, targets)
        else:
            loss_clf = torch.nn.functional.cross_entropy(
                logits[:, self._known_classes :], targets - self._known_classes
            )
            loss_ewc = self.compute_ewc()
            return loss_clf + self.lamb * loss_ewc

    def compute_ewc(self):
        loss = 0
        for n, p in self.named_parameters():
            if n in self.fisher.keys():
                loss += (
                    torch.sum(
                        (self.fisher[n])
                        * (p[: len(self.mean[n])] - self.mean[n]).pow(2)
                    )
                    / 2
                )
        return loss

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SimpleCNN(in_channel=3, num_class=1000).to(device)
        
# TODO : Set your hyperparameters
batch_size        = 500
learning_rate     = 0.1
num_epochs        = 25
optimizer         = optim.SGD(model.parameters(), lr=learning_rate)
random_seed       = 555
validation_num    = 20 # for 150 images for class, the number for validation data
criterion = nn.CrossEntropyLoss() # Define criterion. 

# Incremental Learning. (Do not change!)

### WARNING:
    The training and validation datasets each SHOULD BE prepared properly beforehand.  
    If not, the submitted code from you will be immediately rejected.

In [None]:

"""  
--div_data  : split your data or not.   
"""
parser = argparse.ArgumentParser()  
parser.add_argument('--div_data',   default = False)  # Change this with 'False' after dividing your datsaet into 10 groups.
args = parser.parse_args(args=[])  

# TODO : Saving tranined model in this location. Don't change this path. 
save_dir = './result'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# TODO : Seed
random.seed(random_seed)
torch.manual_seed(random_seed)

# TODO : Split dataset according to argument '--div_data'
if args.div_data == True:
    train_split(validation_num)
else:
    pass


""" 
1. training      : train each 100 classes sequentailly with respect to 1000 output class. 
    trainining class === 1-100 -> 101-200 -> 201-300 -> 301-400 -> ... -> 901-1000
    
2. validation    : validate each trained model.
    validation class === 1-100 -> 1-200 -> 1-300 -> ... -> 1-1000
    
3. model save    : saves each trained model.                
"""

for div_idx in range(10):

    # TODO : Load your train and validation data
    x_train, y_train = load_train_data(div_idx)
    x_valid, y_valid = load_valid_data(div_idx)

    """
        in case of tranining 1  -100 classes, validate on 1-100 classes
        in case of tranining 101-200 classes, validate on 1-200 classes
        in case of tranining 201-300 classes, validate on 1-300 classes
        and so on...            
    """
    
    if div_idx == 0:
        x_val_tmp = x_valid
        y_val_tmp = y_valid
    else:
        x_val_tmp = np.concatenate((x_val_tmp, x_valid), axis = 0)
        y_val_tmp = np.concatenate((y_val_tmp, y_valid), axis = 0)
        x_valid   = x_val_tmp
        y_valid   = y_val_tmp

    # TODO : let the label starts from 0 to match the output index of model prediction. (Currently the label starts from 1.)
    y_train = y_train - 1
    y_valid = y_valid - 1

    # TODO : Define dataset and dataloader
    train_dataset     = CustomDataset(x_train, y_train, device)
    valid_dataset     = CustomDataset(x_valid, y_valid, device)
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_data_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    # TODO : train and validate
    trained_model, acc_train = train_model(model, x_train, optimizer, num_epochs, train_data_loader, criterion)
    acc_valid                = validation(trained_model, x_valid, valid_data_loader, criterion)

    if div_idx == 9:
        MODEL_SAVE_FOLDER_PATH = './model_save/'
        if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
            os.mkdir(MODEL_SAVE_FOLDER_PATH)        
        model_path = MODEL_SAVE_FOLDER_PATH + 'continual_model.pt'
        # TODO : save trained model in 'save_model_path'
        torch.save(trained_model.state_dict(), model_path)

    print(f'{str(div_idx)} Iteration Done.')