In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import multiprocessing
from torch import autograd
from fastai.conv_learner import *
from fastai.transforms import TfmType
from fasterai.transforms import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.callbacks import *
from fasterai.loss import *
from fasterai.modules import *
from fasterai.training import *
from fasterai.generators import *
from fastai.torch_imports import *
from pathlib import Path
from itertools import repeat
import tensorboardX
torch.cuda.set_device(3)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [None]:
def is_listy(x): 
    return isinstance(x, (list,tuple))

In [None]:
def _hook_inner(m,i,o): 
    return o
    #return o if isinstance(o,Tensor) else o if is_listy(o) else list(o)

In [None]:
class Hook():
    "Create a hook on `m` with `hook_func`."
    def __init__(self, m:nn.Module, hook_func, is_forward:bool=True, detach:bool=True):
        self.hook_func,self.detach,self.stored = hook_func,detach,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module:nn.Module, input, output):
        "Applies `hook_func` to `module`, `input`, `output`."
        if self.detach:
            input  = (o.detach() for o in input ) if is_listy(input ) else input.detach()
            output = (o.detach() for o in output) if is_listy(output) else output.detach()
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        "Remove the hook from the model."
        if not self.removed:
            self.hook.remove()
            self.removed=True

    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()

In [None]:
class Hooks():
    "Create several hooks on the modules in `ms` with `hook_func`."
    def __init__(self, ms, hook_func, is_forward:bool=True, detach:bool=True):
        self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]

    def __getitem__(self,i:int)->Hook: return self.hooks[i]
    def __len__(self)->int: return len(self.hooks)
    def __iter__(self): return iter(self.hooks)
    @property
    def stored(self): return [o.stored for o in self]

    def remove(self):
        "Remove the hooks from the model."
        for h in self.hooks: h.remove()

    def __enter__(self, *args): return self
    def __exit__ (self, *args): self.remove()

In [None]:
def requires_grad(m:nn.Module, b:bool=None)->[bool]:
    "If `b` is not set `requires_grad` on all params in `m`, else return `requires_grad` of first param."
    ps = list(m.parameters())
    if not ps: return None
    if b is None: return ps[0].requires_grad
    for p in ps: p.requires_grad=b

In [None]:
base_loss = F.l1_loss
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [None]:
def hook_outputs(modules, detach:bool=True, grad:bool=False)->Hooks:
    "Return `Hooks` that store activations of all `modules` in `self.stored`"
    return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)

In [None]:
class NewFeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
class MseLoss(nn.Module):
    def __init__(self, multiplier:int=1.0):
        super().__init__()
        self.multiplier=multiplier
        self.fn=F.mse_loss
    
    def forward(self, input, target):
        return self.fn(input,target)*self.multiplier
        

In [None]:
#feat_loss = NewFeatureLoss(vgg_m, blocks[2:5], [5,15,2])
feat_loss = MseLoss(multiplier=10)
#feat_loss = FeatureLoss(multiplier=1e2)

In [None]:
IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')
proj_id = 'SuperResResCrit11_mse10'
TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

#pretrained in SuperResolutionTrainingPreTrain
gpath = Path('superres2x34_gen_pretrain.h5')
#pretrained critic path
dpath = IMAGENET.parent/('ResCriticPre11_crit_128.h5')
#dpath = IMAGENET.parent/('colorize_critic_192.h5')

c_lr=2e-4
c_lrs = np.array([c_lr,c_lr,c_lr])

g_lr=c_lr/5
g_lrs = np.array([g_lr/100,g_lr/10,g_lr])

scale=2
keep_pcts=[0.50,0.50]
gen_freeze_tos=[-1,0]
lrs_unfreeze_factor=0.05
x_tfms = []
extra_aug_tfms = [RandomLighting(0.1, 0.1, tfm_y=TfmType.PIXEL)]
torch.backends.cudnn.benchmark=True

In [None]:
netG = Unet34_V2(nf_factor=2, scale=scale).cuda()
#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')
load_model(netG, gpath)

netD = DCCriticDense(ni=3, nf=124).cuda()
#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')
load_model(netD, dpath, strict=False)

In [None]:
trainer = GANTrainer(netD=netD, netG=netG, genloss_fns=[feat_loss] , max_critic_loops=1000)
trainerVis = GANVisualizationHook(TENSORBOARD_PATH, trainer, 'trainer', jupyter=False, visual_iters=100)

## Training

In [None]:
scheds=[]

scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[16,16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos, reduce_x_scale=scale))

#c_lrs=c_lrs/4
#g_lrs=g_lrs/4

#unshock
#scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], 
    #save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1], reduce_x_scale=scale))

#scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.5], 
    #save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1], reduce_x_scale=scale))

#scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.5], 
    #save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0], reduce_x_scale=scale))


#c_lrs=c_lrs/4
#g_lrs=g_lrs/4

#unshock
#scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], 
    #save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1], reduce_x_scale=scale))

#scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.5], 
    #save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1], reduce_x_scale=scale))

scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[4], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.5], 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1], reduce_x_scale=scale))

scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[4], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.5], 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0], reduce_x_scale=scale))



In [None]:
trainer.train(scheds=scheds)