### Importing dependencies.

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image

### Defining the generative network.

In [7]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=1):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(z_dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, img_dim)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.fc1(x), 0.2)
        x = nn.functional.leaky_relu(self.fc2(x), 0.2)
        x = nn.functional.leaky_relu(self.fc3(x), 0.2)
        x = nn.functional.tanh(self.fc4(x))
        return x


### Defining the discriminative network.

In [8]:
class Discriminator(nn.Module):
    def __init__(self, img_dim=1):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(img_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.fc1(x), 0.2)
        x = nn.functional.leaky_relu(self.fc2(x), 0.2)
        x = nn.functional.leaky_relu(self.fc3(x), 0.2)
        return x


### Loss Function.

In [10]:
criterion = nn.BCEWithLogitsLoss()


### Loading the training dataset.

In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)


### Defining the GAN Model.

In [12]:
z_dim = 100
img_dim = 28*28

generator = Generator(z_dim=z_dim, img_dim=img_dim)
discriminator = Discriminator(img_dim=img_dim)

optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)


### Training the Model.

In [13]:
num_epochs = 250
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(trainloader):
        batch_size = real_images.size(0)

        # Train the discriminator
        optimizer_d.zero_grad()
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        z = torch.randn(batch_size, z_dim)
        fake_images = generator(z)
        real_outputs = discriminator(real_images.view(batch_size, -1))
        fake_outputs = discriminator(fake_images.detach())
        d_loss_real = criterion(real_outputs, real_labels)
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # Train the generator
        optimizer_g.zero_grad()
        z = torch.randn(batch_size, z_dim)
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizer_g.step()

        if i % 100 == 0:
            print('[%d/%d][%03d/%d]\tLoss_D: %.4f\tLoss_G: %.4f' % (epoch+1,
                  num_epochs, i+1, len(trainloader), d_loss.item(), g_loss.item()))

    # Generate a sample image
    with torch.no_grad():
        fake = generator(torch.randn(1, z_dim))
        fake = fake.view(28, 28)
        fake = (fake + 1) / 2
        save_image(fake, 'output/fake_%03d.png' % (epoch+1))

    # Save the model
    torch.save(generator.state_dict(), 'models/generator_%03d.pt' % (epoch+1))
    torch.save(discriminator.state_dict(),
               'models/discriminator_%03d.pt' % (epoch+1))


[1/250][001/469]	Loss_D: 1.3868	Loss_G: 0.6934
[1/250][101/469]	Loss_D: 0.9056	Loss_G: 0.6967
[1/250][201/469]	Loss_D: 0.7008	Loss_G: 0.6945
[1/250][301/469]	Loss_D: 0.6971	Loss_G: 0.6986
[1/250][401/469]	Loss_D: 0.6969	Loss_G: 0.7021
[2/250][001/469]	Loss_D: 0.7540	Loss_G: 0.7016
[2/250][101/469]	Loss_D: 1.4362	Loss_G: 0.7248
[2/250][201/469]	Loss_D: 1.4029	Loss_G: 0.6974
[2/250][301/469]	Loss_D: 1.4188	Loss_G: 0.7267
[2/250][401/469]	Loss_D: 1.3965	Loss_G: 0.7079
[3/250][001/469]	Loss_D: 1.3665	Loss_G: 0.7205
[3/250][101/469]	Loss_D: 1.0642	Loss_G: 0.6897
[3/250][201/469]	Loss_D: 1.4470	Loss_G: 0.6985
[3/250][301/469]	Loss_D: 1.4180	Loss_G: 0.6988
[3/250][401/469]	Loss_D: 1.3573	Loss_G: 0.7319
[4/250][001/469]	Loss_D: 0.8517	Loss_G: 0.7035
[4/250][101/469]	Loss_D: 1.3999	Loss_G: 0.7145
[4/250][201/469]	Loss_D: 1.3739	Loss_G: 0.7175
[4/250][301/469]	Loss_D: 0.7866	Loss_G: 0.6971
[4/250][401/469]	Loss_D: 1.4388	Loss_G: 0.6987
[5/250][001/469]	Loss_D: 1.4033	Loss_G: 0.6997
[5/250][101/4

KeyboardInterrupt: 