<a href="https://colab.research.google.com/github/principwty/Machine-Learning-and-Deep-Learning/blob/main/GAN_Fashion_MNIST_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

# Define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 0.0002
batch_size = 128
num_epochs = 50
latent_size = 100
image_size = 28 * 28
hidden_size = 256

In [None]:
# Define the transformations for the Fashion MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the Fashion MNIST dataset
train_dataset = torchvision.datasets.FashionMNIST(
    root='./data', train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

In [None]:
# Define the generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, image_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x

In [None]:
# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(image_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

In [None]:
# Initialize the networks
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Define the loss function and optimizer
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=lr)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

In [None]:
# Train the GAN
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.shape[0]
        images = images.view(batch_size, -1).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train the discriminator with real images
        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)

        # Train the discriminator with fake images
        noise = torch.randn(batch_size, latent_size).to(device)
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)

        # Backpropagation and Optimization
        d_loss = d_loss_real + d_loss_fake
        discriminator.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Train the generator
        noise = torch.randn(batch_size, latent_size).to(device)
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)

        # Backpropagation and optimization
        generator.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Print the loss and save the generated images
        if i % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
                  .format(epoch+1, num_epochs, i+1, len(train_loader), d_loss.item(), g_loss.item()))

            # Save the generated images
            fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
            save_image(fake_images, 'generated_images-{}.png'.format(epoch+1))


Epoch [1/50], Step [1/469], d_loss: 1.3872, g_loss: 0.7056
Epoch [1/50], Step [101/469], d_loss: 0.3411, g_loss: 1.8949
Epoch [1/50], Step [201/469], d_loss: 0.1330, g_loss: 2.7505
Epoch [1/50], Step [301/469], d_loss: 0.0905, g_loss: 3.8671
Epoch [1/50], Step [401/469], d_loss: 0.1159, g_loss: 3.7903
Epoch [2/50], Step [1/469], d_loss: 0.0508, g_loss: 4.2418
Epoch [2/50], Step [101/469], d_loss: 0.0734, g_loss: 3.4324
Epoch [2/50], Step [201/469], d_loss: 0.0814, g_loss: 3.5016
Epoch [2/50], Step [301/469], d_loss: 0.0454, g_loss: 4.1276
Epoch [2/50], Step [401/469], d_loss: 0.0196, g_loss: 5.0892
Epoch [3/50], Step [1/469], d_loss: 0.0549, g_loss: 4.9826
Epoch [3/50], Step [101/469], d_loss: 0.0178, g_loss: 5.3346
Epoch [3/50], Step [201/469], d_loss: 0.0312, g_loss: 5.5372
Epoch [3/50], Step [301/469], d_loss: 0.0492, g_loss: 5.2041
Epoch [3/50], Step [401/469], d_loss: 0.0110, g_loss: 5.2085
Epoch [4/50], Step [1/469], d_loss: 0.0062, g_loss: 7.3365
Epoch [4/50], Step [101/469], d_