In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import models


In [22]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import models
import sys

## Factors for discriminator and generator channel changes
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]


class WSConv2d(nn.Module):          #Weighted Scaled convolutional layers
    '''
    Weighted scaled Conv2d (Equalized Learning Rate - Section 4.1 of Notes)
    '''
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,gain=2):
        super(WSConv2d,self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain/(in_channels*(kernel_size**2)))**0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)


    def forward(self,x):
        return self.conv(x*self.scale) + self.bias.view(1,self.bias.shape[0],1,1)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, norm_layer='PixelNorm', use_pixelnorm=True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)


class Generator_ProGAN(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3, norm_layer = 'PixelNorm', factors =[]):
        super(Generator_ProGAN, self).__init__()
        # self.norm_layer = getattr(module_norm, norm_layer)/
        # initial takes 1x1 -> 4x4
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(
            len(factors) - 1
        ):  # -1 to prevent index error because of factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            print('i : ', i, ' | conv_in_c : ', conv_in_c)
            conv_out_c = int(in_channels * factors[i + 1])
            print('i : ', i, ' | conv_out_c : ', conv_out_c)
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):     # steps=0 (4x4), steps=1 (8x8), ...
        print('steps : ', steps)
        print('x.shape : ', x.shape)
        out = self.initial(x)   # 4x4
        print('out.shape : ', out.shape)

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            print('step:',step)
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            print('upscaled : ', upscaled.shape)
            # print(self.prog_blocks[step])
            
            out = self.prog_blocks[step](upscaled)
            print('out.shape : ', out.shape)

        # The number of channels in upscale will stay the same, while
        # out which has moved through prog_blocks might change. To ensure
        # we can convert both to rgb we use different rgb_layers
        # (steps-1) and steps for upscaled, out respectively

        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        print('final_upscaled : ', final_upscaled.shape)
        final_out = self.rgb_layers[steps](out)
        print('final_out : ', final_out.shape)
        return self.fade_in(alpha, final_upscaled, final_out)



class Discriminator_ProGAN(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3, factors =[]):
        super(Discriminator_ProGAN,self).__init__()

        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)
        # Work Backwords form factor
        for i in range(len(factors)-1, 0,-1):
            conv_in = int(in_channels*factors[i])
            conv_out = int(in_channels*factors[i-1])
            self.prog_blocks.append(ConvBlock(conv_in,conv_out,use_pixelnorm=False))
            self.rgb_layers.append(WSConv2d(img_channels,conv_in,kernel_size=1,stride=1,padding=0))
        
        
        # RGB layer for 4x4 input size, to "mirror" the generator initial_rgb
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)

        # DownSampling
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        # Block for 4x4 input size
        self.final_block = nn.Sequential(
            # +1 to in_channels because we concatenate from MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        # we take the std for each example (across all channels, and pixels) then we repeat it
        # for a single channel and concatenate it with the image. In this way the discriminator
        # will get information about the variation in the batch/image
        return torch.cat([x, batch_statistics], dim=1)


    def forward(self, x, alpha, steps):
        print('Discriminatpor forward')
        cur_step = len(self.prog_blocks)-steps
        print(cur_step)
        # Convert from rgb as initial step
        print('x: ', x.shape)
        print(self.rgb_layers[cur_step])
        out = self.leaky(self.rgb_layers[cur_step](x))
        
        print('OUT: ', out.shape)

        if steps==0:
            out = self.minibatch_std(out)
            print('step0 out : ', out.shape)
            out = self.final_block(out)
            print('step0 out : ', out.shape)
            return out.view(out.shape[0],-1)

        downscaled = self.leaky(self.rgb_layers[cur_step+1](self.avg_pool(x)))
        print('downscaled ; ', downscaled.shape)
        out = self.avg_pool(self.prog_blocks[cur_step](out))
        print('out ; ', out.shape)
        out = self.fade_in(alpha, downscaled, out)
        print('out ; ', out.shape)

        for step in range(cur_step+1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            print('out ; ', out.shape)
            out = self.avg_pool(out)
            print('out ; ', out.shape)

        out = self.minibatch_std(out)
        print('out ; ', out.shape)
        return self.final_block(out).view(out.shape[0], -1)


In [23]:
Z_DIM = 100
IN_CHANNELS = 256
gen = Generator_ProGAN(Z_DIM, IN_CHANNELS, img_channels=3,factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32])
critic = Discriminator_ProGAN(Z_DIM, IN_CHANNELS, img_channels=3,factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32])
from math import log2

i :  0  | conv_in_c :  256
i :  0  | conv_out_c :  256
i :  1  | conv_in_c :  256
i :  1  | conv_out_c :  256
i :  2  | conv_in_c :  256
i :  2  | conv_out_c :  256
i :  3  | conv_in_c :  256
i :  3  | conv_out_c :  128
i :  4  | conv_in_c :  128
i :  4  | conv_out_c :  64
i :  5  | conv_in_c :  64
i :  5  | conv_out_c :  32
i :  6  | conv_in_c :  32
i :  6  | conv_out_c :  16
i :  7  | conv_in_c :  16
i :  7  | conv_out_c :  8


In [24]:
for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
    num_steps = int(log2(img_size / 4))
    print(num_steps, img_size)

0 4
1 8
2 16
3 32
4 64
5 128
6 256
7 512
8 1024


In [25]:
critic

Discriminator_ProGAN(
  (prog_blocks): ModuleList(
    (0): ConvBlock(
      (conv1): WSConv2d(
        (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (conv2): WSConv2d(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (leaky): LeakyReLU(negative_slope=0.2)
      (pn): PixelNorm()
    )
    (1): ConvBlock(
      (conv1): WSConv2d(
        (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (conv2): WSConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (leaky): LeakyReLU(negative_slope=0.2)
      (pn): PixelNorm()
    )
    (2): ConvBlock(
      (conv1): WSConv2d(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (conv2): WSConv2d(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [26]:
num_steps = 0
img_size = 4
x = torch.randn((1, Z_DIM, 1, 1))
z = gen(x, 0.5, steps=num_steps)
print(z.shape)
assert z.shape == (1, 3, img_size, img_size)
out = critic(z, alpha=0.5, steps=num_steps)
assert out.shape == (1, 1)
print(f"Success! At img size: {img_size}")

steps :  0
x.shape :  torch.Size([1, 100, 1, 1])
out.shape :  torch.Size([1, 256, 4, 4])
torch.Size([1, 3, 4, 4])
Discriminatpor forward
8
x:  torch.Size([1, 3, 4, 4])
WSConv2d(
  (conv): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
OUT:  torch.Size([1, 256, 4, 4])
step0 out :  torch.Size([1, 257, 4, 4])
step0 out :  torch.Size([1, 1, 1, 1])
Success! At img size: 4


In [30]:
for i in range(2):
    print(i)

0
1


In [27]:
def init_obj(self, name, module, *args, **kwargs):
    """
    Finds a function handle with the name given as 'type' in config, and returns the
    instance initialized with corresponding arguments given.

    `object = config.init_obj('name', module, a, b=1)`
    is equivalent to
    `object = module.name(a, b=1)`
    """
    print(module)
    module_name = self[name]['type']
    print('module name : ',module_name)
    module_args = dict(self[name]['args'])

    ## Change 'Beta' type to TUPLE for ADAM OPTIMIZER
    if 'optimizer' in module_name and 'betas' in module_args.keys():
        module_args['betas'] = tuple(module_args['betas'])

    print(module_args)
    assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
    module_args.update(kwargs)
    print(module_args)
    return getattr(module, module_name)(*args, **module_args)


In [33]:
from data_loader import KNUskinDataLoader_ProGAN

ModuleNotFoundError: No module named 'data_loader'

In [62]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [63]:
model = ResNet2d(ResBlk2d, inplanes= 64, kernel_size=3, stride=2, num_classes=2, dropout=0.5).to(device)
x = torch.randn(3, 3, 224, 224).to(device)

-------------------------------------------
block id 0
kernel_size =  3
inplanes =  64
outplanes =  64
stride =  2
-------------------------------------------
-------------------------------------------
block id 1
kernel_size =  3
inplanes =  64
outplanes =  128
stride =  2
-------------------------------------------
-------------------------------------------
block id 2
kernel_size =  3
inplanes =  128
outplanes =  256
stride =  2
-------------------------------------------
-------------------------------------------
block id 3
kernel_size =  3
inplanes =  256
outplanes =  512
stride =  2
-------------------------------------------


In [65]:
output = model(x)

x :  torch.Size([3, 3, 224, 224]) aux :  torch.Size([3, 150528])
conv1 :  torch.Size([3, 64, 112, 112])
-------------------
layer0 torch.Size([3, 64, 56, 56])
identity =  torch.Size([3, 64, 56, 56])
X_shape =  torch.Size([3, 64, 56, 56])
conv1 shape =  torch.Size([3, 64, 28, 28])
conv2 shape =  torch.Size([3, 64, 28, 28])
 out =  torch.Size([3, 64, 28, 28])
downsample =  Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
identity2 =  torch.Size([3, 64, 28, 28])
torch.Size([3, 64, 28, 28])
dropout out =  torch.Size([3, 64, 28, 28])
identity =  torch.Size([3, 64, 28, 28])
X_shape =  torch.Size([3, 64, 28, 28])
conv1 shape =  torch.Size([3, 64, 28, 28])
conv2 shape =  torch.Size([3, 64, 28, 28])
 out =  torch.Size([3, 64, 28, 28])
downsample =  None
identity2 =  torch.Size([3, 64, 28, 28])
torch.Size([3, 64, 28, 28])
dropout out =  torch.Size([3, 64, 28, 28])
-------------------
layer1 torch.Size([3, 64, 28, 28])
identity =  torch.Size([3, 64, 28, 28])
X_shape =  torch.Size([3,