# 1. Install Library

In [1]:
!pip install torch torchvision



In [2]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torch.optim as optim
from tqdm import tqdm

# 1.Load dataset


In [3]:
def load_mnist(batch_size):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [4]:
batch_size = 200
train_MNIST, test_MNIST = load_mnist(batch_size)

# 2. Create Fence GAN  accrodint got Architecture and hyperparameters on MNIST Dataset (Table 1)

Define Generatoro and Discriminator

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 128 * 7 * 7),  # Assuming latent size of 100
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU(True),
            nn.Unflatten(1, (128, 7, 7)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),
            nn.Tanh()  # Output is a 28x28 image
        )

    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1024),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

generator = Generator()
discriminator = Discriminator()

Define Loss

In [6]:
def encirclement_loss(discriminator_outputs, alpha=0.1):
    return torch.mean(torch.abs(discriminator_outputs - alpha))

def dispersion_loss(generated_images, beta=30):
    mean = torch.mean(generated_images, 0, keepdim=True)
    return beta * torch.mean((generated_images - mean) ** 2)

Define optimization

In [7]:
optimizer_G = optim.Adam(generator.parameters(), lr=2e-5, weight_decay=1e-4)
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-5, weight_decay=1e-4)
criterion = nn.BCELoss()

#Train model

In [9]:
epochs = 100

for epoch in tqdm(range(epochs)):
    for i, (imgs, _) in enumerate(train_MNIST):
        real_imgs = imgs
        real = torch.ones(imgs.size(0), 1)
        fake = torch.zeros(imgs.size(0), 1)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), real)
        z = torch.randn(imgs.size(0), 100)
        fake_imgs = generator(z)
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_imgs), real)
        total_g_loss = g_loss + encirclement_loss(discriminator(fake_imgs)) + dispersion_loss(fake_imgs)
        total_g_loss.backward()
        optimizer_G.step()

100%|██████████| 100/100 [1:50:46<00:00, 66.47s/it]


# Evaluation


In [10]:
correct = 0
total = 0
with torch.no_grad():
    for images, _ in test_MNIST:

        images = images
        outputs = discriminator(images)
        predicted = (outputs > 0.5).float()  # Assuming the output is a probability of being real
        total += images.size(0)
        correct += (predicted == 1).sum().item()

        z = torch.randn(images.size(0), 100)  # 100 is the size of the noise vector
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images)
        fake_predicted = (fake_outputs < 0.5).float()  # Fake images should be classified as 0
        correct += (fake_predicted == 1).sum().item()
        total += fake_images.size(0)

print(f'Accuracy of the discriminator on the test images: {100 * correct / total}%')

Accuracy of the discriminator on the test images: 99.005%


#Save model

In [13]:
model_path = 'generator'
optimizer_path = 'generator_opt'
torch.save(generator.state_dict(), model_path)
torch.save(optimizer_G.state_dict(), optimizer_path)

In [14]:
discriminator_model_path = 'discriminator'
discriminator_optimizer_path = 'discriminator_opt'

torch.save(discriminator.state_dict(), discriminator_model_path)
torch.save(optimizer_D.state_dict(), discriminator_optimizer_path)