In [3]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

# Hyperparameters
class opt:
    img_size = 28
    latent_dim = 100
    batch_size = 64
    lr = 0.00005
    n_epochs = 800
    n_critic = 5  # Train Discriminator n_critic times per Generator iteration
    clip_value = 0.01  # Clipping value for Discriminator weights

img_shape = (1, opt.img_size, opt.img_size)

# Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img

# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(opt.img_size ** 2, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )
    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Optimizers
generator_optimizer = optim.RMSprop(generator.parameters(), lr=opt.lr)
discriminator_optimizer = optim.RMSprop(discriminator.parameters(), lr=opt.lr)

# Data loader
os.makedirs('mnist_data', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('mnist_data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Resize(opt.img_size),
                       transforms.ToTensor(),
                       transforms.Normalize([0.5], [0.5])
                   ])),
    batch_size=opt.batch_size, shuffle=True
)

# Directories for checkpoints, images, and logs
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("generated_images", exist_ok=True)

# Loss logging
log_file = open("training_log.txt", "w")

# CUDA setup
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
if cuda:
    generator.cuda()
    discriminator.cuda()

# Training
start_time = time.time()
G_losses, D_losses = [], []

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = Variable(imgs.type(Tensor))

        # Train Discriminator
        discriminator_optimizer.zero_grad()

        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
        fake_imgs = generator(z).detach()

        d_loss = torch.mean(discriminator(fake_imgs)) - torch.mean(discriminator(real_imgs))
        d_loss.backward()
        discriminator_optimizer.step()

        # Clip weights of Discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)

        # Train Generator every n_critic iterations
        if i % opt.n_critic == 0:
            generator_optimizer.zero_grad()

            critics_fake_imgs = generator(z)
            g_loss = -torch.mean(discriminator(critics_fake_imgs))
            g_loss.backward()
            generator_optimizer.step()

    # Save generator and discriminator losses
    G_losses.append(g_loss.item())
    D_losses.append(d_loss.item())

    # Print progress
    print(f"[Epoch {epoch}/{opt.n_epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

    # Log the losses into a file
    log_file.write(f"Epoch {epoch}/{opt.n_epochs}, D loss: {d_loss.item()}, G loss: {g_loss.item()}, Time elapsed: {time.time() - start_time}s\n")
    log_file.flush()

    # Save checkpoints every 25 epochs
    if epoch % 25 == 0:
        torch.save(generator.state_dict(), f"checkpoints/generator_epoch_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"checkpoints/discriminator_epoch_{epoch}.pth")

    # Save generated images every epoch
    z = Variable(Tensor(np.random.normal(0, 1, (100, opt.latent_dim))))
    gen_imgs = generator(z)
    save_image(gen_imgs.data, f"generated_images/{epoch}.png", nrow=10, normalize=True)

# Close the log file
log_file.close()

# Training duration
end_time = time.time()
total_training_time = end_time - start_time
print(f"Training finished. Total time: {total_training_time:.2f}s")

# Plot the losses
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig("training_loss_plot.png")
plt.show()


[Epoch 0/800] [D loss: -0.7494516372680664] [G loss: -15.359916687011719]
[Epoch 1/800] [D loss: -0.054233551025390625] [G loss: -13.150790214538574]
[Epoch 2/800] [D loss: -0.001461029052734375] [G loss: -8.527275085449219]
[Epoch 3/800] [D loss: -0.028022101148962975] [G loss: 0.051956500858068466]
[Epoch 4/800] [D loss: -0.46532005071640015] [G loss: -0.12578758597373962]
[Epoch 5/800] [D loss: -0.4571816921234131] [G loss: -1.3220010995864868]
[Epoch 6/800] [D loss: -0.5442802906036377] [G loss: -2.91330885887146]
[Epoch 7/800] [D loss: -0.2949094772338867] [G loss: -1.89393150806427]
[Epoch 8/800] [D loss: -0.31762146949768066] [G loss: -2.615656852722168]
[Epoch 9/800] [D loss: -0.7063255310058594] [G loss: -0.7459738254547119]
[Epoch 10/800] [D loss: -0.2276151180267334] [G loss: -2.3458900451660156]
[Epoch 11/800] [D loss: -0.4143112897872925] [G loss: -1.4876806735992432]
[Epoch 12/800] [D loss: -0.41095876693725586] [G loss: -1.5377445220947266]
[Epoch 13/800] [D loss: -0.385

KeyboardInterrupt: 

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


mkdir: cannot create directory ‘/content/Outputs’: File exists


In [None]:
!cp -r /content/outputs /content/drive/MyDrive/WGAN-MNIST-DIGIT-640-epochs/