# A Pytorch Implementation of SinGAN

In [None]:
from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import io, transforms
import math


In [5]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int,
                 padding: int) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, True)
    
    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.norm(self.conv(x)))

In [None]:
class Discriminator(nn.Module):
    def __init__(self,
                 in_channels: int = 3,
                 out_channels: int = 32,
                 kernel_size: int = 3,
                 padding: int = 1,
                 num_layers: int = 3) -> None:
        super(Discriminator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, 1, padding)
        self.body = nn.Sequential(
            ConvBlock(out_channels, out_channels, kernel_size, 1, padding)
            for _ in range(num_layers)
        )
        self.tail = ConvBlock(out_channels, 1, kernel_size, 1, padding)
        self.pad = nn.ZeroPad2d(5)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        x = self.pad(x)
        x = self.head(x)
        x = self.body(x)
        return self.tail(x)

In [None]:
class Generator(nn.Module):
    def __init__(self,
                 in_channels: int = 3,
                 out_channels: int = 32,
                 kernel_size: int = 3,
                 padding: int = 1,
                 num_layers: int = 3) -> None:
        super(Generator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, 1, padding)
        self.body = nn.Sequential(
            ConvBlock(out_channels, out_channels, kernel_size, 1, padding)
            for _ in range(num_layers)
        )
        self.tail = nn.Sequential(
            ConvBlock(out_channels, in_channels, kernel_size, 1, padding),
            nn.Tanh()
        )
        self.pad = nn.ZeroPad2d(5)
    
    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        x = self.pad(x)
        return x + self.tail(self.body(self.head(x)))

In [None]:
b = torch.randn(3, 4, 5)
b.mean()

In [8]:
img_size_min, img_size_max = 25, 250
scale_factor = 4 / 3
number_of_scales = int(math.log(img_size_max / img_size_min, scale_factor)) # 8
size_list = [round(img_size_min * pow(scale_factor, i)) for i in range(number_of_scales + 1)]
size_list

[25, 33, 44, 59, 79, 105, 140, 187, 250]

In [9]:
[32 * pow(2, i // 4) for i in range(9)]

[32, 32, 32, 32, 64, 64, 64, 64, 128]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_features = 32
generators = [Generator(out_channels=n_features * pow(2, i // 4)).to(device) for i in range(number_of_scales + 1)]
discriminators = [Discriminator(out_channels=n_features * pow(2, i // 4)).to(device) for i in range(number_of_scales + 1)]
lr = 5e-4
beta1 = 0.5
gamma = 0
max_epochs = 2000
d_iter = g_iter = 3
alpha = 10.0

In [None]:
path = '../data/'
img = io.read_image(path).to(device)

In [None]:
# update stage by stage
for stage, (generator, discriminator, size) in enumerate(zip(generators, discriminators, size_list)):
    optim_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    optim_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
    scheduler_g = torch.optim.lr_scheduler.MultiStepLR(optim_g, milestones=[1600], gamma=gamma)
    scheduler_d = torch.optim.lr_scheduler.MultiStepLR(optim_d, milestones=[1600], gamma=gamma)
    
    image = F.interpolate(img, size, mode='bilinear', align_corners=True)
    x = torch.randn(*image.shape, device=device) if stage == 0 else None
    for epoch in max_epochs:
        # update Discriminator
        for _ in range(d_iter):
            y = discriminator()
        
        # update Generator
        for _ in range(g_iter):
            fake = generator(x)
            loss = F.mse_loss(image, fake)
            

In [23]:
~ 3

-4

In [25]:
"3".isdigit() << 1

2