<a href="https://colab.research.google.com/github/rsaran-BioAI/AGILE/blob/main/FashionMNIST_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

In [6]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

Why Setting Random Seeds Is Important:

When training machine learning models, especially neural networks, randomness plays a role in several aspects, such as weight initialization, data shuffling, dropout, and more.
Without setting random seeds, each run of your code may produce slightly different results due to the inherent randomness, making it difficult to reproduce and compare experiments.
By setting the seeds as shown in the code, you make your experiments reproducible. That is, if you run the code with the same seeds on the same data, you should get the same results, which is crucial for research, debugging, and sharing code with others.

In [7]:
# Define the Generator network:
class Generator(nn.Module): # This line defines a Python class named Generator that inherits from nn.Module, which is the base class for all PyTorch neural network modules
    def __init__(self): #  In the constructor '__init__' of the Generator class, a fully connected (linear) layer is defined.
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 784)  # 28x28 output

    def forward(self, z):
        x = torch.relu(self.fc1(z))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

In [8]:
# Define the Discriminator network:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)  # Binary classification output

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x