<a href="https://colab.research.google.com/github/pydevcasts/MLHub/blob/master/SSGan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install torch torchvision pytorch-fid  torchinfo

Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloadin

In [96]:
# Import necessary libraries
import os  # For interacting with the operating system
import torch  # Main PyTorch library
import numpy as np  # For numerical operations
import torch.nn as nn  # For building neural network components
import torch.optim as optim  # For optimization algorithms
import torch.nn.functional as F
import matplotlib.pyplot as plt  # For plotting images and graphs
import torch.autograd as autograd
from pytorch_fid import fid_score  # For calculating the Fréchet Inception Distance (FID)
from torch.utils.data import DataLoader  # For loading data in batches
from torchvision import datasets, transforms  # For datasets and image transformations

In [94]:
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the dimensionality of the latent space for the generator
latent_dim = 64

# Set the batch size for training
batch_size = 128 #= [256, 128, 64, 32, 16, 8]

# Define the size of the images (e.g., for MNIST, images are 28x28 pixels)
image_size = 28

# Set the learning rate for the optimizer
lr = 0.0002

# Define the number of epochs for training the model
num_epochs = 20

In [98]:
# Define a series of transformations to be applied to the images
transform = transforms.Compose([
    transforms.Resize(image_size),  # Resize images to the specified image size (28x28)
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5,), (0.5,)),  # Normalize the images to have mean 0.5 and standard deviation 0.5
])

# Load the MNIST dataset
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create a DataLoader for batching and shuffling the dataset
dataloader = DataLoader(dataset,
                        batch_size=batch_size,  # Set the batch size
                        shuffle=True,  # Shuffle the dataset for each epoch
                        drop_last=True  # Drop the last incomplete batch if it is smaller than batch_size
                        )

In [99]:
import torch.nn as nn
### Cell 4: Define Residual Block for Generator
class ResidualBlockG(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=True):
        super().__init__()
        self.upsample = upsample  # Flag to determine whether to upsample

        # Define the main block of the residual block
        self.block = nn.Sequential(
            nn.BatchNorm2d(in_channels),  # Batch normalization
            nn.ReLU(),  # ReLU activation
            nn.Upsample(scale_factor=2, mode='nearest') if upsample else nn.Identity(),  # Upsampling if needed
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),  # First convolution
            nn.BatchNorm2d(out_channels),  # Batch normalization
            nn.ReLU(),  # ReLU activation
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)  # Second convolution
        )

        # Define the shortcut connection
        self.shortcut = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest') if upsample else nn.Identity(),  # Upsampling if needed
            # nn.Conv2d(in_channels, out_channels, kernel_size=1)  # 1x1 convolution for matching dimensions
        )

    def forward(self, x):
        # Return the sum of the main block output and the shortcut connection
        return self.block(x) + self.shortcut(x)


### Cell 5: Implement Generator with Architecture from Table 4
class SSGenerator(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        # Initial layer to project the latent vector
        self.init_layer = nn.Sequential(
            nn.Linear(latent_dim, 256 * 4 * 4),  # Fully connected layer to expand latent vector
            nn.BatchNorm1d(256 * 4 * 4),  # Batch normalization
        )

        # Residual blocks to upscale and refine the generated images
        self.res_blocks = nn.Sequential(
            ResidualBlockG(256, 256, upsample=True),  # First residual block
            ResidualBlockG(256, 256, upsample=True),  # Second residual block
            ResidualBlockG(256, 256, upsample=True),   # Third residual block
            nn.ReLU()  # ReLU activation
        )

        # Final layers to produce the output image
        self.final_layers = nn.Sequential(
            # nn.BatchNorm2d(256),  # Batch normalization
            # nn.ReLU(),  # ReLU activation
            nn.Conv2d(256, 1, kernel_size=3, padding=1),  # Final convolution to reduce channels to 1 (grayscale image)
            nn.Tanh()  # Tanh activation to output values in the range [-1, 1]
        )

    def forward(self, z):
        # Forward pass through the generator
        x = self.init_layer(z)  # Pass latent vector through initial layer
        x = x.view(x.size(0), 256, 4, 4)  # Reshape to (batch_size, channels, height, width)
        x = self.res_blocks(x)  # Pass through residual blocks
        return self.final_layers(x)  # Pass through final layers to get output image

In [100]:
import torchinfo
img = SSGenerator()
torchinfo.summary(img,input_size=(1,128))

Layer (type:depth-idx)                   Output Shape              Param #
SSGenerator                              [1, 1, 32, 32]            --
├─Sequential: 1-1                        [1, 4096]                 --
│    └─Linear: 2-1                       [1, 4096]                 528,384
│    └─BatchNorm1d: 2-2                  [1, 4096]                 8,192
├─Sequential: 1-2                        [1, 256, 32, 32]          --
│    └─ResidualBlockG: 2-3               [1, 256, 8, 8]            --
│    │    └─Sequential: 3-1              [1, 256, 8, 8]            1,181,184
│    │    └─Sequential: 3-2              [1, 256, 8, 8]            --
│    └─ResidualBlockG: 2-4               [1, 256, 16, 16]          --
│    │    └─Sequential: 3-3              [1, 256, 16, 16]          1,181,184
│    │    └─Sequential: 3-4              [1, 256, 16, 16]          --
│    └─ResidualBlockG: 2-5               [1, 256, 32, 32]          --
│    │    └─Sequential: 3-5              [1, 256, 32, 32]      

In [101]:
import torchinfo

block = ResidualBlockG(256, 256, upsample=True)
torchinfo.summary(block, input_size=(1, 256, 4, 4))

Layer (type:depth-idx)                   Output Shape              Param #
ResidualBlockG                           [1, 256, 8, 8]            --
├─Sequential: 1-1                        [1, 256, 8, 8]            --
│    └─BatchNorm2d: 2-1                  [1, 256, 4, 4]            512
│    └─ReLU: 2-2                         [1, 256, 4, 4]            --
│    └─Upsample: 2-3                     [1, 256, 8, 8]            --
│    └─Conv2d: 2-4                       [1, 256, 8, 8]            590,080
│    └─BatchNorm2d: 2-5                  [1, 256, 8, 8]            512
│    └─ReLU: 2-6                         [1, 256, 8, 8]            --
│    └─Conv2d: 2-7                       [1, 256, 8, 8]            590,080
├─Sequential: 1-2                        [1, 256, 8, 8]            --
│    └─Upsample: 2-8                     [1, 256, 8, 8]            --
Total params: 1,181,184
Trainable params: 1,181,184
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 75.53
Input size (MB): 0.02
Forw

In [102]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# تعریف بلوک Residual برای Discriminator
class ResidualBlockD(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False):
        super().__init__()
        self.downsample = downsample  # Flag to determine if this is the first block

        # Block1
        self.block1 = nn.Sequential(
            nn.ReLU() if not downsample else nn.Identity(),  # ReLU (not for the first block)
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),  # Conv2D 3x3
            nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)),  # Spectral Norm
            nn.ReLU(),  # ReLU
        )

        # Block2
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),  # Conv2D 3x3
            nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)),  # Spectral Norm
            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),  # AvgPool2D 2x2
        )

        # Shortcut connection
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),  # Conv2D 1x1
            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),  # AvgPool2D 2x2
        )

    def forward(self, x):
        # Block1
        out1 = self.block1(x)

        # Block2
        out2 = self.block2(out1)

        # Shortcut connection
        shortcut = self.shortcut(x)

        # Combine Block1 + Block2 with shortcut
        return out2 + shortcut


# تعریف Discriminator برای SSGAN
class SSDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()

        # Residual blocks to downsample and refine the input images
        self.res_blocks = nn.Sequential(
            ResidualBlockD(1, 128, downsample=True),  # First residual block (downsample)
            ResidualBlockD(128, 128),  # Second residual block (downsample)
            ResidualBlockD(128, 128),  # Third residual block (no downsample)
            ResidualBlockD(128, 128),  # Fourth residual block (no downsample)
        )

        # Final layers to produce the output
        self.final_layers = nn.Sequential(
            nn.Linear(128 * 3 * 3, 1),  # Fully connected layer to output a single value (real/fake)
            nn.Linear(128 * 3 * 3, 4),  # Fully connected layer for self-supervised task (e.g., rotation prediction)
        )

    def forward(self, x):
        # Forward pass through the discriminator
        x = self.res_blocks(x)  # Pass through residual blocks
        x = x.view(x.size(0), -1)  # Flatten the output
        real_fake_output = self.final_layers[0](x)  # Output for real/fake classification
        self_supervised_output = self.final_layers[1](x)  # Output for self-supervised task
        return real_fake_output, self_supervised_output




In [103]:
print("Summary for ResidualBlockD:")
residual_block = ResidualBlockD(in_channels=1, out_channels=128, downsample=False)
torchinfo.summary(residual_block, input_size=(128, 1, 28, 28))  # ورودی با ابعاد (batch_size, channels, height, width)

# نمایش خلاصه‌ی SSDiscriminator
print("\nSummary for SSDiscriminator:")
discriminator = SSDiscriminator()
torchinfo.summary(discriminator, input_size=(128, 1, 28, 28))  # ورودی با ابعاد (batch_size, channels, height, width)

Summary for ResidualBlockD:

Summary for SSDiscriminator:


Layer (type:depth-idx)                   Output Shape              Param #
SSDiscriminator                          [128, 1]                  --
├─Sequential: 1-1                        [128, 128, 3, 3]          --
│    └─ResidualBlockD: 2-1               [128, 128, 15, 15]        --
│    │    └─Sequential: 3-1              [128, 128, 28, 28]        148,864
│    │    └─Sequential: 3-2              [128, 128, 15, 15]        295,168
│    │    └─Sequential: 3-3              [128, 128, 15, 15]        256
│    └─ResidualBlockD: 2-2               [128, 128, 8, 8]          --
│    │    └─Sequential: 3-4              [128, 128, 15, 15]        295,168
│    │    └─Sequential: 3-5              [128, 128, 8, 8]          295,168
│    │    └─Sequential: 3-6              [128, 128, 8, 8]          16,512
│    └─ResidualBlockD: 2-3               [128, 128, 5, 5]          --
│    │    └─Sequential: 3-7              [128, 128, 8, 8]          295,168
│    │    └─Sequential: 3-8              [128, 128, 5, 

In [106]:

# تعریف Generator و Discriminator (با استفاده از کدهای قبلی)
generator = SSGenerator(latent_dim=128).to(device)
discriminator = SSDiscriminator().to(device)

# تعریف بهینه‌سازها
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# تعریف توابع خطا
criterion = nn.BCEWithLogitsLoss()  # برای تشخیص واقعی/جعلی
criterion_ss = nn.CrossEntropyLoss()  # برای وظیفه خودنظارتی (مثلاً پیش‌بینی چرخش)


In [None]:

# لیست‌ها برای ذخیره Loss‌ها
g_losses = []
d_losses = []

# آموزش مدل
num_epochs = 10
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # (a) Train Discriminator
        discriminator.zero_grad()

        # 1. Train with real images
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Discriminator output for real images
        real_output, real_ss_output = discriminator(real_images)
        d_loss_real = criterion(real_output, real_labels)

        # 2. Train with fake images
        z = torch.randn(batch_size, 128).to(device)  # Sample random latent vectors
        fake_images = generator(z)
        fake_output, fake_ss_output = discriminator(fake_images.detach())
        d_loss_fake = criterion(fake_output, fake_labels)

        # 3. Self-supervised task (e.g., rotation prediction)
        # Assuming rotation prediction with 4 classes (0°, 90°, 180°, 270°)
        rotation_labels = torch.randint(0, 4, (batch_size,)).to(device)
        d_loss_ss = criterion_ss(real_ss_output, rotation_labels)

        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake + d_loss_ss
        d_loss.backward()
        optimizer_D.step()

        # (b) Train Generator
        generator.zero_grad()

        # Generate fake images again
        z = torch.randn(batch_size, 128).to(device)
        fake_images = generator(z)
        fake_output, _ = discriminator(fake_images)

        # Generator loss
        g_loss = criterion(fake_output, real_labels)
        g_loss.backward()
        optimizer_G.step()

        # ذخیره Loss‌ها
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

        # Print losses every 100 steps
        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "
                  f"Loss_D: {d_loss.item():.4f}, Loss_G: {g_loss.item():.4f}")

    # Save generated images at the end of each epoch
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            z = torch.randn(100, 128).to(device)  # Generate 100 random latent vectors
            generated_images = generator(z).cpu()  # Generate images and move to CPU

            # Save generated images
            fig, axs = plt.subplots(10, 10, figsize=(10, 10))
            for i in range(10):
                for j in range(10):
                    img_idx = i * 10 + j
                    axs[i, j].imshow(generated_images[img_idx].squeeze(), cmap='gray')
                    axs[i, j].axis('off')
            plt.savefig(f"generated_images_epoch_{epoch + 1}.png")
            plt.close(fig)


Epoch [1/10], Step [100/468], Loss_D: 1.4078, Loss_G: 6.4545
Epoch [1/10], Step [200/468], Loss_D: 1.3863, Loss_G: 7.5723
Epoch [1/10], Step [300/468], Loss_D: 1.3940, Loss_G: 6.5627
Epoch [1/10], Step [400/468], Loss_D: 1.3900, Loss_G: 7.5336
Epoch [2/10], Step [100/468], Loss_D: 1.3922, Loss_G: 7.8490
Epoch [2/10], Step [200/468], Loss_D: 1.3893, Loss_G: 8.2349


In [None]:

# رسم نمودار Loss‌ها
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label="Generator Loss")
plt.plot(d_losses, label="Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Generator and Discriminator Loss During Training")
plt.legend()
plt.savefig("loss_plot.png")
plt.show()

In [None]:
# Create real images directory
os.makedirs("./real_images", exist_ok=True)
for i, (img, _) in enumerate(dataset):
    if i >= 5000:
        break
    plt.imsave(f"./real_images/{i}.png", img.squeeze().numpy(), cmap='gray')


fake_images = []
for _ in range(100):  # 100 * batch_size = 12800 > 5000
    z = torch.randn(batch_size, latent_dim).to(device)
    fake = generator(z).cpu().detach()
    print(fake.shape)
    fake_images.append(fake)
fake_images = torch.cat(fake_images)[:5000]  # انتخاب 5000 نمونه

# ذخیره در پوشه
os.makedirs("./fake_images_ssgan", exist_ok=True)
for i, img in enumerate(fake_images):
    plt.imsave(f"./fake_images_ssgan/{i}.png", img.squeeze(), cmap='gray')
# Verify paths
print(f"Real images: {len(os.listdir('./real_images'))}")
print(f"Fake images: {len(os.listdir('./fake_images_ssgan'))}")

# Calculate FID
fid_value = fid_score.calculate_fid_given_paths(
    ["./real_images", "./fake_images_ssgan"],
    batch_size=50,
    device=device,
    dims=2048
)
print(f"FID Score: {fid_value:.2f}")

