# TRAINING and RESULT on Biovid
- This notebook presents the steps for evaluating the optimal lambda to maximize the classification task
- The mutual information is modeled by KNIFE : https://openreview.net/pdf?id=a43otnDilz2 

In [None]:
import os
import torch
import pandas as pd
import torchvision
from torchvision.models import resnet50, ResNet50_Weights,resnet18, ResNet18_Weights
import matplotlib.pyplot as plt
from utils import data_augm,data_adapt
from knife import KNIFE
from log_reader import Logs
from Dataset_processing.Biovid import gen_dataframe
from data_loader import Dataset_Biovid_image_binary_class
import torch
from tqdm import tqdm
from MI_training import train,Classif

device_ids = [1]
for d in device_ids:
    torch.cuda.set_device(d)
    torch.cuda.empty_cache()
    #torch.cuda.current_device()
    torch.cuda.get_device_name()
device = f'cuda:{device_ids[0]}'

### Modify and complete 
Change the paths and the hyperparameters

In [None]:
Biovid_img = '.../Biovid/sub_red_classes_img'
biovid_annot_train = '.../your_folder/Biovid_binary/train.csv'
biovid_annot_test = '.../your_folder/Biovid_binary/test.csv'
save_path = '.../your_folder/models/experience_lambda/'


BATCH_SIZE = 200
RESOLUTION = 112
nb_ID = 61
FOLD = 8
LEARNING_RATE = 0.01
LEARNING_RATE_FINETUNE = 0.000005
EPOCH_PRETRAIN = 10
EPOCH_FINETUNE = 20
LAMBDA = [round(i*0.05,2) for i in range(44)]


arg_MI= {'zd_dim':1000, 'zc_dim':61,'hidden_state':100, 'layers':3, 'nb_mixture':10,'tri':False}

### Compute new annotation
- follow the official Biovid part A split Train/Validation
- add ID and video_ID


In [None]:
try :
    pd.read_csv(biovid_annot_train)
except:
    gen_dataframe(Biovid_img[:-20]+'/sub_two_labels.txt',biovid_annot_train[:-10],['100914_m_39','101114_w_37',
                                                                        '082315_w_60', '083114_w_55',
                                                                        '083109_m_60','072514_m_27',
                                                                        '080309_m_29', '112016_m_25',
                                                                        '112310_m_20', '092813_w_24',
                                                                        '112809_w_23', '112909_w_20',
                                                                        '071313_m_41', '101309_m_48',
                                                                        '101609_m_36', '091809_w_43',
                                                                        '102214_w_36', '102316_w_50',
                                                                        '112009_w_43', '101814_m_58',
                                                                        '101908_m_61', '102309_m_61',
                                                                        '112209_m_51', '112610_w_60',
                                                                        '112914_w_51', '120514_w_56'])
    print('split was created')

### Prepare datasets and data loaders

In [None]:
#Preprocessing
tr = data_augm(RESOLUTION)
tr_test = data_adapt(RESOLUTION)
tr_size = torchvision.transforms.Resize((RESOLUTION,RESOLUTION),antialias=True)

#Train
dataset_train = Dataset_Biovid_image_binary_class(Biovid_img,biovid_annot_train,transform = tr.transform,IDs = None,nb_image = None,nb_fold=FOLD,preload=False)
loader_train = torch.utils.data.DataLoader(dataset_train,
                                             batch_size=BATCH_SIZE, shuffle=True,
                                             num_workers=0,drop_last = True) # num_workers=0 if preload is False

# Validation
dataset_test = Dataset_Biovid_image_binary_class(Biovid_img,biovid_annot_test,transform = tr_test.transform,IDs = None,nb_image = None,preload=False)
loader_test = torch.utils.data.DataLoader(dataset_test,
                                             batch_size=BATCH_SIZE, shuffle=True,
                                             num_workers=0)

### Create Models, Losses and optimizers
- ResNet18 is pretrained with ImageNet and freezed

In [None]:
model_resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).to(device)
model_affect = Classif(1,False).to(device)
model_ID = Classif(nb_ID).to(device)
MI = KNIFE(**arg_MI).to(device)

optimizer = torch.optim.Adam(list(model_affect.parameters())+list(model_ID.parameters()),lr=LEARNING_RATE)
optimizer_MI = torch.optim.Adam(list(MI.parameters()),lr=0.01)

loss_BCE = torch.nn.BCELoss(reduction='sum')
loss_CE = torch.nn.CrossEntropyLoss(reduction='sum')

### Pretrain MI, Affect_Classifier and ID_Classifier
- ID_classifier doesn't influences the ResNet representation (encoded_img.detach())

In [None]:
dic_log = {'loss_CE_train':[],'loss_CE_val':[],'loss_acc_train':[],'loss_acc_val':[],'loss_acc_ID_val': [],'MI':[]}

for epoch in range(EPOCH_PRETRAIN):
    dataset_train.reset()
    loss_task_tot = 0
    elem_sum = 0
    true_response_affect = 0
    true_response_ID = 0
    MI_loss_tot = 0
    model_resnet.train()
    model_affect.train()
    model_ID.train()
    loop_train = tqdm(loader_train,colour='BLUE')
    for i,pack in enumerate(loop_train):

        img_tensor = pack[0].to(device)
        pain_tensor = pack[1].float().to(device)
        ID_tensor = pack[2].to(device)
        ID_tensor_one_hot = torch.nn.functional.one_hot(ID_tensor,61).float()
        with torch.no_grad():
            encoded_img  = model_resnet(img_tensor)

        # UPDATE MI
        loss_MI = MI.loss(encoded_img.detach(),ID_tensor_one_hot)
        optimizer_MI.zero_grad()
        loss_MI.backward()
        optimizer_MI.step()
        MI_loss_tot += float(loss_MI)
        
        # TASK Affect
        output = model_affect(encoded_img)
        loss_task_affect = loss_BCE(output,pain_tensor) 
        true_response_affect += float(torch.sum(output.round() == pain_tensor))

        # Task ID
        output = model_ID(encoded_img.detach())
        loss_task_ID = loss_CE(output,ID_tensor) 
        true_response_ID += float(torch.sum(output.max(dim=-1)[1] == ID_tensor))

        loss =  loss_task_ID + loss_task_affect
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        elem_sum += img_tensor.shape[0]
        loss_task_tot += float(loss_task_affect)

        loop_train.set_description(f"Epoch [{epoch}/{EPOCH_PRETRAIN}] training")
        loop_train.set_postfix(loss_task = loss_task_tot/elem_sum,accuracy_pain=true_response_affect/elem_sum*100,accuracy_ID=true_response_ID/elem_sum*100,MI=MI_loss_tot/elem_sum)
    
    model_resnet.eval()
    model_affect.eval()
    model_ID.eval()

    loss_task_val = 0
    elem_sum_val = 0
    true_response_affect_val  =0
    true_response_ID_val = 0
    loop_test = tqdm(loader_test,colour='GREEN')
    for pack in loop_test:
        img_tensor = pack[0].to(device)
        pain_tensor = pack[1].float().to(device)
        ID_tensor = pack[2].to(device)
            
        with torch.no_grad():
            encoded_img  = model_resnet(img_tensor)
            # TASK Affect
            output = model_affect(encoded_img)
            loss_task_affect_val = loss_BCE(output,pain_tensor) 
            true_response_affect_val += float(torch.sum(output.round() == pain_tensor))

            # Task ID
            output = model_ID(encoded_img.detach())
            loss_task_ID_val = loss_CE(output,ID_tensor) 
            true_response_ID_val += float(torch.sum(output.max(dim=-1)[1] == ID_tensor))

        elem_sum_val += img_tensor.shape[0]
        loss_task_val += float(loss_task_affect_val)
        loop_test.set_description(f"Test")
        loop_test.set_postfix(loss_task = loss_task_val/elem_sum_val,accuracy_pain=true_response_affect_val/elem_sum_val*100,accuracy_ID=true_response_ID_val/elem_sum_val*100)

    dic_log['loss_CE_train'].append(loss_task_tot/elem_sum)
    dic_log['MI'].append(MI_loss_tot/elem_sum)
    dic_log['loss_acc_train'].append(true_response_affect/elem_sum*100)
    dic_log['loss_CE_val'].append(loss_task_val/elem_sum_val)
    dic_log['loss_acc_val'].append(true_response_affect_val/elem_sum_val*100)
    dic_log['loss_acc_ID_val'].append(true_response_ID_val/elem_sum_val*100)
    torch.save(model_resnet.state_dict(),save_path+'encoder.pt')
    torch.save(model_ID.state_dict(),save_path+'ID.pt')
    torch.save(model_affect.state_dict(),save_path+'Affect.pt')
    torch.save(MI.state_dict(),save_path+'MI.pt')

### Main loop
- a training loop for each lambda value
- all is unfreezed

In [None]:
for lamb in LAMBDA:
        model_resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).to(device)
        model_resnet.load_state_dict(torch.load(save_path+'encoder.pt'))

        model_affect = Classif(1,False).to(device)
        model_affect.load_state_dict(torch.load(save_path+'Affect.pt'))

        model_ID = Classif(nb_ID).to(device)
        model_ID.load_state_dict(torch.load(save_path+'ID.pt'))

        MI = KNIFE(**arg_MI).to(device)
        MI.load_state_dict(torch.load(save_path+'MI.pt'))

        train(model_resnet,MI,model_affect,model_ID,loader_train,loader_test,device,lamb=lamb,EPOCH=EPOCH_FINETUNE,lr=LEARNING_RATE_FINETUNE,save_path=save_path)

### Results

In [None]:
L = Logs(save_path,select=LAMBDA,colormap='brg')
L.plot()
L.lamb

In [None]:
L.curve('loss_acc_val',degree=2,res=300,color='orange')
#L.curve('loss_acc_ID_train',degree=-1,res=300,color='green')