In [1]:
from code.sepconvfull import model
import dataloader
import torch
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict
import matplotlib.pyplot as plt
import discriminator
from tqdm import tqdm
from collections import defaultdict
import metrics
import numpy as np

In [2]:
params = {
    'lr': 1e-4,
    'weight_decay': 0,
    'amsgrad':False,
    'loss': 'normal', # or wasserstein,
    'input_size': 2 # or 4
    
}

In [25]:
N_EPOCHS=10
writer = SummaryWriter('runs/adverserial_cond_4')

In [26]:
def convert_weights(weights):
    w = OrderedDict()
    for key in weights:
        new_key = 'get_kernel.'+key
        w[new_key] = weights[key]
        
    return w

In [27]:
# init interpolation model
sepconv = model.SepConvNet(kernel_size=51)

weights = torch.load('code/sepconv/network-l1.pytorch')
weights = convert_weights(weights)

sepconv.load_state_dict(weights)
opt = torch.optim.Adam(sepconv.parameters())

# init discriminator
disc_model = discriminator.Discriminator()

In [28]:
sepconv = sepconv.cuda()
D = disc_model.cuda()

In [29]:
def write_tensorboard_train(writer, metrics):
    writer.add_scalar('Loss/train', np.mean(train_loss[epoch]), epoch)
    writer.add_scalar('PSNR/train', np.mean(train_psnr[epoch]), epoch)
    writer.add_scalar('IE/train', np.mean(train_ie[epoch]), epoch)
    writer.add_scalar('Accuracy/train', np.mean(train_correct[epoch]), epoch)
    
    writer.add_histogram('Loss/train_hist', np.array(train_loss[epoch]), epoch)
    writer.add_histogram('PSNR/train_hist', np.array(train_psnr[epoch]), epoch)
    
def write_tensorboard_valid(writer, metrics):
    writer.add_scalar('Loss/valid', np.mean(valid_loss[epoch]), epoch)
    writer.add_scalar('PSNR/valid', np.mean(valid_psnr[epoch]), epoch)
    writer.add_scalar('IE/valid', np.mean(valid_ie[epoch]), epoch)
    writer.add_scalar('Accuracy/valid', np.mean(valid_correct[epoch]), epoch)
    
    writer.add_histogram('Loss/valid_hist', np.array(valid_loss[epoch]), epoch)
    writer.add_histogram('PSNR/valid_hist', np.array(valid_psnr[epoch]), epoch)

In [30]:
class ResultStore:
    
    def __init__(self, folds=['train', 'valid'], metrics=['psnr', 'ie', 'loss', 'accuracy'], writer=None):
        self.folds = folds
        self.metrics = metrics
        self.results = dict()
        self.writer = writer
        
        for fold in self.folds:
            self.results[fold] = dict()
            for metric in self.metrics:
                self.results[fold][metric] = defaultdict(list)
        
    def store(self, fold, metric, epoch, value):
        self.results[fold][metric][epoch].extend(value)
        
    def write_tensorboard(self, fold, epoch):
        for metric in self.metrics:
#             print(self.results[fold][metric][epoch])
            mean = np.mean(self.results[fold][metric][epoch])
            
            self.writer.add_scalar(f'{metric}/{fold}', mean, epoch)
            self.writer.add_histogram(f'{metric}/{fold}_hist', np.array(self.results[fold][metric][epoch]), epoch)
        
        



In [40]:
metrics.ssim(y_hat, y)

tensor([0.9941, 0.9966, 0.9948, 0.9981], device='cuda:0')

In [31]:
ds = dataloader.adobe240_dataset()
ds = dataloader.TransformedDataset(ds, crop_size=(128,128))

N_train = int(len(ds) * 0.8)
N_valid = len(ds)-N_train

train, valid = torch.utils.data.random_split(ds, [N_train, N_valid])

train_dl = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True, pin_memory=True)
valid_dl = torch.utils.data.DataLoader(valid, batch_size=4, pin_memory=True)

optimizer_G = torch.optim.Adam(sepconv.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], amsgrad=params['amsgrad'])
optimizer_D = torch.optim.Adam(D.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], amsgrad=params['amsgrad'])
critereon = torch.nn.L1Loss()

# metrics
# train_loss = defaultdict(list)
# valid_loss = defaultdict(list)
# train_psnr = defaultdict(list)
# valid_psnr = defaultdict(list)
# train_ie = defaultdict(list)
# valid_ie = defaultdict(list)
# train_correct = defaultdict(list)
# valid_correct = defaultdict(list)
R = ResultStore(writer=writer)

for epoch in range(N_EPOCHS):
    sepconv.train()
    D.train()
    for i, ((x1, x2), y) in enumerate(tqdm(train_dl, total=len(train_dl), desc=f'{epoch+1}/{N_EPOCHS}')):
        x1 = x1.cuda() / 255.
        x2 = x2.cuda() / 255.
        y = y.cuda() / 255.
        
        
        y_hat = sepconv(x1, x2)
        
        l1_loss = critereon(y_hat, y)
        

        
        loss = l1_loss - D(x1, x2, y_hat).sigmoid().mean()
        
        R.store('train', 'loss', epoch, [loss.item()])
        
        # compute psnr
        y_hat = (y_hat * 255).clamp(0,255)
        y = (y * 255).clamp(0,255)
        
        psnr = metrics.psnr(y_hat, y)
        psnr = psnr.detach().cpu().tolist()
#         train_psnr[epoch].extend(psnr)
        R.store('train', 'psnr', epoch, psnr)
        
        ie = metrics.interpolation_error(y_hat, y)
        ie = ie.detach().cpu().tolist()
#         train_ie[epoch].extend(ie)
        R.store('train', 'ie', epoch, ie)
        
        
        optimizer_G.zero_grad()
        l1_loss.backward()
        optimizer_G.step()
        
        
        
        # train discriminator
        y_hat = y_hat.detach()
        
        for p in D.parameters():
            p.data.clamp_(-0.01, 0.01)
        
        D_loss = D(x1, x2, y_hat).sigmoid().mean() - D(x1, x2, y).sigmoid().mean()
        
        correct_preds = (D(x1, x2, y_hat).sigmoid().round() == 0).flatten().int().detach().cpu().tolist()
        correct_preds.extend((D(x1, x2, y).sigmoid().round() == 1).flatten().int().detach().cpu().tolist())
#         train_correct[epoch].extend(correct_preds)
        R.store('train', 'accuracy', epoch, correct_preds)
        
        optimizer_D.zero_grad()
        D_loss.backward()
        optimizer_D.step()
        
        if i == 5:
            break
        
    # update tensorboard
    R.write_tensorboard('train', epoch)
    

        
    
    sepconv.eval()
    D.eval()
    with torch.no_grad():
        for i, ((x1, x2), y) in enumerate(valid_dl):
            x1 = x1.cuda() / 255.
            x2 = x2.cuda() / 255.
            y = y.cuda() / 255.

            y_hat = sepconv(x1, x2)        
            l1_loss = critereon(y_hat, y)        
            loss = l1_loss - D(x1, x2, y_hat).sigmoid().mean()

            R.store('valid', 'loss', epoch, [loss.item()])

            
            # compute psnr
            y_hat = (y_hat * 255).clamp(0,255)
            y = (y * 255).clamp(0,255)

            psnr = metrics.psnr(y_hat, y)
            psnr = psnr.detach().cpu().tolist()
            R.store('valid', 'psnr', epoch, psnr)
            
            ie = metrics.interpolation_error(y_hat, y)
            ie = ie.detach().cpu().tolist()
            R.store('valid', 'ie', epoch, ie)

            
            y_hat = y_hat.detach()
        
            D_loss = D(x1, x2, y_hat).sigmoid().mean() - D(x1, x2, y).sigmoid().mean()

            correct_preds = (D(x1, x2, y_hat).sigmoid().round() == 0).flatten().int().detach().cpu().tolist()
            correct_preds.extend((D(x1, x2, y).sigmoid().round() == 1).flatten().int().detach().cpu().tolist())
            R.store('valid', 'accuracy', epoch, correct_preds)
            
            if i == 5:
                break
            
    # update tensorboard
    R.write_tensorboard('valid', epoch)
        
# save models



1/10:   0%|                                                                           | 5/4826 [00:02<33:10,  2.42it/s]
2/10:   0%|                                                                           | 5/4826 [00:02<32:50,  2.45it/s]
3/10:   0%|                                                                           | 5/4826 [00:02<32:12,  2.49it/s]
4/10:   0%|                                                                           | 5/4826 [00:02<32:40,  2.46it/s]
5/10:   0%|                                                                           | 5/4826 [00:02<32:34,  2.47it/s]
6/10:   0%|                                                                           | 5/4826 [00:02<33:21,  2.41it/s]
7/10:   0%|                                                                           | 5/4826 [00:02<33:46,  2.38it/s]
8/10:   0%|                                                                           | 5/4826 [00:02<32:47,  2.45it/s]
9/10:   0%|                             

In [None]:
# writer.add_image('output', y_hat[0], 0)

In [None]:
# disc_model(y_hat).sigmoid().mean()

In [13]:
D(y_hat).sigmoid()

tensor([[0.],
        [0.]], device='cuda:0', grad_fn=<SigmoidBackward>)

### Eval

In [None]:
def evaluate_model(model, dl):
    
    results = defaultdict(list)
    model.eval()
    
    with torch.no_grad():    
        for i, ((x1, x2), y) in enumerate(dl):
            x1 = x1.permute(0,3,1,2).cuda() / 255.
            x2 = x2.permute(0,3,1,2).cuda() / 255.
            y = y.permute(0,3,1,2).cuda()

            y_hat = model(x1, x2)
            
            y_hat = (y_hat * 255).clamp(0,255)

            psnr = metrics.psnr(y_hat, y)
            ie = metrics.interpolation_error(y_hat, y)
            
            psnr = psnr.detach().cpu().tolist()
            ie = ie.detach().cpu().tolist()

            results['psnr'].extend(psnr)
            results['ie'].extend(ie)
        
    return results
        
        

In [None]:
%%time
# t = dataloader2.Transformer(random_crop=False)
ds = dataloader.adobe240_dataset(transformer=None)

N_train = int(len(ds) * 0.8)
N_test = len(ds)-N_train

# _, test = torch.utils.data.random_split(ds, [N_train, N_test])

test_dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False)
m = evaluate_model(sepconv, test_dl)

np.mean(m['psnr']), np.mean(m['ie'])

### test subset

In [None]:
import dataloader2

In [None]:
dataset = dataloader2.adobe240_dataset()
# dataset = dataloader2.TransformedDataset(dataset, crop_size=(512, 512), h_flip_prob=1)

In [None]:
N_train = int(len(dataset) * 0.8)
N_valid = int(len(dataset) * 0.1)
N_test = len(dataset)-N_train-N_valid

train, valid, test = torch.utils.data.random_split(dataset, [N_train, N_valid, N_test])

In [None]:
train = dataloader2.TransformedDataset(train, crop_size=(512, 512), h_flip_prob=1)
valid = dataloader2.TransformedDataset(valid, crop_size=(512, 512), h_flip_prob=1)

In [None]:
train_dl = torch.utils.data.DataLoader(train)
valid_dl = torch.utils.data.DataLoader(valid)
test_dl = torch.utils.data.DataLoader(test)