# Training GAN with pretrained models

A DCGAN model is trained for image enhancement(superres, decrappify). <br>The dataset used is Flicker Image Dataset, availabe on Kaggle.
For training the model synthetic data is generated as in kernel: https://www.kaggle.com/greenahn/crappify-imgs<br>and saved to disk, which in conjunction with high resolution images are used to train the model.<br><br>
This model is trained by incorporating feature generated from vgg16 model in loss function, as in paper on neural art transfer.<br>
For more details, find the github repository at: https://github.com/nupam/GANs-for-Image-enhancement

Both pretrained generator and discriminator models are loaded from disk, output file of kernel,<br> https://www.kaggle.com/greenahn/pretrain-gan-feature-loss.<br>
They are then put together as a GAN, and trained.

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
print(os.listdir("../input"))

from tqdm import tqdm_notebook as tqdm

import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *
from fastai.vision.gan import *
import gc
from torchvision.models import vgg16_bn

In [None]:
## These folders contain crappy images in different resolution with differnt crappafication logic (randomly selected)
orig_path = Path('../input/flickrproc/hr/hr')
fnames_df = pd.read_csv('../input/flickrproc/files.csv')
bs = 16
FOLDERS = {256:Path('../input/flickrproc/crappy_256/crappy/'), 320:Path('../input/flickrproc/crappy_320/crappy/'), }
FOLDERS

## Loading Generator

In [None]:
##loading training data
## if dummy=True is provided, then dataset of ony 32 images is retured
def get_data(size=None, bs=None, folder=320, split=0.9, dummy=False):
    if dummy:
        if bs is None: bs = 1
            
        if size is None: 
            data = ImageImageList.from_df(fnames_df.iloc[:32], path = FOLDERS[320], cols='name').split_by_rand_pct(0.2, seed=34).label_from_func(lambda x: orig_path/Path(x).name).databunch(bs=bs).normalize(imagenet_stats, do_y=True)
        else:
            data = ImageImageList.from_df(fnames_df.iloc[:32], path = FOLDERS[320], cols='name').split_by_rand_pct(0.2, seed=34).label_from_func(lambda x: orig_path/Path(x).name).transform([], size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats, do_y=True)
        data.c = 3
        return data
    
    if bs is None: 
        raise ValueError('Batchsize is not provided')
    if size is None:
        raise ValueError('Size of image is not provided')
    
    folder = FOLDERS[folder]
    src = ImageImageList.from_df(fnames_df, 
                           path = folder, cols='name')
    src = src.split_by_idx(np.arange(int(src.items.shape[0]*split), src.items.shape[0]))
    
    data = src.label_from_func(lambda x: orig_path/Path(x).name).transform(get_transforms(max_zoom=1.2), size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats, do_y=True)
    data.c = 3
    return data

In [None]:
base_loss = F.l1_loss
arch = models.resnet34

vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]

def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

### Feature loss
As in paper on neural art tranfer, https://arxiv.org/abs/1508.06576.
L1 pixel distance is also added to loss.<br>
Prevents mode collapse and supervises for stabe and faster training.

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

In [None]:
gen_data = get_data(bs=bs, size=128, folder=320)

In [None]:
wd = 1e-3
y_range = (-3.,3.)

learn_gen = unet_learner(gen_data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics, blur=True, norm_type=NormType.Weight, model_dir="/kaggle/working", y_range=y_range)
gc.collect();

In [None]:
learn_gen.load('../input/train-gan-l1-and-features/gen-256')

## How good is pretrained model?

In [None]:
learn_gen.load('../input/train-gan-l1-and-features/gen-256')

In [None]:
# gen = load_learner('../input/train-gan-l1-and-features/')
# gen.data = get_data(dummy=True)

In [None]:
from IPython.display import FileLink

def enhance(link):
    !wget {link}
    !mv download pic.jpg
    img = open_image('pic.jpg')
    img.resize((3, *resize_to(img, 720, use_min=True)))
    img.refresh()
    print(img.size)
    img.save('original.jpg')
    out0 = gen.predict(img)
    img.flip_lr()
    out1 = gen.predict(img)
    out1[0].flip_lr()
    img.flip_lr()
    out0[0].save('0.jpg')
    out1[0].save('1.jpg')
    temp = (out0[0].data + out1[0].data)/2

    temp = fastai.vision.image2np(temp)
    plt.figure(figsize = (20,20))
    plt.imshow(temp)
    plt.imsave( 'avg.jpg', temp)
    !tar -czf images.tar *.jpg
    return FileLink('images.tar')

In [None]:
#enhance('https://cloud.anupam.gq/index.php/s/ToQas9LjEConSHw/download')

In [None]:
learn_gen.show_results(rows=10, figsize=(24, 100))

## Loading critic

In [None]:
def get_critic_data(bs, size=256, split=0.9):
    
    def labeler(x):
        ret = 'generated' if Path(x).parent.name == 'crappy' else 'original'
        return ret
    
    df = fnames_df
    valid_names = list(df['name'].iloc[int(split*len(df)):])
    
    src1 = ImageList.from_df(df, path = Path('../input/flickrproc/crappy_320')/'crappy', cols='name')
    src2 = ImageList.from_df(df, path = orig_path, cols='name')
    src1.add(items=src2)
    
    src = src1.split_by_valid_func(lambda x : Path(x).name in valid_names)
    data = src.label_from_func(labeler)
    data = data.transform(get_transforms(), size=size).databunch(bs=bs).normalize(imagenet_stats)
    
    data.c = 3
    return data

In [None]:
data_critic = get_critic_data(bs, 128)

In [None]:
data_critic.show_batch()

In [None]:
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())
def create_critic_learner(data, metrics):
    return   Learner(data_critic, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd, model_dir="/kaggle/working")
learn_critic = create_critic_learner(data_critic, accuracy_thresh_expand)
learn_critic.load('../input/train-gan-l1-and-features/critic-256')

### GAN
**Putting both models together as a GAN**

Training is done by adaptiveliy switching between discriminator and generator.<br>Discriminator is trained whenever discriminator loss drops below 0.65.

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_critic, weights_gen=(1.,50.), show_img=True, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd, model_dir="/kaggle/working", gen_first=True)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

In [None]:
del learn.data
learn.data = get_data(256,bs//2, 320)
gc.collect()
torch.cuda.empty_cache()
gpu_mem_get_free()

In [None]:
learn_gen.show_results(rows=20, figsize=(30, 100))

In [None]:
learn.fit(16,1e-4)

In [None]:
learn.show_results(rows=20, figsize=(30, 100))

Saving models

In [None]:
learn_gen.save('gen-256')
learn_critic.save('critic-256')

In [None]:
learn_gen.export("/kaggle/working/export.pkl")

In [None]:
!ls