## This notebook trains the regressor model for both human and cartoon images. For demonstration purpose, this only trains the human images. If you want to also train the model for cartoon images, simply change the all variables with "human" to "cartoon"

In [None]:
import torch
from dataset import Human2CartoonDataset
import sys
from utils import save_checkpoint, load_checkpoint, calculate_error_norm
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import config
from tqdm import tqdm
import numpy as np
from model_R import Regressor
from PIL import Image
import os
from torch.utils.data import Dataset
import torch
import numpy as np
import pandas as pd
from torchsummary import summary
import torch
import torch.utils.data
import tensorflow as tf
import matplotlib.pyplot as plt
from torchsummary import summary as summary

In [None]:
def train_fn(reg_H,loader, opt_H, mse):
    loop = tqdm(loader, leave=True)
    
    running_loss = 0
    for idx, (human, cartoon, landmark_human, landmark_cartoon) in enumerate(loop):
        cartoon = cartoon.to(config.DEVICE)
        human = human.to(config.DEVICE)
        landmark_human = landmark_human.to(config.DEVICE)
        landmark_cartoon = landmark_cartoon.to(config.DEVICE)

        with torch.cuda.amp.autocast():
            landmark_human_pred = reg_H((human*0.5+0.5)*255)

            landmark_human_loss = mse(landmark_human_pred.float(),landmark_human.float())
            # If you want to train with cartoon images, change all variables with "human" to "cartoon"
            

        opt_H.zero_grad()
        landmark_human_loss.backward()
        opt_H.step()
        running_loss += landmark_human_loss.item()
 
        loop.set_postfix(L = landmark_human_loss.item())
    return running_loss


In [None]:
def main():

    reg_H = Regressor().to(config.DEVICE)


    opt_H = optim.Adam(
        list(reg_H.parameters()),
        lr= 1e-5,
        betas=(0.5, 0.999),
    )
    dataset = Human2CartoonDataset(
        root_human=config.TRAIN_DIR+"/trainA", 
        root_cartoon=config.TRAIN_DIR+"/trainB", 
        root_landmarks_human=config.TRAIN_DIR+"/trainA_human_landmarks.xlsx", 
        root_landmarks_cartoon=config.TRAIN_DIR+"/trainB_cartoon_landmarks.xlsx", 
        transform=config.transforms
    )
    loader = DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    )
    
    mse = nn.MSELoss()
    load_checkpoint(
            "R_H.pth.tar", reg_H, opt_H, 1e-5,
        )

    hold_loss=[]
    for epoch in range(2):
        print('Epoch : ', epoch)
        running_loss=train_fn(reg_H,loader, opt_H, mse)
        hold_loss.append(running_loss/4644)
        print('Loss at this epoch : ', hold_loss[epoch])
    plt.plot(np.array(hold_loss))
    
    
if __name__ == "__main__":
    main()