# Training GAN with pretrained models

A DCGAN model is trained for image enhancement(superres, decrappify). <br>The dataset used is Flicker Image Dataset, availabe on Kaggle.
For training the model synthetic data is generated as in kernel: https://www.kaggle.com/greenahn/crappify-imgs<br>and saved to disk, which in conjunction with high resolution images are used to train the model.<br><br>
For more details, find the github repository at: https://github.com/nupam/GANs-for-Image-enhancement

Both pretrained generator and discriminator models are loaded from disk, output file of kernel,<br> https://www.kaggle.com/greenahn/pretrain-gan-mse.<br>
They are then put together as a GAN, and trained.

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
print(os.listdir("../input"))

from tqdm import tqdm_notebook as tqdm

import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *
from fastai.vision.gan import *
import gc
from torchvision.models import vgg16_bn

In [None]:
## These folders contain crappy images in different resolution with differnt crappafication logic (randomly selected)
orig_path = Path('../input/flickrproc/hr/hr')
fnames_df = pd.read_csv('../input/flickrproc/files.csv')
bs = 16
FOLDERS = {256:Path('../input/flickrproc/crappy_256/crappy/'), 320:Path('../input/flickrproc/crappy_320/crappy/'), }
FOLDERS

## Loading Generator

In [None]:
##loading training data
## if dummy=True is provided, then dataset of ony 32 images is retured
def get_data(size=None, bs=None, folder=320, split=0.9, dummy=False):
    if dummy:
        if bs is None: bs = 1
            
        if size is None: 
            data = ImageImageList.from_df(fnames_df.iloc[:32], path = FOLDERS[320], cols='name').split_by_rand_pct(0.2, seed=34).label_from_func(lambda x: orig_path/Path(x).name).databunch(bs=bs).normalize(imagenet_stats, do_y=True)
        else:
            data = ImageImageList.from_df(fnames_df.iloc[:32], path = FOLDERS[320], cols='name').split_by_rand_pct(0.2, seed=34).label_from_func(lambda x: orig_path/Path(x).name).transform([], size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats, do_y=True)
        data.c = 3
        return data
    
    if bs is None: 
        raise ValueError('Batchsize is not provided')
    if size is None:
        raise ValueError('Size of image is not provided')
    
    folder = FOLDERS[folder]
    src = ImageImageList.from_df(fnames_df, 
                           path = folder, cols='name')
    src = src.split_by_idx(np.arange(int(src.items.shape[0]*split), src.items.shape[0]))
    
    data = src.label_from_func(lambda x: orig_path/Path(x).name).transform(get_transforms(max_zoom=1.2), size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats, do_y=True)
    data.c = 3
    return data

In [None]:
data_gen = get_data(128,bs)

In [None]:
wd = 1e-3
y_range = (-3.,3.)
loss_gen = MSELossFlat()
arch = models.resnet34

def create_gen_learner():
    return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                         self_attention=True, y_range=y_range, loss_func=loss_gen, model_dir="/kaggle/working")


In [None]:
!ls ../input/pretrain-gan-mse/

In [None]:
learn_gen = create_gen_learner()
learn_gen.load(Path('/kaggle/input/pretrain-gan-mse/gen_pre'))

### How good is pretrained model?

In [None]:
learn_gen.show_results(rows=20)

## Loading critic

In [None]:
def get_critic_data(bs, size=256, split=0.9):
    
    def labeler(x):
        ret = 'generated' if Path(x).parent.name == 'crappy' else 'original'
        return ret
    
    df = fnames_df
    valid_names = list(df['name'].iloc[int(split*len(df)):])
    
    src1 = ImageList.from_df(df, path = Path('../input/flickrproc/crappy_320')/'crappy', cols='name')
    src2 = ImageList.from_df(df, path = orig_path, cols='name')
    src1.add(items=src2)
    
    src = src1.split_by_valid_func(lambda x : Path(x).name in valid_names)
    data = src.label_from_func(labeler)
    data = data.transform(get_transforms(), size=size).databunch(bs=bs).normalize(imagenet_stats)
    
    data.c = 3
    return data

data_critic = get_critic_data(bs, 128)

In [None]:
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())
def create_critic_learner(data, metrics):
    return   Learner(data_critic, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd, model_dir="/kaggle/working")
learn_critic = create_critic_learner(data_critic, accuracy_thresh_expand)
learn_critic.load('../input/pretrain-gan-mse/critic-pre')

### GAN
**Putting both models together as a GAN**

In [None]:
gc.collect()
torch.cuda.empty_cache()
gpu_mem_get_free()

Training is done by adaptiveliy switching between discriminator and generator.<br>Discriminator is trained whenever discriminator loss drops below 0.65.

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_critic, weights_gen=(1.,50.), show_img=True, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd, model_dir="/kaggle/working", gen_first=True)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

In [None]:
lr = 2e-4

In [None]:
learn.fit(8,lr)

In [None]:
learn.show_results(rows=20)

In [None]:
learn_gen.save('gen-128')
learn_critic.save('critic-128')

In [None]:
del learn.data
gc.collect()
torch.cuda.empty_cache()
gpu_mem_get_free()

Increasing image size and training again

In [None]:
learn.data = get_data(256,bs//2)

In [None]:
learn.fit(7,lr/2)

In [None]:
learn.show_results(rows=10, figsize=(30, 100))

Saving models

In [None]:
learn_gen.save('gen-256')
learn_critic.save('critic-256')

In [None]:
learn_gen.export("/kaggle/working/export.pkl")

In [None]:
!ls