In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
import torch.nn.init as init
import tqdm
import sys
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

REAL = 1.
FAKE = 0.

In [3]:
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=110, out_features=384*4*4),
            nn.ReLU(inplace=True),

            Reshape((-1, 384, 4, 4)),

            nn.ConvTranspose2d(in_channels=384, out_channels=192, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(in_channels=192, out_channels=96, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(in_channels=96, out_channels=3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(self.init_weights)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        return self.model(x)

    def init_weights(self, m):
      if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):
        init.normal_(m.weight, mean=0, std=0.02)
        if m.bias is not None:
            init.zeros_(m.bias)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu1 = nn.LeakyReLU(0.2)
        self.dropout1 = nn.Dropout(0.5)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(32)
        self.relu2 = nn.LeakyReLU(0.2)
        self.dropout2 = nn.Dropout(0.5)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.batch_norm3 = nn.BatchNorm2d(64)
        self.relu3 = nn.LeakyReLU(0.2)
        self.dropout3 = nn.Dropout(0.5)

        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.batch_norm4 = nn.BatchNorm2d(128)
        self.relu4 = nn.LeakyReLU(0.2)
        self.dropout4 = nn.Dropout(0.5)

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False)
        self.batch_norm5 = nn.BatchNorm2d(256)
        self.relu5 = nn.LeakyReLU(0.2)
        self.dropout5 = nn.Dropout(0.5)

        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False)
        self.batch_norm6 = nn.BatchNorm2d(512)
        self.relu6 = nn.LeakyReLU(0.2)
        self.dropout6 = nn.Dropout(0.5)

        self.discriminator = nn.Linear(512 * 4 * 4, 1)
        self.classifier = nn.Linear(512 * 4 * 4, 10)

        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()

        self.apply(self.init_weights)


    def forward(self, x):
        x = self.dropout1(self.relu1(self.conv1(x)))
        x = self.dropout2(self.relu2(self.batch_norm2(self.conv2(x))))
        x = self.dropout3(self.relu3(self.batch_norm3(self.conv3(x))))
        x = self.dropout4(self.relu4(self.batch_norm4(self.conv4(x))))
        x = self.dropout5(self.relu5(self.batch_norm5(self.conv5(x))))
        x = self.dropout6(self.relu6(self.batch_norm6(self.conv6(x))))
        x = x.view(-1, 512 * 4 * 4)
        discriminator = self.discriminator(x)
        classifier = self.classifier(x)
        choice = self.sigmoid(discriminator).view(-1, 1).squeeze(1)
        classes = self.softmax(classifier)

        return choice, classes
    def init_weights(self, m):
      if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.normal_(m.weight, mean=0, std=0.02)
        if m.bias is not None:
            init.zeros_(m.bias)

In [4]:
def scale_transform(image):
  return image * 2 - 1
def get_data() -> DataLoader:
  data_transforms = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    scale_transform
  ])

  dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=data_transforms)
  # loader = DataLoader(dataset, batch_size=100, shuffle=False)
  loader = DataLoader(dataset, batch_size=100, shuffle=True)

  return loader

def create_random_seeding():
  batch_tensors = []
  labels = []
  for _ in range(100):
      random_part = torch.normal(mean=0, std = 1, size = (100, 1, 1)) # better to use uniform random or normal random?
      one_hot_vector = torch.zeros(10, 1, 1)
      label = torch.randint(0, 10, (1,))
      labels.append(label)
      one_hot_vector[label] = 1
      complete_tensor = torch.cat((random_part, one_hot_vector), dim=0)
      batch_tensors.append(complete_tensor)

  final_tensor = torch.stack(batch_tensors)

  return final_tensor.clone().detach().requires_grad_(True), torch.tensor(labels)

def get_dis_accuracies(dis_labels, dis_predictions):
  correct = 0

  for y, y_hat in zip(dis_labels, dis_predictions):
    if y == round(float(y_hat.detach().cpu().numpy())):
      correct += 1
  return correct / len(dis_labels)

def get_class_accuracies(class_labels, class_predictions):
  correct = 0
  for y, y_hat in zip(class_labels, class_predictions):
    if y == y_hat.argmax(dim = -1):
      correct += 1
  return correct / len(class_labels)
def save_model_state(model, epoch, filepath):
  state = {
      'state_dict': model.state_dict(),
      'epoch': epoch
  }
  torch.save(state, filepath)

def create_sample_seeding() -> torch.Tensor:
  batch_tensors = []
  for i in range(20):
      random_part = torch.normal(mean=0, std = 1, size = (100, 1, 1)) # better to use uniform random or normal random?
      one_hot_vector = torch.zeros(10, 1, 1)
      label = i % 10
      one_hot_vector[label] = 1
      complete_tensor = torch.cat((random_part, one_hot_vector), dim=0)
      batch_tensors.append(complete_tensor)

  final_tensor = torch.stack(batch_tensors)

  return final_tensor

def sample_images(netG, save=False, epoch=0):
  seeds = create_sample_seeding()
  seeds = seeds.to(device)
  netG.eval()
  with torch.no_grad():
    imgs = netG(seeds)

  imgs = imgs / 2 + .5
  grid = torchvision.utils.make_grid(imgs, nrow=10, padding = 4)
  grid = grid.detach().cpu().numpy()
  plt.imshow(np.transpose(grid, (1,2,0)))
  plt.xticks([])
  plt.yticks([])
  if save:
    plt.savefig(f"img_e{epoch}.png")
    np.save(f"grid_e{epoch}.npy", imgs.detach().cpu().numpy())
  else:
    plt.show()

def sample_real_images():
  dataloader = get_data()
  data_iter = iter(dataloader)
  images, labels = next(data_iter)
  sample_image, sample_label = images[0], labels[0]
  sample_image = sample_image / 2 + .5
  plt.imshow(np.transpose(sample_image, (1,2,0)))
  plt.show()

In [5]:
netG = Generator()
netD = Discriminator()

G_path = "G_100wac.pth"
D_path = "D_100wac.pth"
netG.load_state_dict(torch.load(G_path)['state_dict'])
netD.load_state_dict(torch.load(D_path)['state_dict'])

netG = netG.to(device)
netD = netD.to(device)

netG.train()
netD.train()

dis_criterion = nn.BCELoss().to(device)
class_criterion = nn.NLLLoss().to(device)

optimizer_G = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999))

dataloader = get_data()

gradientsD = []
gradientsG = []

for real_Xs, real_ys in tqdm.tqdm(dataloader, desc='training...', file=sys.stdout):
  # for real_Xs, real_ys in dataloader:
    # TRAIN DISCRIMINATOR
    netD.train()
    netD.zero_grad()

    real_Xs = real_Xs.to(device)
    real_ys = real_ys.to(device)

    # D on real data
    dis_prediction, class_prediction = netD(real_Xs)
    dis_labels = torch.full((100, ), REAL, requires_grad=True).to(device)
    loss_R = dis_criterion(dis_prediction, dis_labels) + class_criterion(class_prediction, real_ys)
    loss_R.backward()

    # Generate fake data
    fake_seeding, fake_labels = create_random_seeding()
    fake_seeding = fake_seeding.to(device)
    fake_labels = fake_labels.to(device)
    fake_Xs = netG(fake_seeding)
    fake_Xs = fake_Xs.to(device)

    # D on fake data
    dis_prediction, class_prediction = netD(fake_Xs.detach())
    dis_labels = torch.full((100, ), FAKE, requires_grad=True).to(device)
    loss_F = class_criterion(class_prediction, fake_labels) + dis_criterion(dis_prediction, dis_labels)
    loss_F.backward()

    gradientsD = [param.grad for param in netD.parameters()]
    netD.zero_grad()

    # TRAIN GENERATOR
    netG.train()
    netG.zero_grad()

    dis_prediction, class_prediction = netD(fake_Xs)
    dis_labels = torch.full((100, ), REAL, requires_grad=True).to(device)

    loss_G = class_criterion(class_prediction, fake_labels) + dis_criterion(dis_prediction, dis_labels)
    loss_G.backward()
    gradientsG = [param.grad for param in netG.parameters()]
    optimizer_G.step()

    break
print(gradientsG)
print(gradientsD)

FileNotFoundError: ignored