In [None]:
import repitl.matrix_itl as itl
import repitl.kernel_utils as ku
import numpy as np 
import matplotlib.pyplot as plt
import torch
from IPython import display

# DCGAN for CIFAR10

In [None]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
from torchvision import datasets, transforms

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
# Root directory for dataset
dataroot = "./data/cifar10"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 64

# Size of RFFs
rff_size = 256

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 32

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of projected features (top layer in discriminator)
npf = 64

# Number of training epochs
num_epochs = 200

# Learning rate for optimizers
lr = 0.00005

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# kernel parameter for itl objective
sigma = np.sqrt(npf/2)

# number of updates for generator per disc update
g_iter = 80

# Kernel employed to compare representations
topKernel = ku.factorizedLaplacianKernel
# topKernel = ku.gaussianKernel
# topKernel = ku.ellipticalLaplacianKernel



In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d( ngf * 2, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, input):
        return self.main(input)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 32 x 32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        )
        self.lin_proj = nn.Linear(ndf * 4 *16, npf)

    def forward(self, input):
        x = self.main(input)
        return self.lin_proj(x.view(x.size()[0], -1))

# Defining objective function

In [None]:
def permuteGram(K):
    assert K.shape[0] == K.shape[1], f"matrix dimensions must be the same"
    idx = torch.randperm(K.shape[0])
    K = K[idx, :]
    K = K[:, idx]
    return K

def permuteData(X):
    idx = torch.randperm(X.shape[0])
    X = X[idx]
    return X

def rffDoPE(X,Y ):
    pass

def matrixDiME(Kx, Ky, alpha=1.01):
    ''' Single permutation DiME'''
    Ky_perm = permuteGram(Ky)
    H = itl.matrixAlphaJointEntropy([Kx, Ky], alpha=alpha)
    H_perm = itl.matrixAlphaJointEntropy([Kx, Ky_perm], alpha=alpha)
    DiME = H_perm - H
    return DiME

def matrixCondEnt(Kl,Kz, alpha=1.01):
    Hj = itl.matrixAlphaJointEntropy([Kz, Kl], alpha=alpha)
    Hz = itl.matrixAlphaEntropy(Kz, alpha=alpha)
    Hl_z = Hj - Hz
    return Hl_z



# Defining training function

In [None]:
fixed_noise = torch.randn(64, nz, 1, 1, device=torch.device("cuda:1" if True else "cpu"))

def train(args, gen, disc, device, dataloader, optimizerG, optimizerD, epoch, iters):
    gen.train()
    disc.train()
    img_list = []

    ax1 = plt.subplot(121)
    ax2 = plt.subplot(122)
    # ax3 = plt.subplot(133)
    for i, data in enumerate(dataloader, 0):

        #*****
        # Update Discriminator with DiME
        #*****
        ## Train with real+fake batch
        # disc.train()
        disc.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        # Forward pass real batch through D
        real_output = disc(real_cpu)
        # Generate batch of latent vectors
        noise = torch.randn(b_size, config.nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = gen(noise)
        # Map fake batch with D don't retain gradients
        fake_output = disc(fake.detach())
        # compute DiME divergence between real and fake representations
        output = torch.cat((real_output, fake_output), dim=0)
        
        Kz = topKernel(output, output, sigma=sigma)
        l = torch.ones(2*b_size, dtype=torch.long)
        l[b_size:] = 0
        L = torch.nn.functional.one_hot(l).type(real_output.dtype)    
        Kl = torch.matmul(L, L.t()).to(device)
        DiME = matrixDiME(Kl, Kz)
        d_loss = -DiME
        d_loss.backward()
        # Update D
        optimizerD.step()

        #*****
        # Update Generator
        #*****
        G_obj = torch.tensor(0, dtype=d_loss.dtype)
        j = 0
        while (j < args.g_iter) and (G_obj < DiME):
            # disc.eval()
            gen.zero_grad()
            # Perform another forward pass of real+fake batch through D
            fake = gen(noise)
            fake_output = disc(fake)
            # Calculate G's loss based on this output
            output = torch.cat((real_output.detach(), fake_output), dim=0)
            Kz = topKernel(output, output, sigma=sigma)
            Hl_z = matrixCondEnt(Kl, Kz)
            g_loss = -Hl_z
            g_loss.backward()
            # Update G
            optimizerG.step()
            G_obj = Hl_z.detach()
            j += 1
        if  j == args.g_iter:
            print(f'Went over {args.g_iter}. DiME: {DiME.item()}, G_obj: {G_obj.item()}')
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
                % (epoch, args.epochs, i, len(dataloader),
                    DiME.item(), G_obj.item()))

      # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == args.epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = gen(fixed_noise).detach().cpu()
            grid_im = vutils.make_grid(fake, padding=2, normalize=True)
            display.clear_output(wait=True)
            ax1.cla()
            ax1.imshow(grid_im.permute([1,2,0]))
            ax2.cla()
            output = output.detach().to('cpu').numpy()
            ax2.scatter(output[:,0], output[:,1], c=l)
            display.display(plt.gcf())
            print("epoch %d"%(epoch,))
        iters += 1

In [None]:
class Config():
    def __init__(self):
        self.batch_size = batch_size 
        self.epochs = num_epochs         
        self.lr = lr              
        self.beta1 = beta1
        self.nz = nz          
        self.no_cuda = False         
        self.seed = manualSeed # random seed (default: 42)
        self.g_iter = g_iter
        
config = Config()

In [None]:
def main():
    use_cuda = not config.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda:1" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    
    # Set random seeds and deterministic pytorch for reproducibility
    random.seed(config.seed)       # python random seed
    torch.manual_seed(config.seed) # pytorch random seed
    np.random.seed(config.seed) # numpy random seed
    torch.backends.cudnn.deterministic = True

    # Load the dataset
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = datasets.CIFAR10(root=dataroot, train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,
                                              shuffle=True, num_workers=workers)

    # Create the generator
    netG = Generator(ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.02.
    netG.apply(weights_init)

    # Create the Discriminator
    netD = Discriminator(ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.02.
    netD.apply(weights_init)


    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
    
    iters = 0
    for epoch in range(1, config.epochs + 1):
        train(config, netG, netD, device, trainloader, optimizerG, optimizerD, epoch, iters)
    
    # Save the model checkpoint.
    torch.save(netG.state_dict(), "model_dcgan.h5")

if __name__ == '__main__':
    main()