<font size='6'>**WGAN(gradient penalty) Training**</font>

In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.io import read_image
from PIL import ImageFile
%matplotlib inline
ImageFile.LOAD_TRUNCATED_IMAGES = True


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2, affine=True),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.25),
            nn.Conv2d(ndf*2, ndf * 8, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 8,affine=True),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.25),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),)

    def forward(self, input):
        return self.main(input)
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d( nz, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)
#parameters
########################
ndf= 64
ngf= 64
nc= 3
ngpu= 1
epochs= 5
nz= 100
lr1 = 0.0002
lr2 = 0.0002
imsize= 32
batch_size= 64
wr= 2
lgp = 10
########################
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

def compute_gradient_penalty(discr, real, fake, device=device):
    alpha = torch.randn_like(real).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate disc scores
    mixed_scores = discr(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

data_path= "/path/dataset_or_part_of_dataset"
datas= data.ImageFolder(root=data_path,
                           transform=transforms.Compose([
                               transforms.Resize(imsize),
                               transforms.CenterCrop(imsize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
writer_real = SummaryWriter(f"logs/real5")
writer_fake = SummaryWriter(f"logs/fake5")
disc= Discriminator(ngpu).to(device)
gen= Generator(ngpu).to(device)
disc.apply(weights_init)
gen.apply(weights_init)
optm_gen = torch.optim.Adam(gen.parameters(), lr=lr1, betas=(0.5,0.99))
optm_disc = torch.optim.Adam(disc.parameters(), lr=lr2, betas=(0.5, 0.99))
real_l= 1
fake_l= 0
gen_loss= []
disc_loss= []
img_list= []
step= 0
eval_interval= 10
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
dataloader = torch.utils.data.DataLoader(datas, batch_size=batch_size,shuffle=True, num_workers=wr)
disciter=5   
#Training Loop
for epoch in range(epochs):
    for i, real_data in enumerate(dataloader, 0):
        # Train Discriminator
        disc.zero_grad()
        # Train with real data
        real = real_data[0].to(device)
        batch_size = real.size(0)
        for _ in range(disciter):
            noise = torch.randn(batch_size, nz, 1, 1).to(device)
            fake = gen(noise)
            disc_real = disc(real).reshape(-1)
            disc_fake = disc(fake).reshape(-1)
            gradpen = compute_gradient_penalty(disc, real, fake)
            loss_disc = -(torch.mean(disc_real) - torch.mean(disc_fake)) + lgp *gradpen
            disc.zero_grad()
            loss_disc.backward(retain_graph=True)
            optm_disc.step()
        # Train Generator
        gen_fake = disc(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        optm_gen.step()
        # Record Losses
        gen_loss.append(loss_gen.item())
        disc_loss.append(loss_disc.item())
        if (epoch == epochs-1):
            with torch.no_grad():
                fake = gen(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        if step % 10 == 0:
            print(f"Epoch [{epoch}/{epochs}], Batch [{i}/{len(dataloader)}]"
                  f"Discriminator Loss: {loss_disc.item():.4f}, "
                  f"Generator Loss: {loss_gen.item():.4f}")
            with torch.no_grad():
                fake = gen(fixed_noise)
                real = real_data[0]
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
        step=step+1

<font size='6'>**Visualize Loss**</font>

In [None]:
plt.title("Wgan Generator and Discriminator Loss During Training 5 epochs and 60k samples")
plt.plot(gen_loss,label="G")
plt.plot(disc_loss,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()