In [1]:
import dataloader
import numpy as np
import network
import time
import torch
import torch.optim as optim

In [2]:
dataloader.args['train_scales'] = 2
dataloader.args['batch_size'] = 1
dataloader.args['train'] = 'data/*.jpg'

In [3]:
loader = dataloader.DataLoader()

In [4]:
# network
enhancer = network.Enhancer()

In [5]:
# pretrain
pretrain_params = {
    'smoothness-weight' : 1e7,
    'adversary-weight' : 0.0,
    'generator-start' : 0,
    'discriminator-start' : 0, #1
    'adversarial-start' : 1, #2
    'perceptual-weight' : 1e0,
    'epochs' : 2, #50
    'epoch-size' : 2, #72
    'batch-size' : 1, #15
    'image-size' : 192,
    'zoom' : 2,
    'learning-rate': 1e-4,
    'discriminator-size' : 64
}

In [18]:
t = torch.autograd.Variable(torch.Tensor(2, 1, 24, 24).random_())
t.sum(1).sum(1).sum(1).data.numpy()

array([  4.79335885e+09,   4.97336627e+09], dtype=float32)

In [15]:
def train(enhancer, mode, param):
    seed_size = param['image-size'] // param['zoom']
    images = np.zeros((param['batch-size'], 3, param['image-size'], param['image-size']), dtype=np.float32)
    seeds = np.zeros((param['batch-size'], 3, seed_size, seed_size), dtype=np.float32)
    
    loader.copy(images, seeds)
    # initial lr
    lr = network.decay_learning_rate(param['learning-rate'], 75, 0.5)
    
    #optimizer for generator
    opt_gen = optim.Adam(enhancer.generator.parameters(), lr = 0)
    
    try:
        average, start = None, time.time()
        for epoch in range(param['epochs']):
            adversary_weight = 5e2
            
            total, stats = None, None
            
            l_r = next(lr)
            network.update_optimizer_lr(opt_gen, l_r)
#             network.update_optimizer_lr(opt_disc, l_r)
            
            for step in range(param['epoch-size']):
                enhancer.zero_grad()
                loader.copy(images, seeds)
                
                # run full network once
                gen_out, c12, c22, c32, c52, disc_out = enhancer(images, seeds)
                
                # clone discriminator on the full network
                disc = enhancer.discriminator_clone()
                disc.zero_grad()
                
                # optimizer for discriminator
                opt_disc = optim.Adam(disc.parameters(), lr=l_r)
                opt_disc.state['step'] = epoch * param['epochs'] + step
                
                # output of new cloned network (maybe you can assert it to equal disc_out)
                disc_out2 = disc(c12.detach(), c22.detach(), c32.detach())
                disc_out_mean = disc_out2.sum(1).sum(1).sum(1).data.numpy()
                stats = stats + disc_out_mean if stats is not None else disc_out_mean 
                
                # compute generator loss
                if mode == 'pretrain':
                    gen_loss = network.loss_perceptual(c22[:param['batch-size']], c22[param['batch-size']:]) * param['perceptual-weight'] \
                        + network.loss_total_variation(gen_out) * param['smoothness-weight'] \
                        + network.loss_adversarial(disc_out[1:]) * adversary_weight
                else:
                    gen_loss = network.loss_perceptual(c52[:param['batch-size']], c52[param['batch-size']:]) * param['perceptual-weight'] \
                        + network.loss_total_variation(gen_out) * param['smoothness-weight'] \
                        + network.loss_adversarial(disc_out[1:]) * adversary_weight
                
                # compute discriminator loss
                disc_loss = network.loss_discriminator(disc_out2[:param['batch-size']], disc_out2[param['batch-size']:])
                
                total = total + gen_loss.data.numpy() if total is not None else gen_loss.data.numpy()
                
                average = gen_loss.data.numpy() if average is None else average * 0.95 + 0.05 * gen_loss.data.numpy()
                print('↑' if gen_loss.data.numpy() > average else '↓', end='', flush=True)
            
                # update parameters step
                
                gen_loss.backward()
                disc_loss.backward()
                
                opt_gen.step()
                opt_disc.step()
                
                # rebuild real discriminator from clone
                enhancer.assign_back_discriminator(disc)
            
            total /= param['epoch-size']
            stats /= param['epoch-size']
            
            print('\nGenerator Loss: ')
            print(total)
            
            real, fake = stats[:param['batch-size']], stats[param['batch-size']:]
            print('  - discriminator', real.mean(), len(np.where(real > 0.5)[0]),
                                       fake.mean(), len(np.where(fake < -0.5)[0]))
            
            if epoch == param['adversarial-start'] - 1:
                print('  - generator now optimizing against discriminator.')
                adversary_weight = param['adversary-weight']
                
            # Then save every several epochs
#             if epoch % 10 == 0:
#                 enhancer.save('model/model.pth')
                
    except KeyboardInterrupt:
        pass

In [16]:
train(enhancer, 'pretrain', pretrain_params)

↓↓
Generator Loss: 
[ 72319.03125]
  - discriminator -388.137 0 498.292 0
  - generator now optimizing against discriminator.
↓↓
Generator Loss: 
[ 48386.765625]
  - discriminator -527.849 0 639.15 0


In [44]:
# pretrain
train_params = {
    'smoothness-weight' : 2e4,
    'adversary-weight' : 1e3,
    'generator-start' : 5,
    'discriminator-start' : 0,
    'adversarial-start' : 5, 
    'perceptual-weight' : 1e0,
    'epochs' : 250,
    'epoch-size' : 72,
    'batch-size' : 15,
    'image-size' : 192,
    'zoom' : 2,
    'learning-rate': 1e-4,
    'discriminator-size' : 64
}

In [None]:
# for training do
# enhancer.create_new_discriminator(64)
# after that call train

In [45]:
# save pickle
import pickle

In [48]:
torch.save(enhancer.state_dict(), 'test_model.pth')

In [2]:
new_enhancer = network.Enhancer()
new_enhancer.load_state_dict(torch.load('test_model.pth'))

In [3]:
new_enhancer.create_new_discriminator(64)

In [6]:
opt_ = optim.Adam(enhancer.parameters(), lr=0.1)
for param_group in opt_.param_groups:
    print(param_group.keys())

dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay'])


In [13]:
opt_.state['step'] = 1
opt_.state

defaultdict(dict, {'step': 1})

defaultdict(dict, {})