In [1]:
from code.sepconvfull import model
import dataloader
import torch
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import discriminator
from tqdm import tqdm
from collections import defaultdict
import metrics
import numpy as np
import hyperopt
from hyperopt import hp, space_eval, Trials
import time
from itertools import product
import os
from utilities import ResultStore, EarlyStopping, get_sepconv

In [2]:
MODEL_FOLDER = 'models'

In [3]:
torch.manual_seed(42)
np.random.seed(42)

In [4]:
def objective(params):
    '''Minimize loss on validation set'''
    
    _, _, resultstore = train(params, n_epochs=5, verbose=False)
    
    losses = resultstore.results['valid']['L1_loss']
    max_epoch = np.max(list(losses.keys()))
    validation_loss = np.mean(losses[max_epoch])
    
    return validation_loss


def train(params, n_epochs, verbose=True):
    
    # init interpolation model
    G = get_sepconv(input_size = params['input_size'])
    G = G.cuda()
    G = G.eval()

    # init discriminator
    D = discriminator.Discriminator(input_size=params['input_size']).cuda()
    
    ds = dataloader.adobe240_dataset( quadratic = params['input_size'] == 4 )
    ds = dataloader.TransformedDataset(ds, crop_size=(128,128), normalize=True)

    N_train = int(len(ds) * 0.8)
    N_valid = len(ds)-N_train

    train, valid = torch.utils.data.random_split(ds, [N_train, N_valid])

    train_dl = torch.utils.data.DataLoader(train, batch_size=4, pin_memory=True, shuffle=True)
    valid_dl = torch.utils.data.DataLoader(valid, batch_size=4, pin_memory=True)

    optimizer_G = torch.optim.Adam(G.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])#$, amsgrad=params['amsgrad'])
    optimizer_D = torch.optim.Adam(D.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])#, amsgrad=params['amsgrad'])
    critereon = torch.nn.L1Loss()

    # metrics
    name = f'{int(time.time())}_{params["lr"]}_{params["weight_decay"]}_{params["wgan"]}_{params["input_size"]}'
    writer = SummaryWriter(f'runs/{name}') #TODO hp
    results = ResultStore(writer=writer, metrics = ['psnr', 'ie', 'L1_loss', 'accuracy', 'G_loss', 'D_loss'])
    early_stopping = EarlyStopping(results, patience=5, metric='L1_loss')
    
    for epoch in range(n_epochs):
        G.train()
        D.train()
        
        if verbose:
            pb = tqdm(desc=f'{epoch+1}/{n_epochs}', total=len(train_dl), leave=True, position=0)
        
        for i, (X, y) in enumerate(train_dl):
            X = X.cuda()
                
            # train generator               

                
            y = y.cuda()

            y_hat = G(X)
            l1_loss = critereon(y_hat, y)
            
            
            loss = l1_loss - D(X, y_hat).sigmoid().log().mean()
            # generator loss
            if params['wgan']:
                G_loss = l1_loss - D(X, y_hat).mean()
            else:
                G_loss = l1_loss - D(X, y_hat).sigmoid().log().mean()
                

            optimizer_G.zero_grad()
            G_loss.backward()
            optimizer_G.step()

            # train discriminator
            y_hat = y_hat.detach()
            
            
            if params['wgan']:
                D_loss = torch.mean( D(X, y_hat) - D(X, y) )
            else:
                D_loss = -torch.log(1 - D(X, y_hat).sigmoid()).mean() - D(X, y).sigmoid().log().mean()
            
            
            # compute psnr
            y_hat = (y_hat * 255).clamp(0,255)
            y = (y * 255).clamp(0,255)

            psnr = metrics.psnr(y_hat, y)
            ie = metrics.interpolation_error(y_hat, y)
            
            
            correct_preds = (D(X, y_hat).sigmoid().round() == 0).flatten().int().detach().cpu().tolist()
            correct_preds.extend((D(X, y).sigmoid().round() == 1).flatten().int().detach().cpu().tolist())
          
            results.store('train', epoch, {
                    'L1_loss':l1_loss.item(),
                    'psnr':psnr,
                    'ie':ie,
                    'accuracy':correct_preds,
                    'D_loss':D_loss.item(),
                    'G_loss':G_loss.item()
                })
    
            optimizer_D.zero_grad()
            D_loss.backward()
            optimizer_D.step()
            
            if params['wgan']:
                for p in D.parameters():
                    p.data.clamp_(-0.01, 0.01)
            
            if verbose: pb.update()
                
            if i == 50:
                break
                


        # update tensorboard
        results.write_tensorboard('train', epoch)


        G.eval()
        D.eval()
        with torch.no_grad():
            for i, (X, y) in enumerate(valid_dl):
                X = X.cuda()
                y = y.cuda()

                y_hat = G(X)        
                l1_loss = critereon(y_hat, y)        
                
                # generator loss
                if params['wgan']:
                    loss = l1_loss - D(X, y_hat).mean()
                else:
                    loss = l1_loss - D(X, y_hat).sigmoid().log().mean()
                
                y_hat = y_hat.detach()
                
                if params['wgan']:
                    D_loss = torch.mean( D(X, y_hat) - D(X, y) )
                else:
                    D_loss = -torch.log(1 - D(X, y_hat).sigmoid()).mean() - D(X, y).sigmoid().log().mean()
                
                # compute psnr
                y_hat = (y_hat * 255).clamp(0,255)
                y = (y * 255).clamp(0,255)

                psnr = metrics.psnr(y_hat, y)
                ie = metrics.interpolation_error(y_hat, y)
                
                correct_preds = (D(X, y_hat).sigmoid().round() == 0).flatten().int().detach().cpu().tolist()
                correct_preds.extend((D(X, y).sigmoid().round() == 1).flatten().int().detach().cpu().tolist())

                results.store('valid', epoch, {
                    'L1_loss':l1_loss.item(),
                    'psnr':psnr,
                    'ie':ie,
                    'accuracy':correct_preds,
                    'D_loss':D_loss.item(),
                    'G_loss':G_loss.item()
                })
                




        # update tensorboard
        results.write_tensorboard('valid', epoch)
        
        # save model if new best
        if early_stopping.new_best():
            filepath_out = os.path.join(MODEL_FOLDER, '{0}_{1}')
            torch.save(G, filepath_out.format('generator', name))
            torch.save(D, filepath_out.format('discriminator', name))
            
        
        if early_stopping.stop():
            break
            
        if i == 50:
            break
            
            
    # free memory
    del G
    del D
    torch.cuda.empty_cache()
    
        
    # save models
    return results




### HP

In [5]:
# trials = Trials()

# best_hp = hyperopt.fmin(
#     fn=objective, 
#     space=params,
#     algo=hyperopt.tpe.suggest,
#     trials = trials,
#     max_evals=10
# )

In [6]:
space_eval(params, best_hp)

NameError: name 'params' is not defined

### grid search

In [5]:
parameter_space = {
     'input_size': [2, 4],
     'lr': [1e-5, 1e-4],
     'weight_decay': [0],
     'wgan': [True, False],
    
}

In [6]:
param_combinations = product(*parameter_space.values())
param_combinations = [dict(zip(parameter_space, values)) for values in param_combinations]

In [7]:
param_combinations[0]

{'input_size': 2, 'lr': 1e-05, 'weight_decay': 0, 'wgan': True}

In [8]:
R = train(param_combinations[0], n_epochs=5)

1/5:   0%|▎                                                                           | 9/2394 [00:07<27:00,  1.47it/s]

KeyboardInterrupt: 

In [12]:
for parameters in param_combinations[7:]:
    print(time.ctime(), parameters)
    
    results = train(
        parameters,
        n_epochs=10,
        verbose=True
    )
    
    

Mon Apr 13 10:02:05 2020 {'input_size': 2, 'lr': 0.0001, 'weight_decay': 1e-05, 'wgan': False}


10/10: 100%|███████████████████████████████████████████████████████████████████████| 2413/2413 [27:22<00:00,  1.47it/s]