In [1]:
import random
import numpy as np
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import make_dataloaders
from utils import create_loss_meters, log_results, update_losses, visualize
from trainer import MainModel

from model import Generator_Res_Unet

In [2]:
seed = 41
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

In [5]:
def train_model(model, train_dl, valid_dl, epochs):
    # getting a batch for visualizing the model output after fixed intrvals

    
    for e in range(epochs):
        training_loader_iter = iter(train_dl)
        length_train = len(training_loader_iter)
        
        # function returing a dictionary of objects to
        loss_meter_dict = create_loss_meters()
        
        # log the losses of the complete network
        for i in tqdm(length_train):
            data = next(training_loader_iter)
            model.setup_input(data)
            model.optimize()
            
            # function updating the log objects
            update_losses(model, loss_meter_dict, count=data["L"].size(0)) 
            
            if i == 0 or i == int(length_train / 2) or i == (length_train - 1):
                print(f"\nEpoch: [{e+1}/{epochs}], Iteration:[{i}/{length_train}]")
                
                # function to print out the losses
                log_results(loss_meter_dict)  
            
        for idx, valid_data in enumerate(valid_dl):
            visualize(model, valid_data, save_name=f"Epoch_{e}-{idx}.png")

In [6]:
def pretrain_generator(net_G, train_dl, valid_dl, opt, criterion, epochs):
    for e in range(epochs):
        loss_meter = AverageMeter()
        
        for data in tqdm(train_dl):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_meter.update(loss.item(), L.size(0))
            torch.cuda.empty_cache()
            
        print(f"Epoch: [{e + 1}/{epochs}] L1 Loss: [{loss_meter.avg:.5f}]")
        
        for idx, valid_data in enumerate(valid_dl):
            visualize(model, valid_data, save_name=f"Epoch_{e}-{idx}.png")

In [8]:
if __name__ == "__main__":
    
    train_ = r"images\train"
    valid_ = r"images\valid"

    device = get_default_device()

    train_dl = make_dataloaders(path=train_)
    valid_dl = make_dataloaders(
        path=valid_, is_training=False, shuffle=False
    )
    
    print(f"Number of batches ::Train:: {len(train_dl)}, ::Valid:: {len(valid_dl)}")

Number of batches ::Train:: 109, ::Valid:: 2


In [11]:
# Warmup generator

print("Pretraining  U-net Generator ResNet-18")
net_G = Generator_Res_Unet().get_model()
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()     

net_G.to(device)
criterion.to(device)

# pretrain_generator(net_G, train_dl, valid_dl, opt, criterion, 20)
# torch.save(net_G.state_dict(), "res18-unet.pt")

Pretraining  U-net Generator ResNet-18


L1Loss()

In [None]:
# train Gan model

net_G = Generator_Res_Unet()
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
model = MainModel(net_G=net_G)
train_model(model, train_dl, valid_dl, 20)