<p style="text-align: center; font-size:50px;">About SAGAN</p>

#### This is an attempt to replicate a Self Attention GAN, as introduced from this [paper](http://proceedings.mlr.press/v97/zhang19d/zhang19d.pdf)

<p style="text-align: center; font-size:30px;">Data</p>

In [46]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA

device = torch.device('mps')

# Utility function to denormalize the images (since normalization would be applied to them)
def denorm(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)
  
# Fix a random latent input for samples
batch_size = 64
fixed_z = torch.randn(batch_size, 100).to(device)

# Define data transformer
img_transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Read data and transform
dataset = CelebA(root='/Users/kimhyunbin/Documents/Python/My own project (Python)/PyTorch_Guide/data', download=True, split='all', transform=img_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


<p style="text-align: center; font-size:30px;">Model Architecture</p>

#### Constructing the self-attention module first.
#### Credits to [this](https://towardsdatascience.com/building-your-own-self-attention-gans-e8c9b9fe8e51) website.

![self-attention-image](https://miro.medium.com/v2/resize:fit:720/format:webp/1*GSWQYI3ZfYe-MlKQGjTFWA.png)

In [47]:
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class Self_Attn(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        
        # Construct the conv layers
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        
        # Initialize gamma as 0
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
        
    def forward(self,x):
        """
            inputs :
                x : input feature maps( Batch_size * Channels * Width * Height ) 
            returns :
                out : self attention value + input feature 
                attention: B * N * N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        
        proj_query  = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0,2,1) # batch * (width*height) * channels
        proj_key =  self.key_conv(x).view(m_batchsize, -1, width*height) # batch * channels * (width*height)
        energy =  torch.bmm(proj_query, proj_key) # batch matrix-matrix product
        # Read more about bmm function here (https://pytorch.org/docs/stable/generated/torch.bmm.html)
        
        attention = self.softmax(energy) # B * N * N
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) # B * C * N
        out = torch.bmm(proj_value, attention.permute(0,2,1)) # batch matrix-matrix product
        out = out.view(m_batchsize,C,width,height) # B * C * W * H
        
        # Add attention weights onto input
        out = self.gamma*out + x
        return out, attention

#### Building Discriminator and Generator

![image](https://miro.medium.com/v2/resize:fit:720/format:webp/1*IZ9lvFpfDqyeE2OBmC0j4A.png)

#### Be aware that in the picture above, it is dealing with only 1 color channels. 
#### However, in my case, I am using the CelebA dataset which contains 3 color channels. 
#### This would require some changes in the architecture so please take a look at your own code and make sure the architecture matches your dataset.
#### If not, you will definitely run into errors when training.

In [48]:
from torch.nn.utils import spectral_norm

class Generator(nn.Module):

    def __init__(self, batch_size=64, attn=True, image_size=64, z_dim=100, conv_dim=64):
        super().__init__()
        self.attn = attn
        
        # Layer 1 turn 100 dims -> 512 dims, size 1 -> 4
        layer1 = []
        layer1.append(spectral_norm(nn.ConvTranspose2d(in_channels = z_dim, out_channels = conv_dim*8, kernel_size = 4)))
        layer1.append(nn.BatchNorm2d(conv_dim*8))
        layer1.append(nn.ReLU())
        self.l1 = nn.Sequential(*layer1)
        
        # Layer 2 turn 512 dims -> 256 dims, size 4 -> 8
        layer2 = []
        layer2.append(spectral_norm(nn.ConvTranspose2d(in_channels = conv_dim*8, out_channels = conv_dim*4, 
                                                      kernel_size = 4, stride = 2, padding = 1)))
        layer2.append(nn.BatchNorm2d(conv_dim*4))
        layer2.append(nn.ReLU())
        self.l2 = nn.Sequential(*layer2)
        
        # Layer 3 turn 256 dims -> 128 dims, size 8 -> 16
        layer3 = []
        layer3.append(spectral_norm(nn.ConvTranspose2d(in_channels = conv_dim*4, out_channels = conv_dim*2, 
                                                      kernel_size = 4, stride = 2, padding = 1)))
        layer3.append(nn.BatchNorm2d(conv_dim*2))
        layer3.append(nn.ReLU())
        self.l3 = nn.Sequential(*layer3)

        # Attn1 layer turn 128 dims -> 128 dims
        self.attn1 = Self_Attn(conv_dim*2)
        
        # Layer 4 turn 128 dims -> 64 dims, size 16 -> 32
        layer4 = []
        layer4.append(spectral_norm(nn.ConvTranspose2d(in_channels = conv_dim*2, out_channels = conv_dim, 
                                                      kernel_size = 4, stride = 2, padding = 1)))
        layer4.append(nn.BatchNorm2d(conv_dim))
        layer4.append(nn.ReLU())
        self.l4 = nn.Sequential(*layer4)
        
        # Attn2 layer turn 64 dims -> 64 dims
        self.attn2 = Self_Attn(conv_dim)
        
        # Layer 5 turn 64 dims -> 3 dims, size 32 -> 64
        layer5 = []
        layer5.append(nn.ConvTranspose2d(conv_dim, 3, 4, 2, 1))
        layer5.append(nn.Tanh())
        self.l5 = nn.Sequential(*layer5)
        

    def forward(self, z):
        # z is the input random matrix for generator
        z = z.view(z.size(0), z.size(1), 1, 1)
        out=self.l1(z)
        out=self.l2(out)
        out=self.l3(out)
        if self.attn == True:
            out,_ = self.attn1(out)
        out=self.l4(out)
        if self.attn == True:
            out,_ = self.attn2(out)
        out=self.l5(out)

        return out


In [49]:
class Discriminator(nn.Module):
    
    def __init__(self, batch_size=64, attn=True, image_size=64, conv_dim=64):
        super().__init__()
        self.attn = attn
        
        # Layer 1 turn 3 dims -> 64 dims, size 64 -> 32
        layer1 = []
        layer1.append(spectral_norm(nn.Conv2d(3, conv_dim, 4, 2, 1)))
        layer1.append(nn.LeakyReLU(0.1))
        curr_dim = conv_dim
        self.l1 = nn.Sequential(*layer1)
        
        # Layer 2 turn 64 dims -> 128 dims, size 32 -> 16
        layer2 = []
        layer2.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
        layer2.append(nn.LeakyReLU(0.1))
        curr_dim = curr_dim * 2
        self.l2 = nn.Sequential(*layer2)
        
        # Layer 3 turn 128 dims -> 256 dims, size 16 -> 8
        layer3 = []
        layer3.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
        layer3.append(nn.LeakyReLU(0.1))
        curr_dim = curr_dim * 2
        self.l3 = nn.Sequential(*layer3)
        
        # Attn1 layer remains the same dim and size
        self.attn1 = Self_Attn(curr_dim)
        
        # Layer 4 turn 256 dims -> 512 dims, size 8 -> 4
        layer4 = []
        layer4.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
        layer4.append(nn.LeakyReLU(0.1))
        curr_dim = curr_dim * 2
        self.l4 = nn.Sequential(*layer4)
        
        # Attn2 layer remains the same dim and size
        self.attn2 = Self_Attn(curr_dim)
        
        # Layer 5 turn 512 dims -> 1 dims, size 4 -> 1
        layer5 = []
        layer5.append(nn.Conv2d(curr_dim, 1, 4, 1, 0))
        self.l5 = nn.Sequential(*layer5)

    def forward(self, x):
        out = self.l1(x)
        out = self.l2(out)
        out = self.l3(out)
        if self.attn == True:
            out,_ = self.attn1(out)
        out = self.l4(out)
        if self.attn == True:
            out,_ = self.attn2(out)
        out = self.l5(out)

        return out.squeeze()

<p style="text-align: center; font-size:30px;">Training</p>

In [51]:
import torch.optim as optim
import torchvision.utils as vutils
from torchvision.utils import save_image
import os

# Training Loop
D = Discriminator().to(device)
G = Generator().to(device)

criterion = nn.BCELoss()

# Creating batch of latent vectors to visualize the progression of the generator
fixed_noise = fixed_z

real_label = 1.
fake_label = 0. 
lr = 0.0002
beta1 = 0.5

optimizerD = optim.Adam(D.parameters(), lr = lr, betas = (beta1, 0.999))
optimizerG = optim.Adam(G.parameters(), lr = lr, betas = (beta1, 0.999))

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 200
k = 1 # Hyperparameter according to the paper

# Make directory for samples and models
cwd = os.getcwd()
post='_attn'
if not os.path.exists(cwd+'/samples_celeba'+post):
    os.makedirs(cwd+'/samples_celeba'+post)

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, (data,_) in enumerate(dataloader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        D_x = 0
        D_G_z1 = 0
        errD = 0
        for j in range(k):
        ## Train with all-real batch
            start_idx, end_idx = int(j*(batch_size/k)), int((j+1)*(batch_size/k))
            mini_data = data[start_idx:end_idx]
            D.zero_grad()
            # Format batch
            real_cpu = mini_data.to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output = D(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x += output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, 100, device=device)
            # Generate fake image batch with G
            fake = G(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = D(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            errD_fake.backward()
            D_G_z1 += output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD += (errD_real + errD_fake)
            # Update D
            optimizerD.step()
        D_x /= k
        D_G_z1 /= k
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        G.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = D(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 10 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 10 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = G(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            if (iters % 100 == 0):
                fake_images= G(fixed_z)
                save_image(denorm(fake_images), os.path.join('./samples_celeba'+post, 'Epoch:{} | Iter:{}.png'.format(epoch, iters)))

        iters += 1

Starting Training Loop...
[0/200][0/3166]	Loss_D: 56.7116	Loss_G: 0.9777	D(x): -0.0038	D(G(z)): -0.0130 / 0.3765
[0/200][10/3166]	Loss_D: 1.3652	Loss_G: 0.5089	D(x): 0.6553	D(G(z)): 0.6025 / 0.6014
[0/200][20/3166]	Loss_D: 1.3504	Loss_G: 0.5562	D(x): 0.6223	D(G(z)): 0.5738 / 0.5737
[0/200][30/3166]	Loss_D: 1.3564	Loss_G: 0.5544	D(x): 0.6208	D(G(z)): 0.5747 / 0.5747
[0/200][40/3166]	Loss_D: 1.3835	Loss_G: 0.5582	D(x): 0.5974	D(G(z)): 0.5726 / 0.5726
[0/200][50/3166]	Loss_D: 1.3767	Loss_G: 0.5567	D(x): 0.6031	D(G(z)): 0.5733 / 0.5733
[0/200][60/3166]	Loss_D: 1.3672	Loss_G: 0.5553	D(x): 0.6089	D(G(z)): 0.5741 / 0.5741
[0/200][70/3166]	Loss_D: 1.3640	Loss_G: 0.5581	D(x): 0.6083	D(G(z)): 0.5726 / 0.5726
[0/200][80/3166]	Loss_D: 1.3983	Loss_G: 0.5561	D(x): 0.5889	D(G(z)): 0.5737 / 0.5737
[0/200][90/3166]	Loss_D: 1.3584	Loss_G: 0.6893	D(x): 0.5269	D(G(z)): 0.5033 / 0.5022
[0/200][100/3166]	Loss_D: 1.3511	Loss_G: 0.6886	D(x): 0.5312	D(G(z)): 0.5025 / 0.5025
[0/200][110/3166]	Loss_D: 1.3346	Los

KeyboardInterrupt: 