# GANILLA model

> Defines the GANILLA model architecture.

In [None]:
# default_exp models.ganilla

In [None]:
#export
from fastai.vision.all import *
from fastai.basics import *
from typing import List
from fastai.vision.gan import *
from upit.models.cyclegan import *

In [None]:
#hide
from nbdev.showdoc import *

We use the generator that was introduced in the [GANILLA paper](https://arxiv.org/abs/1703.10593).

## Generator

In [None]:
#export
class BasicBlock_Ganilla(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, use_dropout, stride=1):
        super(BasicBlock_Ganilla, self).__init__()
        self.rp1 = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=0, bias=False)
        self.bn1 = nn.InstanceNorm2d(planes)
        self.use_dropout = use_dropout
        if use_dropout:
            self.dropout = nn.Dropout(use_dropout)
        self.rp2 = nn.ReflectionPad2d(1)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=0, bias=False)
        self.bn2 = nn.InstanceNorm2d(planes)
        self.out_planes = planes

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.InstanceNorm2d(self.expansion*planes)
            )

            self.final_conv = nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(self.expansion * planes * 2, self.expansion * planes, kernel_size=3, stride=1,
                                        padding=0, bias=False),
                nn.InstanceNorm2d(self.expansion * planes)
            )
        else:
            self.final_conv = nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(planes*2, planes, kernel_size=3, stride=1, padding=0, bias=False),
                nn.InstanceNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(self.rp1(x))))
        if self.use_dropout:
            out = self.dropout(out)
        out = self.bn2(self.conv2(self.rp2(out)))
        inputt = self.shortcut(x)
        catted = torch.cat((out, inputt), 1)
        out = self.final_conv(catted)
        out = F.relu(out)
        return out

In [None]:
#export
class PyramidFeatures(nn.Module):
    def __init__(self, C2_size, C3_size, C4_size, C5_size, fpn_weights, feature_size=128):
        super(PyramidFeatures, self).__init__()

        self.sum_weights = fpn_weights #[1.0, 0.5, 0.5, 0.5]

        # upsample C5 to get P5 from the FPN paper
        self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P5_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
        #self.rp1 = nn.ReflectionPad2d(1)
        #self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=0)

        # add P5 elementwise to C4
        self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P4_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
        #self.rp2 = nn.ReflectionPad2d(1)
        #self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=0)

        # add P4 elementwise to C3
        self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P3_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
        #self.rp3 = nn.ReflectionPad2d(1)
        #self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=0)

        self.P2_1 = nn.Conv2d(C2_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P2_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
        self.rp4 = nn.ReflectionPad2d(1)
        self.P2_2 = nn.Conv2d(int(feature_size), int(feature_size/2), kernel_size=3, stride=1, padding=0)

        #self.P1_1 = nn.Conv2d(feature_size, feature_size, kernel_size=1, stride=1, padding=0)
        #self.P1_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
        #self.rp5 = nn.ReflectionPad2d(1)
        #self.P1_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=0)

    def forward(self, inputs):

        C2, C3, C4, C5 = inputs

        i = 0
        P5_x = self.P5_1(C5) * self.sum_weights[i]
        P5_upsampled_x = self.P5_upsampled(P5_x)
        #P5_x = self.rp1(P5_x)
        # #P5_x = self.P5_2(P5_x)
        i += 1
        P4_x = self.P4_1(C4) * self.sum_weights[i]
        P4_x = P5_upsampled_x + P4_x
        P4_upsampled_x = self.P4_upsampled(P4_x)
        #P4_x = self.rp2(P4_x)
        # #P4_x = self.P4_2(P4_x)
        i += 1
        P3_x = self.P3_1(C3) * self.sum_weights[i]
        P3_x = P3_x + P4_upsampled_x
        P3_upsampled_x = self.P3_upsampled(P3_x)
        #P3_x = self.rp3(P3_x)
        #P3_x = self.P3_2(P3_x)
        i += 1
        P2_x = self.P2_1(C2) * self.sum_weights[i]
        P2_x = P2_x * self.sum_weights[2] + P3_upsampled_x
        P2_upsampled_x = self.P2_upsampled(P2_x)
        P2_x = self.rp4(P2_upsampled_x)
        P2_x = self.P2_2(P2_x)

        return P2_x

In [None]:
#export
class ResNet(nn.Module):

    def __init__(self, input_nc, output_nc, ngf, use_dropout, fpn_weights, block, layers):
        self.inplanes = ngf
        super(ResNet, self).__init__()

        # first conv
        self.pad1 = nn.ReflectionPad2d(input_nc)
        self.conv1 = nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=True)
        self.in1 = nn.InstanceNorm2d(ngf)
        self.relu = nn.ReLU(inplace=True)
        self.pad2 = nn.ReflectionPad2d(1)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)

        # Output layer
        self.pad3 = nn.ReflectionPad2d(output_nc)
        self.conv2 = nn.Conv2d(64, output_nc, 7)
        self.tanh = nn.Tanh()


        if block == BasicBlock_Ganilla:
            # residuals
            self.layer1 = self._make_layer_ganilla(block, 64, layers[0], use_dropout, stride=1)
            self.layer2 = self._make_layer_ganilla(block, 128, layers[1], use_dropout, stride=2)
            self.layer3 = self._make_layer_ganilla(block, 128, layers[2], use_dropout, stride=2)
            self.layer4 = self._make_layer_ganilla(block, 256, layers[3], use_dropout, stride=2)

            fpn_sizes = [self.layer1[layers[0] - 1].conv2.out_channels,
                         self.layer2[layers[1] - 1].conv2.out_channels,
                         self.layer3[layers[2] - 1].conv2.out_channels,
                         self.layer4[layers[3] - 1].conv2.out_channels]

        else:
            print("This block type is not supported")
            sys.exit()

        self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2], fpn_sizes[3], fpn_weights)

    
    def _make_layer_ganilla(self, block, planes, blocks, use_dropout, stride=1):
        strides = [stride] + [1] * (blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inplanes, planes, use_dropout, stride))
            self.inplanes = planes * block.expansion
        return nn.Sequential(*layers)

    def freeze_bn(self):
        '''Freeze BatchNorm layers.'''
        for layer in self.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()

    def forward(self, inputs):

        img_batch = inputs

        x = self.pad1(img_batch)
        x = self.conv1(x)
        x = self.in1(x)
        x = self.relu(x)
        x = self.pad2(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        out = self.fpn([x1, x2, x3, x4]) # use all resnet layers

        out = self.pad3(out)
        out = self.conv2(out)
        out = self.tanh(out)

        return out

In [None]:
#export
def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                torch.nn.init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                torch.nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                torch.nn.init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, gain)
            torch.nn.init.constant_(m.bias.data, 0.0)

    net.apply(init_func)

In [None]:
#export
def ganilla_generator(input_nc, output_nc, ngf, drop, fpn_weights=[1.0, 1.0, 1.0, 1.0], init_type='normal', gain=0.02, **kwargs):
    """Constructs a ResNet-18 GANILLA generator."""
    model = ResNet(input_nc, output_nc, ngf, drop, fpn_weights, BasicBlock_Ganilla, [2, 2, 2, 2],  **kwargs)
    init_weights(model,init_type='normal', gain=gain)
    return model

Let's test for a few things:
1. The generator can indeed be initialized correctly
2. A random image can be passed into the model successfully with the correct size output

First let's create a random batch:

In [None]:
img1 = torch.randn(4,3,256,256)

In [None]:
m = ganilla_generator(3,3,64,0.5)
with torch.no_grad():
    out1 = m(img1)
out1.shape

torch.Size([4, 3, 256, 256])

## Full model

We group two discriminators and two generators in a single model, then a `Callback` (defined in `02_cyclegan_training.ipynb`) will take care of training them properly. The discriminator and training loop is the same as CycleGAN.

In [None]:
#export
class GANILLA(nn.Module):
    """
    GANILLA model. \n
    When called, takes in input batch of real images from both domains and outputs fake images for the opposite domains (with the generators). 
    Also outputs identity images after passing the images into generators that outputs its domain type (needed for identity loss).

    Attributes: \n
    `G_A` (`nn.Module`): takes real input B and generates fake input A \n
    `G_B` (`nn.Module`): takes real input A and generates fake input B \n
    `D_A` (`nn.Module`): trained to make the difference between real input A and fake input A \n
    `D_B` (`nn.Module`): trained to make the difference between real input B and fake input B \n
    """
    def __init__(self, ch_in:int=3, ch_out:int=3, n_features:int=64, disc_layers:int=3, lsgan:bool=True, 
                 drop:float=0., norm_layer:nn.Module=None, fpn_weights:list=[1.0, 1.0, 1.0, 1.0], init_type:str='normal', gain:float=0.02,**kwargs):
        """
        Constructor for GANILLA model.
        
        Arguments: \n
        `ch_in` (`int`): Number of input channels (default=3) \n
        `ch_out` (`int`): Number of output channels (default=3) \n
        `n_features` (`int`): Number of input features (default=64) \n
        `disc_layers` (`int`): Number of discriminator layers (default=3) \n
        `lsgan` (`bool`): LSGAN training objective (output unnormalized float) or not? (default=True) \n
        `drop` (`float`): Level of dropout (default=0) \n
        `norm_layer` (`nn.Module`): Type of normalization layer to use in the discriminator (default=None)
        `fpn_weights` (`list`): Weights for feature pyramid network (default=[1.0, 1.0, 1.0, 1.0]) \n
        `init_type` (`str`): Type of initialization (default='normal') \n
        `gain` (`float`): Gain for initialization (default=0.02)
        """
        
        super().__init__()
        #G_A: takes real input B and generates fake input A
        #G_B: takes real input A and generates fake input B
        #D_A: trained to make the difference between real input A and fake input A
        #D_B: trained to make the difference between real input B and fake input B
        self.D_A = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
        self.D_B = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
        self.G_A = ganilla_generator(ch_in, ch_out, n_features, drop, fpn_weights, init_type, gain, **kwargs)
        self.G_B = ganilla_generator(ch_in, ch_out, n_features, drop, fpn_weights, init_type, gain, **kwargs)
        
    
    def forward(self, input):
        """Forward function for CycleGAN model. The input is a tuple of a batch of real images from both domains A and B."""
        real_A, real_B = input
        fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)
        idt_A, idt_B = self.G_A(real_A), self.G_B(real_B) #Needed for the identity loss during training.
        return [fake_A, fake_B, idt_A, idt_B]

In [None]:
show_doc(GANILLA,title_level=3)

<h3 id="GANILLA" class="doc_header"><code>class</code> <code>GANILLA</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>GANILLA</code>(**`ch_in`**:`int`=*`3`*, **`ch_out`**:`int`=*`3`*, **`n_features`**:`int`=*`64`*, **`disc_layers`**:`int`=*`3`*, **`lsgan`**:`bool`=*`True`*, **`drop`**:`float`=*`0.0`*, **`norm_layer`**:`Module`=*`None`*, **`fpn_weights`**:`list`=*`[1.0, 1.0, 1.0, 1.0]`*, **`init_type`**:`str`=*`'normal'`*, **`gain`**:`float`=*`0.02`*, **\*\*`kwargs`**) :: `Module`

GANILLA model. 

When called, takes in input batch of real images from both domains and outputs fake images for the opposite domains (with the generators). 
Also outputs identity images after passing the images into generators that outputs its domain type (needed for identity loss).

Attributes: 

`G_A` (`nn.Module`): takes real input B and generates fake input A 

`G_B` (`nn.Module`): takes real input A and generates fake input B 

`D_A` (`nn.Module`): trained to make the difference between real input A and fake input A 

`D_B` (`nn.Module`): trained to make the difference between real input B and fake input B 

In [None]:
show_doc(GANILLA.__init__)

<h4 id="GANILLA.__init__" class="doc_header"><code>GANILLA.__init__</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>GANILLA.__init__</code>(**`ch_in`**:`int`=*`3`*, **`ch_out`**:`int`=*`3`*, **`n_features`**:`int`=*`64`*, **`disc_layers`**:`int`=*`3`*, **`lsgan`**:`bool`=*`True`*, **`drop`**:`float`=*`0.0`*, **`norm_layer`**:`Module`=*`None`*, **`fpn_weights`**:`list`=*`[1.0, 1.0, 1.0, 1.0]`*, **`init_type`**:`str`=*`'normal'`*, **`gain`**:`float`=*`0.02`*, **\*\*`kwargs`**)

Constructor for GANILLA model.

Arguments: 

`ch_in` (`int`): Number of input channels (default=3) 

`ch_out` (`int`): Number of output channels (default=3) 

`n_features` (`int`): Number of input features (default=64) 

`disc_layers` (`int`): Number of discriminator layers (default=3) 

`lsgan` (`bool`): LSGAN training objective (output unnormalized float) or not? (default=True) 

`drop` (`float`): Level of dropout (default=0) 

`norm_layer` (`nn.Module`): Type of normalization layer to use in the discriminator (default=None)
`fpn_weights` (`list`): Weights for feature pyramid network (default=[1.0, 1.0, 1.0, 1.0]) 

`init_type` (`str`): Type of initialization (default='normal') 

`gain` (`float`): Gain for initialization (default=0.02)

In [None]:
show_doc(GANILLA.forward)

<h4 id="GANILLA.forward" class="doc_header"><code>GANILLA.forward</code><a href="__main__.py#L43" class="source_link" style="float:right">[source]</a></h4>

> <code>GANILLA.forward</code>(**`input`**)

Forward function for CycleGAN model. The input is a tuple of a batch of real images from both domains A and B.

### Quick model tests

Again, let's check that the model can be called sucsessfully and outputs the correct shapes.


In [None]:
ganilla_model = GANILLA()
img1 = torch.randn(4,3,256,256)
img2 = torch.randn(4,3,256,256)

In [None]:
%%time
with torch.no_grad(): ganilla_output = ganilla_model((img1,img2))

CPU times: user 18.4 s, sys: 3.26 s, total: 21.6 s
Wall time: 21.6 s


In [None]:
test_eq(len(ganilla_output),4)
for output_batch in ganilla_output:
    test_eq(output_batch.shape,img1.shape)

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_models.cyclegan.ipynb.
Converted 01b_models.junyanz.ipynb.
Converted 02_data.unpaired.ipynb.
Converted 03_train.cyclegan.ipynb.
Converted 04_inference.cyclegan.ipynb.
Converted 05_metrics.ipynb.
Converted 06_tracking.wandb.ipynb.
Converted 07_models.dualgan.ipynb.
Converted 08_train.dualgan.ipynb.
Converted 09_models.ganilla.ipynb.
Converted index.ipynb.
