In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch import autograd

import numpy as np
import scipy as sp
from scipy import stats
from matplotlib import pyplot as plt
import sys, os
import copy
import tqdm

import math
from torch.optim import Optimizer

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
path = '/content/drive/MyDrive/Numgan'

In [None]:
## Extragradient Algo


required = object()

class Extragradient(Optimizer):
    """Base class for optimizers with extrapolation step.
        Arguments:
        params (iterable): an iterable of :class:`torch.Tensor` s or
            :class:`dict` s. Specifies what Tensors should be optimized.
        defaults: (dict): a dict containing default values of optimization
            options (used when a parameter group doesn't specify them).
    """
    def __init__(self, params, defaults):
        super(Extragradient, self).__init__(params, defaults)
        self.params_copy = []

    def update(self, p, group):
        raise NotImplementedError

    def extrapolation(self):
        """Performs the extrapolation step and save a copy of the current parameters for the update step.
        """
        # Check if a copy of the parameters was already made.
        is_empty = len(self.params_copy) == 0
        for group in self.param_groups:
            for p in group['params']:
                u = self.update(p, group)
                if is_empty:
                    # Save the current parameters for the update step. Several extrapolation step can be made before each update but only the parameters before the first extrapolation step are saved.
                    self.params_copy.append(p.data.clone())
                if u is None:
                    continue
                # Update the current parameters
                p.data.add_(u)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        if len(self.params_copy) == 0:
            raise RuntimeError('Need to call extrapolation before calling step.')

        loss = None
        if closure is not None:
            loss = closure()

        i = -1
        for group in self.param_groups:
            for p in group['params']:
                i += 1
                u = self.update(p, group)
                if u is None:
                    continue
                # Update the parameters saved during the extrapolation step
                p.data = self.params_copy[i].add_(u)


        # Free the old parameters
        self.params_copy = []
        return loss


In [None]:

class ExtraAdam(Extragradient):
    """Implements the Adam algorithm with extrapolation step.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False):
        if not 0.0 <= lr:
         raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
         raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
         raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
         raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                     weight_decay=weight_decay, amsgrad=amsgrad)
        super(ExtraAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(ExtraAdam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    def update(self, p, group):
        if p.grad is None:
            return None
        grad = p.grad.data
        if grad.is_sparse:
            raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
        amsgrad = group['amsgrad']

        state = self.state[p]

        # State initialization
        if len(state) == 0:
            state['step'] = 0
            # Exponential moving average of gradient values
            state['exp_avg'] = torch.zeros_like(p.data)
            # Exponential moving average of squared gradient values
            state['exp_avg_sq'] = torch.zeros_like(p.data)
            if amsgrad:
                # Maintains max of all exp. moving avg. of sq. grad. values
                state['max_exp_avg_sq'] = torch.zeros_like(p.data)

        exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
        if amsgrad:
            max_exp_avg_sq = state['max_exp_avg_sq']
        beta1, beta2 = group['betas']

        state['step'] += 1

        if group['weight_decay'] != 0:
            grad = grad.add(group['weight_decay'], p.data)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(1 - beta1, grad)
        exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = max_exp_avg_sq.sqrt().add_(group['eps'])
        else:
            denom = exp_avg_sq.sqrt().add_(group['eps'])

        bias_correction1 = 1 - beta1 ** state['step']
        bias_correction2 = 1 - beta2 ** state['step']
        step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

        return -step_size*exp_avg/denom

In [None]:
#function for plottings using kernel density estimation. Same as the original implementation.
def kde(mu, tau, bbox=[-1.6, 1.6, -1.6, 1.6], save_file="", xlabel="", ylabel="", cmap='Blues'):
    values = np.vstack([mu, tau])
    kernel = sp.stats.gaussian_kde(values)

    fig, ax = plt.subplots()
    ax.axis(bbox)
    ax.set_aspect(abs(bbox[1]-bbox[0])/abs(bbox[3]-bbox[2]))
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom='off',      # ticks along the bottom edge are off
        top='off',         # ticks along the top edge are off
        labelbottom='off') # labels along the bottom edge are off
    plt.tick_params(
        axis='y',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        left='off',      # ticks along the bottom edge are off
        right='off',         # ticks along the top edge are off
        labelleft='off') # labels along the bottom edge are off
    
    xx, yy = np.mgrid[bbox[0]:bbox[1]:300j, bbox[2]:bbox[3]:300j]
    positions = np.vstack([xx.ravel(), yy.ravel()])
    f = np.reshape(kernel(positions).T, xx.shape)
    cfset = ax.contourf(xx, yy, f, cmap=cmap)

    if save_file != "":
        plt.savefig(save_file, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()

# function to plot eignevalues in the complex plane. Same as the original implementation.
def complex_scatter(points, bbox=None, save_file="", xlabel="real part", ylabel="imaginary part", cmap='Blues'):
    fig, ax = plt.subplots()

    if bbox is not None:
        ax.axis(bbox)

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    xx = [p.real for p in points]
    yy = [p.imag for p in points]
    
    plt.plot(xx, yy, 'X')
    plt.grid()

    if save_file != "":
        plt.savefig(save_file, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()        

In [None]:
from os.path import join
def plot_eigens(iteration):
    gen_out, real_in, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out = batch_net_outputs()
    gen_loss_detached, disc_loss_detached, gen_loss, disc_loss = net_losses(fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out)
    p_count = torch.cat([x.flatten() for x in params]).shape[0]

    gen_net.zero_grad()
    gen_grad = autograd.grad(gen_loss, gen_net.parameters(), retain_graph=True, create_graph=True)
    disc_net.zero_grad()
    disc_grad = autograd.grad(disc_loss, disc_net.parameters(), retain_graph=True, create_graph=True)

    v = list(gen_grad) + list(disc_grad)
    v = torch.cat([t.flatten() for t in v])
    jacobian = torch.zeros([p_count * p_count]).to(device)

    for i in range(p_count):
        jacobian[i*p_count: (i+1)*p_count] = torch.cat([x.flatten() for x in autograd.grad(v[i], params, retain_graph=True)])

    jacobian = jacobian.reshape([p_count, -1])

    jacobian2 = jacobian - gamma * torch.mm(jacobian.T, jacobian)

    eigens = torch.linalg.eigvals(jacobian)
    eigens2 = torch.linalg.eigvals(jacobian2)

    print(eigens.shape)
    breakpoint()


    save_path = join(path, 'Eig_v_' + str(iteration) + "_" + method + "_" + optim_name + "_" + str(z_dim) + "_" + str(gamma) + "_" + net_size + '.png')
    cmap='Blues'
    complex_scatter(eigens.cpu().detach().numpy(), bbox=[-1.0, 1.0, -0.15, 0.15], save_file=save_path, cmap=cmap)

    save_path = join(path, 'Eig_w_' + str(iteration) + "_" + method   + "_" + optim_name + "_" + str(z_dim) + "_" + str(gamma) + "_" + net_size +  '.png')
    complex_scatter(eigens2.cpu().detach().numpy(), bbox=[-1.0, 1.0, -0.15, 0.15], save_file=save_path, cmap=cmap)


def plot_kde(iteration, real_input=False):
    z = torch.normal(mean=0, std=1, size=[batch_size * 5, z_dim]).to(device)
    inp = gen_net(z)

    save_path = join(path, 'KDE', method + "_" + str(iteration)  + "_" + optim_name + "_" + str(z_dim) + "_" + str(gamma) + "_" + net_size + '.png')
    cmap='Blues'
    
    if real_input:
        angles = torch.tensor([2*np.pi*k/8 for k in range(batch_size * 5)]).to(device)
        mus = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)
        inp = mus + torch.normal(mean=0, std=sigma, size=[batch_size * 5, 2]).to(device)

        save_path = join(path, 'KDE', 'original.png')
        cmap='Reds'

    kde(inp[:, 0].cpu().detach().numpy(), inp[:, 1].cpu().detach().numpy(), save_file=save_path, cmap=cmap)


In [None]:
class Net(nn.Module):
    def __init__(self, indim, outdim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(indim, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 16)
        self.fc4 = nn.Linear(16, 16)
        self.fc5 = nn.Linear(16, outdim)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out

class Net_big(nn.Module):
    def __init__(self, indim, outdim):
        super(Net_big, self).__init__()
        self.fc1 = nn.Linear(indim, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 16)
        self.fc4 = nn.Linear(16, 16)
        self.fc5 = nn.Linear(16, 16)
        self.fc6 = nn.Linear(16, 16)
        self.fc7 = nn.Linear(16, 16)
        self.fc8 = nn.Linear(16, outdim)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = F.relu(self.fc5(out))
        out = F.relu(self.fc6(out))
        out = F.relu(self.fc7(out))
        out = self.fc8(out)

        return out

In [None]:
criterion = nn.BCEWithLogitsLoss()
def net_losses(fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out):
    disc_loss_detached = criterion(input=fake_d_out_disc, target=torch.zeros_like(fake_d_out_disc))
    disc_loss_detached += criterion(real_d_out, target=torch.ones_like(real_d_out))

    gen_loss_detached =  criterion(fake_d_out_gen, target=torch.ones_like(fake_d_out_gen))

    disc_loss = criterion(input=fake_d_out, target=torch.zeros_like(fake_d_out))
    disc_loss += criterion(real_d_out, target=torch.ones_like(real_d_out))

    gen_loss =  criterion(fake_d_out, target=torch.ones_like(fake_d_out))
    
    return gen_loss_detached, disc_loss_detached, gen_loss, disc_loss

In [None]:
#Parameters
z_dim = 32
gamma = 10.0
lr = 1e-4
sigma = 1e-2
steps = 20000
batch_size = 512
method = 'ConsOpt' #'SimGA' #'ConsOpt'
optim_name = 'Adam'
net_size = 'small'

lr_adam = 3e-4
beta = 0.55
alpha = 0.6

In [None]:
if net_size == "small":
  gen_net = Net(z_dim, 2).to(device)
  disc_net = Net(2, 1).to(device)
else:
  gen_net = Net_big(z_dim, 2).to(device)
  disc_net = Net_big(2, 1).to(device)

params = list(gen_net.parameters()) + list(disc_net.parameters())


if optim_name == "Adam":
  gen_opt = optim.Adam(gen_net.parameters(), lr=lr_adam, betas=(0.5, 0.9))
  disc_opt = optim.Adam(disc_net.parameters(), lr=lr_adam, betas=(beta, 0.9))
elif optim_name == 'ExtraAdam':
  dis_optimizer = ExtraAdam(disc_net.parameters(), lr=2e-4, betas=(0.5, 0.9))
  gen_optimizer = ExtraAdam(gen_net.parameters(), lr=2e-5, betas=(0.5, 0.9))
else:
  gen_opt = optim.RMSprop(gen_net.parameters(), lr=lr)
  disc_opt = optim.RMSprop(disc_net.parameters(), lr=lr)


In [None]:
def batch_net_outputs():
    z = torch.normal(mean=0, std=1, size=[batch_size, z_dim]).to(device)
    gen_out = gen_net(z)
    fake_d_out_gen = copy.deepcopy(disc_net)(gen_out)
    fake_d_out_disc = disc_net(gen_out.detach())
    fake_d_out = disc_net(gen_out)

    angles = torch.tensor([2*np.pi*k/8 for k in range(batch_size)]).to(device)
    mus = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)
    real_in = mus + torch.normal(mean=0, std=sigma, size=[batch_size, 2]).to(device)
    real_d_out = disc_net(real_in)

    return gen_out, real_in, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out

In [None]:
# Uncomment if you needed to load a model

# gen_path = join(path, 'Models', 'gen_SimGA_0.pt')
# disc_path = join(path, 'Models', 'disc_SimGA_0.pt')
# gen_net.load_state_dict(torch.load(gen_path))
# disc_net.load_state_dict(torch.load(disc_path))

In [None]:
# plot_kde(0, real_input=True)

In [None]:
n_gen_update = 0
for i in tqdm.notebook.tqdm(range(steps+1)):

    gen_out, real_in, fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out = batch_net_outputs()
    gen_loss_detached, disc_loss_detached, gen_loss, disc_loss = net_losses(fake_d_out_gen, fake_d_out_disc, fake_d_out, real_d_out)

    if i%5000 == 0:
        # if method == 'ConsOpt':
        plot_eigens(i)
        plot_kde(i)
        
        gen_path = join(path, 'Models', 'gen_' + method + "_" + str(i) + "_" + optim_name + "_" + str(z_dim) + "_" + str(gamma) + "_" + net_size + '.pt')
        disc_path = join(path, 'Models', 'disc_'+ method + "_" + str(i)  + "_" + optim_name + "_" + str(z_dim) + "_" + str(gamma) + "_" + net_size + '.pt')
        torch.save(gen_net.state_dict(), gen_path)
        torch.save(disc_net.state_dict(), disc_path)      

    if method == 'ConsOpt':

        gen_net.zero_grad()
        gen_grad = autograd.grad(gen_loss, gen_net.parameters(), retain_graph=True, create_graph=True)
        disc_net.zero_grad()
        disc_grad = autograd.grad(disc_loss, disc_net.parameters(), retain_graph=True, create_graph=True)

        v = list(gen_grad) + list(disc_grad)
        v = torch.cat([t.flatten() for t in v])

        L = 1/2 * torch.dot(v, v)
        jgrads = autograd.grad(L, params, retain_graph=True)
        
        gen_opt.zero_grad()

        for i in range(len(params)):
               params[i].grad = jgrads[i] * gamma
        gen_loss_detached.backward(retain_graph=True, create_graph=True)
        gen_opt.step()

        disc_opt.zero_grad()

        for i in range(len(params)):
               params[i].grad = jgrads[i] * gamma
        disc_loss_detached.backward(retain_graph=True, create_graph=True)
        disc_opt.step()
    elif method == "ExtraAdam":

        for p in gen_net.parameters():
            p.requires_grad = False
        dis_optimizer.zero_grad()
        disc_loss.backward(retain_graph=True)

        if (i+1)%2 != 0:
            dis_optimizer.extrapolation()
        else:
            dis_optimizer.step()

        for p in gen_net.parameters():
            p.requires_grad = True

        for p in disc_net.parameters():
            p.requires_grad = False
        gen_optimizer.zero_grad()
        gen_loss.backward()

        if (i+1)%2 != 0:
            gen_optimizer.extrapolation()
        else:
            n_gen_update += 1
            gen_optimizer.step()
      
        for p in disc_net.parameters():
            p.requires_grad = True
    else:
        gen_opt.zero_grad()
        gen_loss_detached.backward(retain_graph=True, create_graph=True)
        gen_opt.step()

        disc_opt.zero_grad()
        disc_loss_detached.backward(retain_graph=True, create_graph=True)
        disc_opt.step()



  0%|          | 0/20001 [00:00<?, ?it/s]


sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/lib/python3.9/bdb.py", line 334, in set_trace
    sys.settrace(self.trace_dispatch)



torch.Size([2259])
> <ipython-input-18-d0ca30d8166c>(30)plot_eigens()
-> save_path = join(path, 'Eig_v_' + str(iteration) + "_" + method + "_" + optim_name + "_" + str(z_dim) + "_" + str(gamma) + "_" + net_size + '.png')
(Pdb) eigens.shape
torch.Size([2259])
(Pdb) n
> <ipython-input-18-d0ca30d8166c>(31)plot_eigens()
-> cmap='Blues'
(Pdb) eigen2.shape
*** NameError: name 'eigen2' is not defined
(Pdb) eigens2
tensor([-20.4477+0.0000j,  -1.8414+0.4181j,  -1.8414-0.4181j,  ...,
          0.0000+0.0000j,   0.0000+0.0000j,   0.0000+0.0000j], device='cuda:0')
(Pdb) eigens2.shape
torch.Size([2259])
--KeyboardInterrupt--
(Pdb) exit



sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/lib/python3.9/bdb.py", line 359, in set_quit
    sys.settrace(None)



BdbQuit: ignored

In [None]:
eigens

NameError: ignored