# **SkinDeep**

In [None]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from torchvision.models import vgg16_bn
from fastai.utils.mem import *

# **Path**

**Each Path has different set of Images.**

In [None]:
path = Path('/content/gdrive/MyDrive/Tattoo')
path_hr = Path('/content/gdrive/MyDrive/Tattoo/tattoo-r half')
path_lr = Path('/content/gdrive/MyDrive/Tattoo/tattoo half')

path_hr1 = Path('/content/gdrive/MyDrive/Tattoo/Split Tattoo-R')
path_lr1 = Path('/content/gdrive/MyDrive/Tattoo/Split Tattoo')

path_hr2 = Path('/content/gdrive/MyDrive/Tattoo/Tattoo-R')
path_lr2 = Path('/content/gdrive/MyDrive/Tattoo/Tattoo')

path_hr3 = Path('/content/gdrive/MyDrive/Tattoo/tat-r')
path_lr3 = Path('/content/gdrive/MyDrive/Tattoo/tat')

path_hr4 = Path('/content/gdrive/MyDrive/Tattoo/tatt-R')
path_lr4 = Path('/content/gdrive/MyDrive/Tattoo/tatt')

path_hr5 = Path('/content/gdrive/MyDrive/Tattoo/new-tattoo-r')
path_lr5 = Path('/content/gdrive/MyDrive/Tattoo/new-tattoo')

path_hr6 = Path('/content/gdrive/MyDrive/Tattoo/fresh tattoo-r')
path_lr6 = Path('/content/gdrive/MyDrive/Tattoo/fresh tattoo')

In [None]:
bs,size=10,64
arch = models.resnet34

In [None]:
src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.3, seed=42)


In [None]:
additional_aug=(perspective_warp(magnitude = 0.25, p = 0.25))

def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr/x.name)
           .transform(get_transforms(max_zoom=2., max_warp =0.25, max_lighting =0.5,xtra_tfms=additional_aug), size=size, tfm_y=True)
           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [None]:
data = get_data(bs,size)


In [None]:
data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))

In [None]:
t = data.valid_ds[0][1].data
t = torch.stack([t,t])

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [None]:
gram_matrix(t)

In [None]:
base_loss = F.l1_loss


In [None]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

In [None]:
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])


In [None]:
wd = 1e-3
y_range = (-3.,3.)

In [None]:
def create_gen_learner():
    return unet_learner(data, arch, wd=wd, blur=True,norm_type=NormType.Spectral,self_attention=True, y_range=(-3.0, 3.0),
                        loss_func=feat_loss, callback_fns=LossMetrics)
gc.collect();

In [None]:
learn_gen = create_gen_learner()


In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
lr = 8.32E-02
epoch = 5
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn_gen.fit_one_cycle(epoch, lrs, pct_start=pct_start)
    learn_gen.save(save_name)
    learn_gen.show_results(rows=1, imgsize=5)

In [None]:
do_fit('da', slice(lr*10))


In [None]:
learn_gen.show_results(rows=20)

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db', slice(1.32E-06,lr))
#1e-05

In [None]:
learn_gen.show_results(rows=10)

In [None]:
data = get_data(8,128)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch =6
lr = 3.98E-06
do_fit('db2',slice(lr))

In [None]:
learn_gen.show_results(rows=20)

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db3', slice(1.91E-06,1e-4), pct_start=0.3)

In [None]:
learn_gen.show_results(rows=10)

In [None]:
data = get_data(10,192)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db3');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch =5
lr = 6.92E-06
do_fit('db4')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db5', slice(6.31E-07,1e-5), pct_start=0.3)

In [None]:
src = ImageImageList.from_folder(path_lr2).split_by_rand_pct(0.2, seed=42)


In [None]:
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr2/x.name)
           .transform(get_transforms(max_zoom=2., max_warp =0.25, max_lighting =0.5), size=size, tfm_y=True)
           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [None]:
data = get_data(20,64)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db5');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
lr = 3.63E-03
do_fit('db6')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db7', slice(1.10E-06,1e-4), pct_start=0.3)

In [None]:
data = get_data(10,128)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db7');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
lr = 1.10E-06
do_fit('db8')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db9', slice(1.10E-06,1e-4), pct_start=0.3)

In [None]:
src = ImageImageList.from_folder(path_lr3).split_by_rand_pct(0.2, seed=42)


In [None]:
additional_aug=(perspective_warp(magnitude = 0.25, p = 0.25))
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr3/x.name)
           .transform(get_transforms(max_zoom=2., max_warp =0.25, max_lighting =0.5,xtra_tfms=additional_aug), size=size, tfm_y=True)
           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [None]:
data = get_data(10,128)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db9');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
lr = 2.29E-06
do_fit('db10')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db11', slice(2.29E-06,1e-4), pct_start=0.3)

In [None]:
data = get_data(10,192)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db11');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
lr = 2.29E-06
do_fit('db12')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db13', slice(9.12E-07,1e-5), pct_start=0.3)

In [None]:
src = ImageImageList.from_folder(path_lr4).split_by_rand_pct(0.2, seed=42)


In [None]:
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr4/x.name)
           .transform(get_transforms(max_zoom=2., max_warp =0.25, max_lighting =0.5), size=size, tfm_y=True)
           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [None]:
data = get_data(10,128)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db13');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
lr = 3.31E-06
do_fit('db14')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db15', slice(2.75E-06,1e-3), pct_start=0.3)

In [None]:
data = get_data(10,192)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db15');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
lr = 3.98E-06
do_fit('db16')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db17', slice(6.31E-07,1e-5), pct_start=0.3)

In [None]:
src = ImageImageList.from_folder(path_lr6).split_by_rand_pct(0.2, seed=42)


In [None]:
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr6/x.name)
           .transform(get_transforms(max_zoom=2., max_warp =0.25, max_lighting =0.5), size=size, tfm_y=True)
           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [None]:
data = get_data(10,192)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db17');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
lr = 2.09E-05
do_fit('db18')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db19', slice(1.91E-06,1e-4), pct_start=0.3)

In [None]:
data = get_data(9,256)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db19');

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
lr = 6.31E-07
do_fit('db20')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 5
do_fit('db21', slice(1.58E-06,1e-4), pct_start=0.3)

In [None]:
src = ImageImageList.from_folder(path_lr5).split_by_rand_pct(0.2, seed=42)


In [None]:
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr5/x.name)
           .transform(get_transforms(max_zoom=2., max_warp =0.25, max_lighting =0.5), size=size, tfm_y=True)
           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [None]:
data = get_data(5,320)
learn_gen.data = data
learn_gen.freeze()
gc.collect()
learn_gen.load('db21');

In [None]:
data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 6
lr = 6.31E-07
do_fit('db22')

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.lr_find()
learn_gen.recorder.plot(suggestion =True)

In [None]:
epoch = 6
do_fit('db23', slice(3.63E-05,1e-3), pct_start=0.3)

In [None]:
learn_gen.show_results(rows=30)