In [None]:
import intel_extension_for_pytorch as ipex
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torch import optim
import torch.nn as nn
import torch
from tqdm import tqdm
from typing import Tuple
import imageio
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

### Dataset

In [None]:
BATCH_SIZE = 64
transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.5, .5, .5], [.5, .5, .5])])

trainset = CIFAR10(root='../data', train=True,  transform=transformer, download=True)
testset  = CIFAR10(root='../data', train=False, transform=transformer, download=True)

trainloader = DataLoader(trainset, BATCH_SIZE, shuffle=True)
testloader  = DataLoader(testset, BATCH_SIZE, shuffle=True)

classes = {i:name for i, name in enumerate(trainset.classes)}
classes

In [None]:
def show_imgs(imgs: torch.Tensor, labels: torch.Tensor):
    imgs = imgs.cpu().detach().clone().permute(0, 2, 3, 1)
    labels = labels.cpu().detach().clone()
    imgs = (imgs + 1)/2
    n = len(labels)
    rows, cols = n // 4, 4
    plt.figure(figsize=(8, 4))
    for i in range(n):
        plt.subplot(rows, cols, i+1)
        plt.imshow(imgs[i-1])
        plt.title(classes[labels[i-1].item()])
        plt.axis(False)
    plt.show()

In [None]:
imgs, labels = next(iter(trainloader))
show_imgs(imgs[:8], labels[:8])

### Model

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim: int, n_classes: int ):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.n_classes = n_classes
        self.embeddings = nn.Embedding(num_embeddings=n_classes,
                                       embedding_dim=latent_dim)
        
        self.model = nn.Sequential(
            nn.Unflatten(dim=-1, unflattened_size=(latent_dim, 1, 1)),
            # ConvBlock2 (4x4)
            nn.UpsamplingBilinear2d(size=(4,4)),
            nn.Conv2d(in_channels=latent_dim,
                      out_channels=256,
                      kernel_size=3,
                      stride=1, padding='same',
                      bias=False),            
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.LeakyReLU(negative_slope=0.15),
            # ConvBlock3 (8x8)
            nn.UpsamplingBilinear2d(size=(8,8)),
            nn.Conv2d(in_channels=256,
                      out_channels=128,
                      kernel_size=3,
                      stride=1, padding='same',
                      bias=False),
            nn.BatchNorm2d(num_features=128, momentum=0.9),
            nn.LeakyReLU(negative_slope=0.15),
            # ConvBlock4 (16x16)
            nn.UpsamplingBilinear2d(size=(16,16)),
            nn.Conv2d(in_channels=128,
                      out_channels=64,
                      kernel_size=3,
                      stride=1, padding='same',
                      bias=False),
            nn.BatchNorm2d(num_features=64, momentum=0.9),
            nn.LeakyReLU(negative_slope=0.15),
            # Final Conv Transpose (32x32)
            nn.UpsamplingBilinear2d(size=(32,32)),
            nn.Conv2d(in_channels=64,
                      out_channels=3,
                      kernel_size=3,
                      stride=1, padding='same',
                      bias=False),
            nn.Tanh())
    
    def forward(self, labels):
        cls_embds = self.embeddings(labels)
        eps = torch.randn_like(cls_embds)
        return self.model(cls_embds + eps)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, n_classes: int, img_shape: Tuple):
        super(Discriminator, self).__init__()
        self.n_classes = n_classes
        self.embeddings = nn.Embedding(num_embeddings=n_classes,
                                       embedding_dim=torch.prod(img_shape))
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=6,
                      out_channels=64,
                      kernel_size=3,
                      stride=2, padding=1),
            nn.LeakyReLU(negative_slope=0.15),
            nn.Dropout2d(p=0.1),
            
            nn.Conv2d(in_channels=64, 
                      out_channels=128,
                      kernel_size=3,
                      stride=2, padding=1),
            nn.BatchNorm2d(num_features=128, momentum=0.9),
            nn.LeakyReLU(negative_slope=0.15),
            nn.Dropout2d(p=0.1),
            
            nn.Conv2d(in_channels=128, 
                      out_channels=256,
                      kernel_size=3,
                      stride=2, padding=1),
            nn.BatchNorm2d(num_features=256, momentum=0.9),
            nn.LeakyReLU(negative_slope=0.15),
            nn.Dropout2d(p=0.1),
            
            nn.Conv2d(in_channels=256, 
                      out_channels=1,
                      kernel_size=4,
                      stride=1, padding=0), 
            nn.Flatten())
        
    def forward(self, imgs, labels):
        embds = self.embeddings(labels).view(imgs.shape)
        concat = torch.cat((imgs, embds), dim=1)
        return self.model(concat)

In [None]:
class CWGAN_GP(nn.Module):
    def __init__(self, generator, discriminator, LAMBDA_GP: float=10):
        super(CWGAN_GP, self).__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.LAMBDA_GP = LAMBDA_GP

    def forward(self, labels):
        with torch.inference_mode():
            out = self.generator(labels)      
        return out.detach()
    
    def compile(self, g_optimizer, d_optimizer):
        self.g_optimizer = g_optimizer
        self.d_optimizer = d_optimizer
    
    def gradient_penalty(self, real_imgs, fake_imgs, labels):
        alpha = torch.randn((*real_imgs.shape[:2], 1, 1)).to(real_imgs.device)
        diff = fake_imgs - real_imgs
        interpolation = real_imgs + alpha * diff
        preds = self.discriminator(interpolation, labels)
        grad_outputs = torch.ones_like(preds)
        gradients = torch.autograd.grad(
            outputs=preds,
            inputs=interpolation,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True)[0]
        
        gradients = gradients.view(labels.shape[0], -1)
        grad_norm = gradients.norm(2, 1)
        return ((grad_norm-1)**2).mean()

### Model Initialization

In [None]:
EPOCHS = 300
LATENT_DIM = 200
DG_TRAIN_RATIO = 5
LAMBDA_GP = 10

device = 'xpu' if ipex.xpu.is_available() else 'cpu'

generator = Generator(LATENT_DIM, len(classes))
discriminator = Discriminator(len(classes), torch.tensor(imgs[0].shape))

g_optimizer = optim.Adam(generator.parameters())
d_optimizer = optim.Adam(discriminator.parameters())

model = CWGAN_GP(generator, discriminator, LAMBDA_GP)
model.compile(g_optimizer, d_optimizer)
print(ipex.xpu.get_device_name())

### Training Implementation Scheme

![image.png](attachment:image.png)

In [None]:
def train_step(model: nn.Module, trainloader: DataLoader, 
               device: torch.device, DG_TRAIN_RATIO:int=3):
    model, model.g_optimizer = ipex.optimize(model, optimizer=model.g_optimizer)
    model, model.d_optimizer = ipex.optimize(model, optimizer=model.d_optimizer)
    loss = {'d_loss':.0, 'g_loss':.0, 'd_wass_loss':.0, 'd_gp':.0}
    
    for i, (real_imgs, labels) in enumerate(tqdm(trainloader)):
        real_imgs = real_imgs.to(device)
        labels = labels.to(device)
        
        # *******************
        # Train Discriminator
        # *******************
        
        # setting grads to 0 after last generator backprop
        model.d_optimizer.zero_grad()
        
        # generated images conditioned on labels
        fake_imgs = model.generator(labels)
        
        # discriminator output based on fake&real images
        d_fake_logits = model.discriminator(fake_imgs.detach(), labels)
        d_real_logits = model.discriminator(real_imgs, labels)
        d_gp = model.gradient_penalty(
            real_imgs, fake_imgs.detach().clone().requires_grad_(True),labels)
        d_wass_loss = d_fake_logits.mean() - d_real_logits.mean()
        d_loss = d_wass_loss + d_gp * model.LAMBDA_GP
        
        # discriminator backpropagation
        d_loss.backward()
        model.d_optimizer.step()

        loss['d_loss'] += d_loss.item()
        loss['d_gp'] += d_gp.item()
        loss['d_wass_loss'] += d_wass_loss.item()
        
        if i % DG_TRAIN_RATIO == 0:
            # ***************
            # Train Generator
            # ***************
            
            # generator backpropagation        
            model.g_optimizer.zero_grad()
            # print(fake_imgs.shape)
            d_fake_logits = model.discriminator(fake_imgs, labels)
            g_loss = -torch.mean(d_fake_logits)
            
            loss['g_loss'] += g_loss.item()

            g_loss.backward()
            model.g_optimizer.step()

    n_batches = len(trainloader)
    for k in loss.keys():
        loss[k] /= n_batches
    loss['g_loss'] *= DG_TRAIN_RATIO
    return loss

In [None]:
def test_step(model: nn.Module, testloader: DataLoader,
              device: torch.device):
    
    loss = {'d_loss':.0, 'g_loss':.0}
    with torch.inference_mode():
        for real_imgs, labels in tqdm(testloader):
            real_imgs = real_imgs.to(device)
            labels = labels.to(device)

            fake_imgs = model.generator(labels)
            d_fake_logits = model.discriminator(fake_imgs, labels)
            d_real_logits = model.discriminator(real_imgs, labels)
            d_loss = d_fake_logits.mean() - d_real_logits.mean()
            g_loss = -torch.mean(d_fake_logits)
            
            loss['d_loss'] += d_loss.item()
            loss['g_loss'] += g_loss.item()
            
        n_batches = len(testloader)
        loss['d_loss'] /= n_batches
        loss['g_loss'] /= n_batches
        
    return loss

In [None]:
def train(model: nn.Module, trainloader: DataLoader,
          testloader: DataLoader, device: torch.device, 
          EPOCHS: int, DG_TRAIN_RATIO: int=3):
    model = model.to(device)
    results = {'train_loss':{'d_loss':[], 'g_loss':[], 'd_wass_loss':[], 'd_gp':[]},
               'test_loss':{'d_loss':[], 'g_loss':[]}}
    
    gif_rand_labels = torch.randint(0, 10, (64,)).to(device)
    gif_gen_imgs = []
    
    for epoch in range(1, EPOCHS+1):
        print(f"\nEPOCH ===================================> {epoch:3d}/{EPOCHS}")
        print("******************** Training ********************")
        train_results = train_step(model, trainloader, device, DG_TRAIN_RATIO)
        print(f"Generator Loss: ------------------------> {train_results['g_loss']:.4f}")
        print(f"Discriminator Loss: --------------------> {train_results['d_loss']:.4f}")
        print(f"Discriminator Wasserstein Loss: --------> {train_results['d_wass_loss']:.4f}")
        print(f"Discriminator Gradient Penalty: --------> {train_results['d_gp']:.4f}")

        print("******************** Testing ********************")
        test_results = test_step(model, testloader, device)
        print(f"Generator Loss: ------------------------> {test_results['g_loss']:.4f}")
        print(f"Discriminator Loss: --------------------> {test_results['d_loss']:.4f}")
        
        # saving results
        for k, v in train_results.items():
            results['train_loss'][k].append(v)
        for k, v in test_results.items():
            results['test_loss'][k].append(v)
        
        # gif frame generation
        gen_imgs = make_grid(model(gif_rand_labels).to(device='cpu')).permute(1, 2, 0).numpy()
        gen_imgs = ((gen_imgs + 1) * 127.5).astype(np.uint8)
        gif_gen_imgs.append(gen_imgs)
    
        if epoch % 50 == 0:
            # save the training GIF after all frames have been collected
            imageio.mimsave(f'../gifs/cifar10_cwgan_gp_{epoch}_epochs.gif', 
                            gif_gen_imgs, loop=65535)
            # save checkpoint
            torch.save({
                'epoch':epoch,
                'model_state_dict':model.state_dict(),
                'optimizers_state_dict': [model.g_optimizer.state_dict(),
                                          model.d_optimizer.state_dict()],
                'losses': results},  f'../models/cifar10_cwgan_gp_{epoch}_epochs.pth')
            
    imageio.mimsave(f'../gifs/cifar10_cwgan_gp_training.gif', gif_gen_imgs, loop=65535)
    return results

In [None]:
results = train(model, trainloader, testloader, device, EPOCHS, DG_TRAIN_RATIO)

### Visualization

In [None]:
def plot_losses(ax, title, g_losses, d_losses):
    ax.plot(g_losses, label='Generator Loss')
    ax.plot(d_losses, label='Discriminator Loss')
    ax.title(title)
    ax.xlabel('Epochs')
    ax.ylabel('Loss')
    ax.legend()

plt.figure(figsize=(14, 4))
plt.subplot(121)
plot_losses(plt, "Train Losses", results['train_losses_g'], results['train_losses_d'])

plt.subplot(122)
plot_losses(plt, "Test Losses", results['test_losses_g'], results['test_losses_d'])
plt.tight_layout()
plt.show()

### Load The Model

In [None]:
generator = Generator(LATENT_DIM, len(classes))
discriminator = Discriminator(len(classes), torch.tensor(imgs[0].shape))
g_optimizer = optim.Adam(generator.parameters())
d_optimizer = optim.Adam(discriminator.parameters())

model = CWGAN_GP(generator, discriminator)

checkpoint = torch.load("../models/cifar10_cwgan_gp_200_epochs.pth")
model.load_state_dict(checkpoint['model_state_dict'])
g_optimizer.load_state_dict(checkpoint['optimizers_state_dict'][0])
d_optimizer.load_state_dict(checkpoint['optimizers_state_dict'][1])
model.compile(g_optimizer, d_optimizer)

rand_labels = torch.randint(0, 10, (8,))
gen_imgs = model(rand_labels)
show_imgs(gen_imgs, rand_labels)