In [None]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from torchvision.models import vgg16_bn
from fastai.utils.mem import *

##Path


In [None]:
path = Path('/content/gdrive/My Drive/mit')
path_hr = Path('/content/gdrive/My Drive/mit/high')
path_lr = Path('/content/gdrive/My Drive/mit/low')

##Architecture

In [None]:
arch = models.resnet34

In [None]:
def get_data(bs,size):
    src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.3, seed=42)
    data = (src.label_from_func(lambda x: path_hr/x.name)
           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

##64px

In [None]:
bs,size=20,64

In [None]:
data = get_data(bs,size)

In [None]:
t = data.valid_ds[0][1].data
t = torch.stack([t,t])

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [None]:
gram_matrix(t)


In [None]:
base_loss = F.l1_loss

In [None]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

In [None]:
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

In [None]:
wd = 1e-3
y_range = (-3.,3.)

In [None]:
def create_gen_learner():
    return unet_learner(data, arch, wd=wd, blur=True,norm_type=NormType.Spectral,self_attention=True, y_range=(-3.0, 3.0),
                        loss_func=feat_loss, callback_fns=LossMetrics)
gc.collect();

In [None]:
learn_gen = create_gen_learner()

In [None]:
lr = 1e-3
epoch = 1
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn_gen.fit_one_cycle(epoch, lrs, pct_start=pct_start,)
    learn_gen.save(save_name)
    learn_gen.show_results(rows=1, imgsize=5)

In [None]:
do_fit('1a', slice(lr*10))


In [None]:
learn_gen.unfreeze()

In [None]:
epoch = 1
do_fit('1b', slice(1e-01,lr))


##128px

In [None]:
data = get_data(8,128)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('1b');

In [None]:
epoch =1
do_fit('2a')

In [None]:
learn_gen.unfreeze()

In [None]:
epoch = 1
do_fit('2b', slice(1e-02), pct_start=0.3)

##192px

In [None]:
data = get_data(4,192)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('2b');

In [None]:
epoch = 4
lr = 1e-03
do_fit('3a')

In [None]:
learn_gen.freeze()

In [None]:
epoch = 4
do_fit('3b', slice(1e-07), pct_start=0.3)

##Save Generated Images

In [None]:
learn_gen.load('3b');
name_gen = 'image_gen'
path_gen = path/name_gen

In [None]:
path_gen.mkdir(exist_ok=True)

In [None]:
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1

In [None]:
save_preds(data.fix_dl)

In [None]:
  
from fastai.core import *
from fastai.torch_core import *
from fastai.vision import *
from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand

_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)


def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
    return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)


def custom_gan_critic(
    n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
):
    "Critic to train a `GAN`."
    layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
    for i in range(n_blocks):
        layers += [
            _conv(nf, nf, ks=3, stride=1),
            nn.Dropout2d(p),
            _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
        ]
        nf *= 2
    layers += [
        _conv(nf, nf, ks=3, stride=1),
        _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
        Flatten(),
    ]
    return nn.Sequential(*layers)


def colorize_crit_learner(
    data: ImageDataBunch,
    loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
    nf: int = 256,
) -> Learner:
    return Learner(
        data,
        custom_gan_critic(nf=nf),
        metrics=accuracy_thresh_expand,
        loss_func=loss_critic,
        wd=1e-3,
    )

## Train Critic

In [None]:
learn_gen=None
gc.collect()

In [None]:
def get_crit_data(classes, bs, size):
    src = ImageList.from_folder(path, include=classes).split_by_rand_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=size)
           .databunch(bs=bs).normalize(imagenet_stats))
    data.c = 3
    return data

In [None]:
bs =4
size =192
data_crit = get_crit_data([name_gen, 'high'], bs=bs, size=size)

In [None]:
learn_critic = colorize_crit_learner(data=data_crit, nf=256)

In [None]:
learn_critic.fit_one_cycle(1, 1e-02)

In [None]:
learn_critic.save('critic-pre2')


## GAN (Generative adversarial networks)

In [None]:
learn_crit=None
learn_gen=None
gc.collect()

In [None]:
data_crit = get_crit_data([name_gen, 'high'], bs=bs, size=size)

In [None]:
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load('critic-pre2')

In [None]:
learn_gen = create_gen_learner().load('3b')

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=True, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

In [None]:
lr = 2e-5

In [None]:
learn.fit(1,lr)