In [2]:
import torch
import numpy as np
import math
import argparse
from torch.autograd import Variable
from augerino import datasets, models, losses
import glob
import re
import pandas as pd
import sys

from data.generate_data import *
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
softplus = torch.nn.Softplus()
savedir = "./saved-outputs/"

ntrain = 10000
ntest = 5000

trainloader, testloader = generate_mario_data(ntrain=ntrain, ntest=ntest,
                                              batch_size=128, dpath="./data/")

  train_images = train_images[np.ix_(trainshuffler), ::].squeeze()


In [4]:
def trainer(model, reg=0.01, epochs=20):
    
    optimizer = torch.optim.Adam(model.parameters(),lr=0.01, weight_decay=0.)
    
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        model = model.cuda()

    logger = []

    criterion = losses.unif_aug_loss

    for epoch in range(epochs):  # loop over the dataset multiple times
        for i, data in enumerate(trainloader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            if use_cuda:
                inputs, labels = inputs.cuda(), labels.cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            # print(inputs.shape)
            outputs = model(inputs)
            loss = criterion(outputs, labels, model,
                            reg=reg)
            loss.backward()
            optimizer.step()

            log = softplus(model.aug.width).tolist()
            log += model.aug.width.grad.data.tolist()
            log += [loss.item()]
            logger.append(log)
            
    logdf = pd.DataFrame(logger)
    logdf.columns = ['width' + str(i) for i in range(6)] + ['grad' + str(i) for i in range(6)] + ['loss']
    logdf = logdf.reset_index()
    return logdf

In [5]:
savedir = "/home/sp2058/augerino/experiments/mario-iggy/saved-outputs/"

In [6]:
net = models.SimpleConv(c=32, num_classes=4)
augerino = models.UniformAug()
high_model = models.AugAveragedModel(net, augerino,ncopies=1)

start_widths = torch.ones(6) * -5.
start_widths[2] = -1.
high_model.aug.set_width(start_widths)

high_logger = trainer(high_model, reg=0.1)

In [None]:
torch.save(high_model.state_dict(), savedir + "highreg.pt")
high_logger.to_pickle(savedir + "high_logger.pkl")

In [None]:
net = models.SimpleConv(c=32, num_classes=4)
augerino = models.UniformAug()
low_model = models.AugAveragedModel(net, augerino,ncopies=1)

start_widths = torch.ones(6) * -5.
start_widths[2] = -1.

low_model.aug.set_width(start_widths)
low_logger = trainer(low_model, reg=0.01)

In [None]:
torch.save(low_model.state_dict(), savedir + "lowreg.pt")
low_logger.to_pickle(savedir + "low_logger.pkl")

In [None]:
net = models.SimpleConv(c=32, num_classes=4)
augerino = models.UniformAug()
mid_model = models.AugAveragedModel(net, augerino,ncopies=1)

start_widths = torch.ones(6) * -5.
start_widths[2] = -1.

mid_model.aug.set_width(start_widths)
mid_logger = trainer(mid_model, reg=0.05)

In [None]:
torch.save(high_model.state_dict(), savedir + "midreg.pt")
mid_logger.to_pickle(savedir + "mid_logger.pkl")

## Plotting

In [None]:
low_logger['lowbd'] = -low_logger['width2']/2.
low_logger['upbd'] = low_logger['width2']/2.
high_logger['lowbd'] = -high_logger['width2']/2.
high_logger['upbd'] = high_logger['width2']/2.
mid_logger['lowbd'] = -mid_logger['width2']/2.
mid_logger['upbd'] = mid_logger['width2']/2.

In [None]:
alpha = 0.1
lwd = 0.

def plot_shade(logger, ax, color, label=""):
    ax.fill_between(logger.index, logger['lowbd'], logger['upbd'],
                    alpha=alpha, color=color,
                    linewidth=lwd)
    sns.lineplot(x=logger.index, y='lowbd', color=color, data=logger, label=label)
    sns.lineplot(x=logger.index, y='upbd', color=color, data=logger)

In [None]:
tick_pts = [-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]
tick_labs = [r"-$\pi$/2", r'-$\pi$/4', '0', r'$\pi$/4', r'$\pi$/2']

In [None]:
fig, ax0 = plt.subplots(1, 1, figsize=(8, 4), dpi=100)
fs = 14
pal = sns.color_palette("tab10")
col0 = pal[0]
col1 = pal[1]
col2 = pal[2]

plot_shade(low_logger, ax0, col0, "Low Reg")
plot_shade(mid_logger, ax0, col1, "Mid Reg")
plot_shade(high_logger, ax0, col2, "High Reg")

# ax0.set_title("Rotation Distributions")
ax0.set_xlabel("Iteration", fontsize=fs)
ax0.set_ylabel("Rotation Width", fontsize=fs)
# ax0.set_title("CE Losses")
ax0.tick_params("both", labelsize=fs-2)
sns.despine()
ax0.set_xticks([])
ax0.set_yticks(tick_pts)
ax0.set_yticklabels(tick_labs)
# ax0.set_xlim(0, 500)
# ax0.legend()
# plt.setp(ax0.get_legend().get_texts(), fontsize=fs-4) # for legend text
plt.show()