In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device,'\n')
class GAN():
    def __init__(self):
        self.img_shape = (1, 28, 28)
        self.latent_dim = 100

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator().to(device)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

        # Build the generator
        self.generator = self.build_generator().to(device)
        self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

        # Loss function
        self.adversarial_loss = nn.BCELoss().to(device)

    def build_generator(self):
        class Generator(nn.Module):
            def __init__(self, latent_dim, img_shape):
                super(Generator, self).__init__()
                self.img_shape = img_shape
                self.model = nn.Sequential(
                    nn.Linear(latent_dim, 256),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.BatchNorm1d(256, 0.8),
                    nn.Linear(256, 512),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.BatchNorm1d(512, 0.8),
                    nn.Linear(512, 1024),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.BatchNorm1d(1024, 0.8),
                    nn.Linear(1024, int(np.prod(self.img_shape))),
                    nn.Tanh()
                )

            def forward(self, z):
                img = self.model(z)
                img = img.view(img.size(0), *self.img_shape)
                return img

        return Generator(self.latent_dim, self.img_shape)

    def build_discriminator(self):
        class Discriminator(nn.Module):
            def __init__(self, img_shape):
                super(Discriminator, self).__init__()
                self.model = nn.Sequential(
                    nn.Linear(int(np.prod(img_shape)), 512),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Linear(512, 256),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Linear(256, 1),
                    nn.Sigmoid()
                )

            def forward(self, img):
                img_flat = img.view(img.size(0), -1)
                validity = self.model(img_flat)
                return validity

        return Discriminator(self.img_shape)
    def train(self, epochs, batch_size=1280, sample_interval=50):
        # Configure data loader
        train_G_interval=2
        os.makedirs("images", exist_ok=True)
        dataloader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "../data/mnist",
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]),
            ),
            batch_size=batch_size,
            shuffle=True,
        )

        for epoch in range(epochs):
            for i, (imgs, _) in enumerate(dataloader):
                # Adversarial ground truths
                valid = torch.ones((imgs.size(0), 1), requires_grad=False).to(device)
                fake = torch.zeros((imgs.size(0), 1), requires_grad=False).to(device)

                # Configure input
                real_imgs = imgs.to(device)
                if(i%train_G_interval==0):
                    # -----------------
                    #  Train Generator
                    # -----------------
                    self.generator_optimizer.zero_grad()

                    # Sample noise as generator input
                    z = torch.randn((imgs.size(0), self.latent_dim)).to(device)

                    # Generate a batch of images
                    gen_imgs = self.generator(z)

                    # Loss measures generator's ability to fool the discriminator
                    g_loss = self.adversarial_loss(self.discriminator(gen_imgs), valid)

                    g_loss.backward()
                    self.generator_optimizer.step()


                # ---------------------
                #  Train Discriminator
                # ---------------------
                self.discriminator_optimizer.zero_grad()

                # Loss for real images
                real_loss = self.adversarial_loss(self.discriminator(real_imgs), valid)
                # Loss for fake images
                fake_loss = self.adversarial_loss(self.discriminator(gen_imgs.detach()), fake)
                # Total discriminator loss
                d_loss = (real_loss + fake_loss) / 2

                d_loss.backward()
                self.discriminator_optimizer.step()

                # Print the progress
                print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

                # If at save interval, save generated samples
                if epoch % sample_interval == 0 and i == 0:
                    self.sample_images(epoch)
                    print(f"Epoch {epoch} completed")

    def sample_images(self, epoch):
        r, c = 5, 5
        z = torch.randn((r * c, self.latent_dim)).to(device)
        gen_imgs = self.generator(z)

        gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images 0 - 1

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, 0].detach().cpu().numpy(), cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(f"images/{epoch}.png")
        plt.close()


cuda 



In [6]:
import torch
print(torch.cuda.is_available())

True


# main

In [8]:
if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=1280, sample_interval=200)


[Epoch 0/30000] [Batch 0/47] [D loss: 0.7136646509170532] [G loss: 0.7057594656944275]
Epoch 0 completed
[Epoch 0/30000] [Batch 1/47] [D loss: 0.6180360317230225] [G loss: 0.7057594656944275]
[Epoch 0/30000] [Batch 2/47] [D loss: 0.5433326363563538] [G loss: 0.7045255899429321]
[Epoch 0/30000] [Batch 3/47] [D loss: 0.48193255066871643] [G loss: 0.7045255899429321]
[Epoch 0/30000] [Batch 4/47] [D loss: 0.4366365671157837] [G loss: 0.7046270966529846]
[Epoch 0/30000] [Batch 5/47] [D loss: 0.40259745717048645] [G loss: 0.7046270966529846]
[Epoch 0/30000] [Batch 6/47] [D loss: 0.3810967803001404] [G loss: 0.7055912017822266]
[Epoch 0/30000] [Batch 7/47] [D loss: 0.36660438776016235] [G loss: 0.7055912017822266]
[Epoch 0/30000] [Batch 8/47] [D loss: 0.35885435342788696] [G loss: 0.7076196670532227]
[Epoch 0/30000] [Batch 9/47] [D loss: 0.35259121656417847] [G loss: 0.7076196670532227]
[Epoch 0/30000] [Batch 10/47] [D loss: 0.34896114468574524] [G loss: 0.7103464007377625]
[Epoch 0/30000] [B

KeyboardInterrupt: 