# Sure! Here's an explanation of each step of the code:
1.	Import necessary modules: We start by importing the necessary modules, including PyTorch, torch.nn, and the CelebA dataset from torchvision.


In [1]:
pip install torchvision=1.13

Note: you may need to restart the kernel to use updated packages.


ERROR: Invalid requirement: 'torchvision=1.13'
Hint: = is not a valid operator. Did you mean == ?


In [7]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.utils.parametrizations
# Define the device to use for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# 2.	Define the device: We define the device to use for training. If a GPU is available, we use it, otherwise we use the CPU.

In [None]:
# Define the transforms to be applied to the images
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 3.	Define the transforms: We define a series of transformations to be applied to the images in the dataset. These transformations resize the images to 64x64 pixels, center crop them, convert them to tensors, and normalize them.

In [None]:
# Define the dataset
dataset = datasets.CelebA(root='./data', split='train', transform=transform, download=True)

# 4.	Define the dataset: We define the CelebA dataset, specifying the path to the dataset, the split to use (train), the transform to apply to the images, and whether to download the dataset if it doesn't already exist.

In [None]:
# Define the dataloader
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

# 5.	Define the dataloader: We define a dataloader that will be used to iterate over the dataset in batches during training. The dataloader takes the dataset as input, along with the batch size, whether to shuffle the dataset, and the number of workers to use for loading the data.

# 6.	Define the Self-Attention module: We define a module that implements self-attention, which will be used in the generator network. The self-attention module takes as input a tensor and returns a tensor with attention applied. The module consists of three convolutional layers for computing the query, key, and value, followed by a gamma parameter for scaling the attention output, and a softmax activation function.

In [None]:
# Define the Self-Attention module
class SelfAttentionCnn(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttentionCnn, self).__init__()

        self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1)

        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        
        batch_size, channels, height, width = x.size()

        # Compute the query, key, and value tensors
        query = self.query_conv(x).view(batch_size, -1, height*width).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, height*width)
        value = self.value_conv(x).view(batch_size, -1, height*width)

        # Compute the attention map
        energy = torch.bmm(query, key)
        attention = self.softmax(energy)

        # Apply attention to the value tensor
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)

        # Scale the output tensor by learnable parameter gamma and apply to input tensor
        out = self.gamma * out + x

        return out

# The Self-Attention layer is a key component of the Self-Attention GAN model. It helps the generator to focus on important parts of the image and generate high-quality images.

#Next, the generator network is defined:

In [None]:
# Define the Generator network
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()

        self.z_dim = z_dim
        #spectral normalization (SN) is used as proposed by Miyato et al. (Miyato et al., 2018)
        
        self.linear = nn.Linear(z_dim, 4*4*1024)
        self.conv1 =  register_parametrization(nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, 2, 1)
        self.bn1 = nn.BatchNorm2d(512)
        self.conv2 =  register_parametrization(nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, 2, 1)
        self.bn2 = nn.BatchNorm2d(256)
        self.self_attention = SelfAttentionCnn(in_channels=256)
        self.conv3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, 2, 1)
        sself.conv4 =  register_parametrization(nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4,2, 1))

        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z):
        out = self.linear(z)
        out = out.view(-1, 1024, 4, 4)
        out = self.relu(self.bn1(self.conv1(out)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.self_attention(out)
        out = self.relu(self.bn3(self.conv3(out)))
        out = self.tanh(self.conv4(out))

        return out


In [None]:
# Define the Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        #spectral normalization (SN) is used as proposed by Miyato et al. (Miyato et al., 2018)
        
        self.conv1 =  register_parametrization(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, 2, 1))
        self.conv2 =  register_parametrization(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, 2, 1))
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 =  register_parametrization(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, 2, 1))
        self.bn3 = nn.BatchNorm2d(256)
        self.self_attention = SelfAttentionCnn(in_channels=256)
        self.conv4 =  register_parametrization(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, 2, 1))
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 =  register_parametrization(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, 1, 0))

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # add leaky ReLU as used in original paper
        out = self.leaky_relu(self.conv1(x))
        out = self.leaky_relu(self.bn2(self.conv2(out)))
        out = self.leaky_relu(self.bn3(self.conv3(out)))
        out = self.self_attention(out)
        out = self.leaky_relu(self.bn4(self.conv4(out)))
        out = self.sigmoid(self.conv5(out))

        return out

In [None]:
# Define the loss function
criterion = nn.BCELoss()

# Define the learning rate
lr = 0.0002

# Define the number of epochs
num_epochs = 200

# Define the generator and discriminator networks
z_dim = 100
G = Generator(z_dim=z_dim).to(device)
D = Discriminator().to(device)

# Define the optimizer for the generator and discriminator networks
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
# Define the function to train the networks
def train(G, D, train_loader, criterion, G_optimizer, D_optimizer, z_dim, num_epochs):
    for epoch in range(num_epochs):
        for i, (real_imag, _) in enumerate(train_loader):
            batch_size = real_imag.size(0)
            real_imag = real_imag.to(device)

            # Training the Discriminator with real images
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)
            outputs = D(real_imag)
            D_loss_real = criterion(outputs, real_labels)
            
             # Train discriminator with fake images
            noise = torch.randn(batch_size, z_dim, device=device)
            fake_imag = generator(noise)
            fake_labels = torch.zeros(batch_size, device=device)
            fake_logits = discriminator(fake_imag.detach())
            d_loss_fake = criterion(fake_logits, fake_labels)
            d_loss_fake.backward()

            d_loss = d_loss_real + d_loss_fake
            d_optimizer.step()

            # Train generator
            g_optimizer.zero_grad()
            noise = torch.randn(batch_size, z_dim, device=device)
            fake_imag = generator(noise)
            fake_labels = torch.ones(batch_size, device=device)
            fake_logits = discriminator(fake_imag)
            g_loss = criterion(fake_logits, fake_labels)
            g_loss.backward()
            g_optimizer.step()

            # Print losses and save sample images periodically
            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i+1}/{len(dataloader)}] "
                      f"D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}")


# The train function takes as input the generator, discriminator, dataloader, num_epochs, device, and lr. 
#It initializes the loss functions and optimizers and generates a fixed set of noise vectors for visualization. 
#It then loops over the specified number of epochs and the batches in the dataloader.
#For each batch, it trains the discriminator with real and fake images and the generator with fake images. 
#It then prints the losses and saves sample images periodically. 
#Finally, it returns the trained generator and discriminator.

In [None]:
# Save sample images
    with torch.no_grad():
        fake_imag = generator(fixed_noise).detach().cpu()
         save_image(fake_imag, f"sample_images/epoch_{epoch+1}.png", normalize=True)