# Train the Model

This notebook contains code to train the model for FBA Matting

In [1]:
import pytorch_ssim
import torch
from torch.autograd import Variable
from torch import optim
import torch.nn as nn
import os
import cv2
import numpy as np
import sys
from PIL import Image, ImageEnhance

In [2]:
# Dataset Paths
IMAGE = "./Image"
ALPHA = "./Alpha"
TRIMAP = "./Trimap"

print("Total Images :", len(os.listdir(IMAGE)))

Total Images : 41245


In [3]:
#Custom Imports
#RAdam Optimizer 
from radam import RAdam
# Dataset Loader
from dataloader import DataLoader

# Initialize Model

In [4]:
from demo import np_to_torch,scale_input
from networks.models import build_model
from networks.transforms import trimap_transform, groupnorm_normalise_image

In [5]:
class Args:
  encoder = 'resnet50_GN_WS'
  decoder = 'fba_decoder'
  weights = 'default'
args=Args()
model = build_model(args)

modifying input layer to accept 11 channels


# Training Process

In [6]:
optimizer = RAdam(model.parameters())
dataset = DataLoader(IMAGE, TRIMAP, ALPHA)
dataset_len = len(dataset)

In [7]:
# Initialize the losses here
mse_loss_criteria = nn.MSELoss()
sad_loss_criteria = nn.L1Loss()
margin_ranking_loss_criteria = nn.MarginRankingLoss(margin=1.0)

In [8]:
import pickle
import time

def check_epochs(file_names):
    """
    Returns only filenames that contains "epoch_" in them
    """
    if len(file_names) ==0:
        return []
    epoch_names = []
    for names in file_names:
        if "epoch_" in names:
            epoch_names.append(names)
    return epoch_names

def load_saved_data():
    last_saved_model = sorted(check_epochs(os.listdir("./logs/")))[-1]
    print("Loading saved data from {}".format(last_saved_model))
    modelFile = open("./logs/"+last_saved_model, 'rb')
    modelPickle = pickle.load(modelFile)
    modelFile.close()
    saved_model = modelPickle['model']
    saved_lowest = modelPickle['lowest']        
    saved_dataset = modelPickle['dataset']        
    saved_epoch = modelPickle['epoch']
    saved_loss = modelPickle['loss']
    return (saved_model, saved_lowest, saved_dataset, saved_epoch, saved_loss)


def save_data(model, lowest, dataset, epoch, loss):
    """
    Save the data
    """
    epoch_name = "epoch_{:09d}".format(epoch)
    model_name = "model_{:09d}".format(epoch)
    epoch_data = {}
    epoch_data['model'] = model
    epoch_data['lowest'] = lowest
    epoch_data['dataset'] = dataset
    epoch_data['epoch'] = epoch
    epoch_data['loss'] = loss
    
    dbfile = open('./logs/'+epoch_name, 'ab')
    pickle.dump(epoch_data, dbfile)
    dbfile.close()
    torch.save(model.state_dict(), "./Saved Models/"+model_name)
    print("Saved data for epoch : {}".format(epoch))
    

In [9]:
# import pickle
# # Redundant functions
# # No need
# def log(model, loss, epoch, dataset, lowest_loss, loss_list,save=False):
#     """
#     This functions saves the data regarding different epochs
#     Data is saved after 10 epochs, and for the lowest loss value
#     if save is True, this function is run after predicting output
#     loss_list = dict with values "mse", "sad", "mr"
#     """
# #     First check if epoch is 0
#     main= None
#     if epoch == 0 and not save:
#         # Restore or load from main file and then start
#         if os.path.exists("./logs/main"):
#             mainFile = open("./logs/main", 'rb')
#             mainPickle = pickle.load(mainFile)
#             loaded_dataset = mainPickle['dataset']
#             loaded_model = mainPickle['model']
#             loaded_lowest = mainPickle['lowest']
#             loaded_epoch = mainPickle['epoch']
#             loaded_loss_list = mainPickle['loss_list']
#             print("Successfully loaded Data")
#             mainFile.close()
#             return (loaded_model, loaded_dataset, loaded_lowest, loaded_epoch, loaded_loss_list)
#         else:
#             # First time running model
#             print("Running for first time")
#             return (model,dataset,lowest_loss, epoch, loss_list)
#     elif epoch == 0 and save:
#         # In this case it will save a main file 
#         mainData = {}
#         mainData['dataset'] = dataset
#         mainData['model'] = model
#         mainData['lowest'] = loss
#         mainData['epoch'] = epoch
#         mainData['loss_list'] = loss_list
#         dbfile = open('./logs/main', 'ab')
#         pickle.dump(mainData, dbfile)
#         dbfile.close() 
#         return None
#     elif epoch>0 and not save:
#         # If its not the first epoch, not need to do anything
#         return (model,dataset,lowest_loss, epoch)
#     elif epoch>0 and save:
#         # Save data according to epochs
#         mainFile = open("./logs/main", 'rb')
#         mainPickle = pickle.load(mainFile)
#         mainFile.close()
#         if loss < lowest_loss:
#             mainData['lowest'] = loss
#         if epoch%10 == 0:
#             # Each multiple of 10 will have it own data save
#             mainData['epoch'] = epoch
#             mainData['loss_list'] = loss_list
#             dbfile = open('./logs/main', 'ab')
#             pickle.dump(mainData, dbfile)
#             dbfile.close()
            
#             file_name = "epoch_{}".format(epoch)
#             epochData = {}
#             epochData['dataset'] = dataset
#             epochData['model'] = dataset
#             epochData['loss'] = dataset
#             epochData['loss_list'] = dataset
#             epochData['epoch'] = epoch
            
#             dbfile = open('./logs/'+file_name, 'ab')
#             pickle.dump(epochData, dbfile)
#             dbfile.close()
#             print("Data saved for epoch :", i)
#         return None
            
            
            

In [None]:
#Loss value is stored in this list
loss_values={}
loss_values['mse'] = []
loss_values['sad'] = []
loss_values['mr'] = []
loss_values['total'] = []
lowest = 99999999999999999999
i=0
for _ in range(dataset_len):
    start_time = time.time()
    epoch = i
    if i ==0 and len(check_epochs(os.listdir("./logs/"))) !=0:
        # Load the data
        model, lowest, dataset, i, loss_values = load_saved_data()
        epoch = i
        
    elif i==0:
        print("No saved data found")
    print("Dataset {} out of {}".format(i, dataset_len))
    try:
        sample = dataset[i]
    except ValueError:
        continue
    optimizer.zero_grad()
    image_np, trimap_np, output = sample['image'], sample['trimap'], sample['output']
    h, w = trimap_np.shape[:2]

    image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
    trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)

    image_torch = np_to_torch(image_scale_np)
    trimap_torch = np_to_torch(trimap_scale_np)
    output = torch.from_numpy(output)[None, :, :, :].float().cuda()
    trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np))
    image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw')

    pred_output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch)
    tranposed_output = cv2.resize(pred_output[0].cpu().detach().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
    # 0 --> Alpha
    # 1-4 --> Fg
    # 4:7 --> Bg

    #mse_loss 
    mse_loss = mse_loss_criteria(pred_output, output)

    #sad loss
    sad_loss = sad_loss_criteria(pred_output, output)

    

    #margin ranking loss
    target = torch.ones(1, 1).to(device='cuda')
    mr_loss = margin_ranking_loss_criteria.forward(pred_output,output, target)
    #overall loss
    loss = 0.01*mse_loss+sad_loss+mr_loss
    loss_values['mse'].append(0.01*mse_loss.item())
    loss_values['sad'].append(sad_loss.item())
    loss_values['mr'].append(mr_loss.item())
    loss_values['total'].append(loss.item())
    print("Total Loss :",loss.item())
    if loss.item()<lowest:
        print("Lowest Loss:{} at {}".format(loss, i))
        lowest = loss.item()
    if epoch%100 == 0:
        # Save the data for this epoch
         save_data(model, lowest, dataset, i, loss_values)
    end_time = time.time()
    print("Time Taken for epoch :", format(end_time-start_time))
    loss.backward()
    optimizer.step()
    print()
    i+=1

Loading saved data from epoch_000009100
Dataset 9100 out of 41245
Total Loss : 149.88914489746094
Saved data for epoch : 9100
Time Taken for epoch : 34.712388038635254

Dataset 9101 out of 41245
Total Loss : 138.00808715820312
Time Taken for epoch : 3.670738697052002

Dataset 9102 out of 41245
Total Loss : 159.649169921875
Time Taken for epoch : 1.7401096820831299

Dataset 9103 out of 41245
Total Loss : 35.949951171875
Time Taken for epoch : 3.675778865814209

Dataset 9104 out of 41245
Total Loss : 238.75282287597656
Time Taken for epoch : 11.228214025497437

Dataset 9105 out of 41245
Total Loss : 94.8265380859375
Time Taken for epoch : 12.834595203399658

Dataset 9106 out of 41245
Total Loss : 108.29341888427734
Time Taken for epoch : 1.8195834159851074

Dataset 9107 out of 41245
Total Loss : 51.69748306274414
Time Taken for epoch : 13.966834545135498

Dataset 9108 out of 41245
Total Loss : 58.77845001220703
Time Taken for epoch : 1.8374710083007812

Dataset 9109 out of 41245
Dataset 