In [1]:
from code.sepconvfull import model
import dataloader
import torch
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict
import matplotlib.pyplot as plt
import discriminator
from tqdm import tqdm
from collections import defaultdict
import metrics
import numpy as np
import hyperopt
from hyperopt import hp, space_eval, Trials
import time

In [2]:
torch.manual_seed(42)
np.random.seed(42)

In [3]:
params = {
    'lr': hp.choice('lr', [1e-5, 1e-4, 5e-4]),
    'weight_decay': hp.choice('weight_decay', [0, 1e-4, 5e-4]),
    'amsgrad':False,
    'loss': 'normal', # or wasserstein,
    'input_size': 2 # or 4   
}

In [4]:
def convert_weights(weights):
    w = OrderedDict()
    for key in weights:
        new_key = 'get_kernel.'+key
        w[new_key] = weights[key]
        
    return w

In [5]:
# # init interpolation model
# sepconv = model.SepConvNet(kernel_size=51)

# weights = torch.load('code/sepconv/network-l1.pytorch')
# weights = convert_weights(weights)

# sepconv.load_state_dict(weights)
# opt = torch.optim.Adam(sepconv.parameters())

# # init discriminator
# disc_model = discriminator.Discriminator()

# sepconv = sepconv.cuda()
# D = disc_model.cuda()

In [6]:
class ResultStore:
    
    def __init__(self, folds=['train', 'valid'], metrics=['psnr', 'ie', 'loss', 'accuracy'], writer=None):
        self.folds = folds
        self.metrics = metrics
        self.results = dict()
        self.writer = writer
        
        for fold in self.folds:
            self.results[fold] = dict()
            for metric in self.metrics:
                self.results[fold][metric] = defaultdict(list)
        
    def store(self, fold, epoch, value_dict):
        for metric, value in value_dict.items():
            if isinstance(value, list):
                self.results[fold][metric][epoch].extend(value)
            else:
                self.results[fold][metric][epoch].append(value)
        
    def write_tensorboard(self, fold, epoch):
        for metric in self.metrics:
            mean = np.mean(self.results[fold][metric][epoch])
            
            self.writer.add_scalar(f'{metric}/{fold}', mean, epoch)
            self.writer.add_histogram(f'{metric}/{fold}_hist', np.array(self.results[fold][metric][epoch]), epoch)

In [7]:
def objective(params):
    '''Minimize loss on validation set'''
    
    _, _, resultstore = train(params, n_epochs=5, verbose=False)
    
    losses = resultstore.results['valid']['loss']
    max_epoch = np.max(list(losses.keys()))
    validation_loss = np.mean(losses[max_epoch])
    
    return validation_loss


def train(params, n_epochs, verbose=True):
    
    # init interpolation model
    sepconv = model.SepConvNet(kernel_size=51)

    weights = torch.load('code/sepconv/network-l1.pytorch')
    weights = convert_weights(weights)

    sepconv.load_state_dict(weights)
    opt = torch.optim.Adam(sepconv.parameters())

    # init discriminator
    disc_model = discriminator.Discriminator()

    sepconv = sepconv.cuda()
    D = disc_model.cuda()
    
    
    ds = dataloader.adobe240_dataset()
    ds = dataloader.TransformedDataset(ds, crop_size=(128,128))

    N_train = int(len(ds) * 0.8)
    N_valid = len(ds)-N_train

    train, valid = torch.utils.data.random_split(ds, [N_train, N_valid])

    train_dl = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True, pin_memory=True)
    valid_dl = torch.utils.data.DataLoader(valid, batch_size=4, pin_memory=True)

    optimizer_G = torch.optim.Adam(sepconv.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], amsgrad=params['amsgrad'])
    optimizer_D = torch.optim.Adam(D.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], amsgrad=params['amsgrad'])
    critereon = torch.nn.L1Loss()

    # metrics
    name = f'{int(time.time())}_{params["lr"]}_{params["weight_decay"]}'
    writer = SummaryWriter(f'runs/{name}') #TODO hp
    R = ResultStore(writer=writer)
    
    for epoch in range(n_epochs):
        sepconv.train()
        D.train()
        for i, ((x1, x2), y) in enumerate(train_dl):
            
            if verbose:
                pb = tqdm(desc=f'{epoch+1}/{n_epochs}', total=len(train_dl))    
                
            x1 = x1.cuda() / 255.
            x2 = x2.cuda() / 255.
            y = y.cuda() / 255.

            y_hat = sepconv(x1, x2)
            l1_loss = critereon(y_hat, y)

            loss = l1_loss - D(x1, x2, y_hat).sigmoid().mean()

            # compute psnr
            y_hat = (y_hat * 255).clamp(0,255)
            y = (y * 255).clamp(0,255)

            psnr = metrics.psnr(y_hat, y)
            ie = metrics.interpolation_error(y_hat, y)
            
            optimizer_G.zero_grad()
            l1_loss.backward()
            optimizer_G.step()

            # train discriminator
            y_hat = y_hat.detach()
            
            if params['wgan']:
                for p in D.parameters():
                    p.data.clamp_(-0.01, 0.01)

            D_loss = D(x1, x2, y_hat).sigmoid().mean() - D(x1, x2, y).sigmoid().mean()

            correct_preds = (D(x1, x2, y_hat).sigmoid().round() == 0).flatten().int().detach().cpu().tolist()
            correct_preds.extend((D(x1, x2, y).sigmoid().round() == 1).flatten().int().detach().cpu().tolist())
          
            R.store('train', epoch, {'loss':loss.item(), 'psnr':psnr, 'ie':ie, 'accuracy':correct_preds})
    
            optimizer_D.zero_grad()
            D_loss.backward()
            optimizer_D.step()
            
            if verbose: pb.update()

            if i == 50:
                break

        # update tensorboard
        R.write_tensorboard('train', epoch)


        sepconv.eval()
        D.eval()
        with torch.no_grad():
            for i, ((x1, x2), y) in enumerate(valid_dl):
                x1 = x1.cuda() / 255.
                x2 = x2.cuda() / 255.
                y = y.cuda() / 255.

                y_hat = sepconv(x1, x2)        
                l1_loss = critereon(y_hat, y)        
                loss = l1_loss - D(x1, x2, y_hat).sigmoid().mean()

                # compute psnr
                y_hat = (y_hat * 255).clamp(0,255)
                y = (y * 255).clamp(0,255)

                psnr = metrics.psnr(y_hat, y)
                ie = metrics.interpolation_error(y_hat, y)
                
                y_hat = y_hat.detach()
                D_loss = D(x1, x2, y_hat).sigmoid().mean() - D(x1, x2, y).sigmoid().mean()

                correct_preds = (D(x1, x2, y_hat).sigmoid().round() == 0).flatten().int().detach().cpu().tolist()
                correct_preds.extend((D(x1, x2, y).sigmoid().round() == 1).flatten().int().detach().cpu().tolist())

                R.store('valid', epoch, {'loss':loss.item(), 'psnr':psnr, 'ie':ie, 'accuracy':correct_preds})

                if i == 50:
                    break

        # update tensorboard
        R.write_tensorboard('valid', epoch)
        
    # save models
    return sepconv, D, R




### HP

In [8]:
trials = Trials()

best_hp = hyperopt.fmin(
    fn=objective, 
    space=params,
    algo=hyperopt.tpe.suggest,
    trials = trials,
    max_evals=10
)

100%|█████████████████████████████████████████████████| 10/10 [40:55<00:00, 243.67s/it, best loss: -0.9863524857689353]


In [10]:
space_eval(params, best_hp)

{'amsgrad': False,
 'input_size': 2,
 'loss': 'normal',
 'lr': 0.0001,
 'weight_decay': 0.0005}