In [1]:
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
Learning_rate = 1e-4
Num_epochs = 1000
Batch_size = 200
image_channel = 1
IMAGE_SIZE = 64

In [3]:
class Generator(nn.Module):
    def __init__(self, channel_noise, features_g = [128,128,256]):
        super(Generator,self).__init__()
        self.ip_net = nn.Linear(in_features=channel_noise, out_features=64*10*10) #32*4*4
        self.res1 = residual_block(in_channels=64,out_channels=features_g[0])
        self.up1 = Upsampler(factor=1/3)
        self.res2 = residual_block(in_channels=features_g[0],out_channels=features_g[1])
        self.up2 = Upsampler(factor=2/3)
        self.res3 = residual_block(in_channels=features_g[1],out_channels=features_g[2])
        self.up3 = Upsampler(factor=1.0)
        self.res4 = residual_block(in_channels = features_g[2], out_channels = 1)

    def forward(self, x1,x2):
        x1 = self.ip_net(x1)
        x1 = x1.view(-1,64,10,10)
        x1 = nn.ReLU()(self.res1(x1))
        x1 = self.up1([x1,x2])
        x1 = nn.ReLU()(self.res2(x1))
        x1 = self.up2([x1,x2])
        x1 = nn.ReLU()(self.res3(x1))
        x1 = self.up3([x1,x2])
        x1 = torch.tanh(self.res4(x1))
        return x1

class Upsampler(nn.Module):
    def __init__(self, factor):
        super(Upsampler, self).__init__()
        self.factor = factor
        
    def forward(self,x):
        input_img = x[0]
        size_ip = torch.squeeze(x[1][0])
        size_tup = [int(size_ip[0]*self.factor), int(size_ip[1]*self.factor)]
        return torch.nn.functional.interpolate(input_img, size = size_tup, mode='nearest')
    
    
class residual_block(nn.Module):
    def __init__(self, in_channels, out_channels, activation='relu'):
        super(residual_block,self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.activation = activation
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=5, stride=1, padding=2,bias='False'),
                                   nn.BatchNorm2d(out_channels))
        self.conv2 = nn.Sequential(nn.Conv2d(out_channels,out_channels,kernel_size=5, stride=1, padding=2, bias='False'),
                                   nn.BatchNorm2d(out_channels))
        self.relu = nn.ReLU(inplace=True)
        self.identity = nn.Conv2d(in_channels, out_channels, kernel_size= 1, stride = 1, padding=0)

    def forward(self,x):
        copy = self.relu(self.identity(x))
        y = self.relu(self.conv1(x))
        y = self.conv2(y)
        y += copy
        return self.relu(y)

In [4]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [5]:
class dis_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, stride, padding):
        super(dis_block,self).__init__()
        self.dis_net = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel, stride, padding, bias=False),
                                nn.InstanceNorm2d(out_channels, affine=True),
                                nn.LeakyReLU(0.2))
    def forward(self,x):
        return self.dis_net(x)

class Discriminator(nn.Module):
    def __init__(self, channels_img):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(nn.Conv2d(channels_img, 64, kernel_size=5, stride=1, padding=2),
                                nn.LeakyReLU(0.2),
                                dis_block(64,64,5,2,1),
                                 dis_block(64,128,5,2,1),
                                 dis_block(128,128,5,2,1),
                                 nn.AdaptiveAvgPool2d((1,1)),
                                )
        self.lin = nn.Linear(128,1)
        
    def forward(self,x):
        x = self.net(x)
        x = x.view(x.shape[0],-1)
        x = self.lin(x)
        # x = torch.sigmoid(self.lin(x))
#         x = self.lin(x)
        return x
        


In [6]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [7]:
transforms1 = transforms.Compose(
    [
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize(
           [0.5 for _ in range(image_channel)], [0.5 for _ in range(image_channel)]
       ),
    ]
)


transforms2 = transforms.Compose(
    [
        transforms.Resize((40,50)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(image_channel)], [0.5 for _ in range(image_channel)]
        ),
    ]
)


transforms3 = transforms.Compose(
    [
        transforms.Resize((60,90)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(image_channel)], [0.5 for _ in range(image_channel)]
        ),
    ]
)

In [8]:
dataset1 = datasets.MNIST(root="dataset/", transform=transforms1, download=True)
dataset2 = datasets.MNIST(root="dataset/", transform=transforms2, download=True)
dataset3 = datasets.MNIST(root="dataset/", transform=transforms3, download=True)

In [9]:
loader1 = DataLoader(dataset1, batch_size=Batch_size, shuffle=True)
loader2 = DataLoader(dataset2, batch_size=Batch_size, shuffle=True)
loader3 = DataLoader(dataset3, batch_size=Batch_size, shuffle=True)

In [10]:
gen = Generator(128).to(device)
dis = Discriminator(image_channel).to(device)

initialize_weights(gen)
initialize_weights(dis)

In [11]:
opt_gen = optim.Adam(gen.parameters(), lr = Learning_rate*0.8, betas=(0.0, 0.9))
opt_dis = optim.Adam(dis.parameters(), lr = Learning_rate*0.75, betas=(0.0, 0.9))

In [12]:
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
writer = SummaryWriter(f"logs/graphs")
step = 0

In [13]:
gen.train()
dis.train()

fixed_noise = torch.randn(32, 128).to(device)
# fixed_size = (torch.ones(32,2)*torch.tensor([32,32])).to(device)
for epoch in range(Num_epochs):
    batch_idx = 0
    for (real_1, _),(real_2, _),(real_3, _) in zip(loader1,loader2,loader3):
      for real in [real_1, real_2, real_3]:  
          real = real.to(device)
          cur_batch_size = real.shape[0]

          for _ in range(5):
            noise = torch.randn(Batch_size,128).to(device)
            size = [real.shape[-2], real.shape[-1]]
            size = torch.ones(Batch_size,2)*torch.tensor(size)


            fake = gen(noise,size)

            disc_real = dis(real).reshape(-1)
            disc_fake = dis(fake).reshape(-1)

            gp = gradient_penalty(dis, real, fake, device = device)
            loss_disc = (-(torch.mean(disc_real)-torch.mean(disc_fake))+10*gp)

            dis.zero_grad()
            loss_disc.backward(retain_graph=True)
            opt_dis.step()

          output = dis(fake).reshape(-1)
          loss_gen = -torch.mean(output)

          gen.zero_grad()
          loss_gen.backward()
          opt_gen.step()

          if batch_idx % 100 == 0:
              print(
                  f"Epoch [{epoch}/{Num_epochs}] Batch {batch_idx}/{len(loader1)} \
                    Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
              )

          with torch.no_grad():
              fake = gen(fixed_noise,size[:32].detach())
              img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
              img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
            
              torchvision.utils.save_image(img_grid_real,f'real_{epoch}.jpg')
              torchvision.utils.save_image(img_grid_fake,f'fake_{epoch}.jpg') 
          writer_real.add_image("Real", img_grid_real, global_step=step)
          writer_fake.add_image("Fake", img_grid_fake, global_step=step)
          writer.add_scalar('Loss/Discriminator',loss_disc, step)
          writer.add_scalar('Loss/Generator',loss_gen, step)
          batch_idx+=1
      step +=1
    if (epoch+1)%2==0:
        torch.save(gen, f'generator_{epoch}.pt')
        torch.save(dis, f'discriminator_{epoch}.pt')
      


Epoch [0/1000] Batch 0/300                     Loss D: 3.8464, loss G: 0.0601
Epoch [0/1000] Batch 100/300                     Loss D: -1.0453, loss G: 0.5664
Epoch [0/1000] Batch 200/300                     Loss D: -2.2722, loss G: 0.8590
Epoch [0/1000] Batch 300/300                     Loss D: -2.6493, loss G: 1.1972
Epoch [0/1000] Batch 400/300                     Loss D: -5.1867, loss G: 2.5158
Epoch [0/1000] Batch 500/300                     Loss D: -7.6547, loss G: 3.8163


KeyboardInterrupt: 