In [1]:
#hide
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
#hide
%load_ext autoreload
%autoreload 2

%matplotlib inline
%cd /content/gdrive/My Drive/Colab Notebooks

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/content/gdrive/My Drive/Colab Notebooks


In [None]:
__all__ = "AdaptiveLoss CycleGanLossFunc".split(" ")

In [None]:
#hide
# !git clone https://github.com/prajwal-suresh13/dl_lib.git

In [None]:
from dl_lib.core.all import *
from dl_lib.cyclegan.cycleganmodel import *
from dl_lib.cyclegan.datadl import*

# Model

In [None]:
#hide
from torchvision.models import vgg16_bn
vgg_m = vgg16_bn(True).features.cuda().eval()
for p in vgg_m.parameters():p.requires_grad_(False)

  f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [None]:
#hide
blocks = [i-1 for i,o in enumerate(list(vgg_m.children())) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

([5, 12, 22, 32, 42],
 [ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True)])

In [None]:
class AdaptiveLoss(nn.Module):
    def __init__(self, crit):
        super().__init__()
        self.crit = crit

    def forward(self, output, target, **kwargs):
        targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())
        return self.crit(output,targ, **kwargs)

In [None]:
class CycleGanLossFunc(nn.Module):

    def __init__(self, cyclegan, lambda_A=10, lambda_B=10, lambda_idt=0.5,lambda_perc=5, lsgan=True,perceptual_loss=True,perc_layer_weights=[5,12,2],perc_scale=1.):
        super().__init__()
        self.cyclegan ,self.l_A, self.l_B, self.l_idt, self.l_perc = cyclegan, lambda_A, lambda_B, lambda_idt, lambda_perc
        self.perceptual_loss,self.perc_layer_weights,self.perc_scale = perceptual_loss, perc_layer_weights,perc_scale
        self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)
        if self.perceptual_loss:
            from torchvision.models import vgg16_bn
            vgg_m = vgg16_bn(True).features.cuda().eval()
            for p in vgg_m.parameters():p.requires_grad_(False)
            blocks = [i-1 for i,o in enumerate(list(vgg_m.children())) if isinstance(o,nn.MaxPool2d)]
            self.perceptual = PerceptualLoss(vgg_m, blocks[2:5], self.perc_layer_weights,scale=self.perc_scale)


    def set_input(self, real_A, real_B):
        self.real_A, self.real_B = real_A, real_B

    def forward(self, output, target):
        fake_A, fake_B,cyc_A, cyc_B, idt_A, idt_B = output

        #Identity loss
        self.id_loss = self.l_idt * (self.l_A * F.l1_loss(idt_A, self.real_A) + self.l_B * F.l1_loss(idt_B, self.real_B))
        
        #Generator loss
        self.gen_loss = self.crit(self.cyclegan.D_A(fake_A), True) + self.crit(self.cyclegan.D_B(fake_B), True)

        #Cyclic loss
        self.cyc_loss  = self.l_A * F.l1_loss(cyc_A, self.real_A)
        self.cyc_loss += self.l_B * F.l1_loss(cyc_B, self.real_B)

        #Perceptual Loss
        if self.perceptual_loss:
            self.perc_lossA = self.perceptual(self.real_A, cyc_A) 
            self.perc_lossB = self.perceptual(self.real_B, cyc_B) 
            self.perc_loss = (self.perc_lossA +self.perc_lossB) *self.l_perc

        total_loss = self.id_loss +self.gen_loss +self.cyc_loss

        return total_loss + self.perc_loss if self.perceptual_loss else total_loss

In [4]:
#hide
!pip install fire
!python dl_lib/notebook2script.py image_colorization/cyclegan/cycleganloss.ipynb dl_lib/cyclegan

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fire
  Downloading fire-0.4.0.tar.gz (87 kB)
[K     |████████████████████████████████| 87 kB 6.5 MB/s 
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115942 sha256=dd6a291ecf1c0a1859ce69f6fae669731c02dbd2be280bee3a258f7e0569e2da
  Stored in directory: /root/.cache/pip/wheels/8a/67/fb/2e8a12fa16661b9d5af1f654bd199366799740a85c64981226
Successfully built fire
Installing collected packages: fire
Successfully installed fire-0.4.0
Converted image_colorization/cyclegan/cycleganloss.ipynb to dl_lib/cyclegan/cycleganloss.py
