## Optimizing rotation during training
* Goal is to make a network that will represent all images at some cannonical orientation while being trained on variable orientation images.
* This should result in an overall more simple model able to do the same work.
* Switch between optimizing the model parameters, and optimizing the network.
* Could also do it on one batch at a fixed orientation first, but this is much less cool so we won't

For each batch:
1. Optimize the rotation parameters (randomly), pick the best, couple SGD steps.
2. Do the backward pass to update the weights.

Could iterate this process for a single batch, gradually getting the "best" representation but instead we will hope that over the course of training it will learn it for them all at once.

# TODO
1. Check Results - just like in other notebook, run examples, sample from vae, etc. 
2. Do standard graph of loss vs. rotation, comparing this vae to one trained on rotation augmented images. Hopefully we find that this one is better for a given latent size. Optimize both! the idea is that the other model should have learned to optimize more for other rotations, and so won't do as well on the standard ones. Big if.
3. Clean this and other notebook
4. Write a bunch of stuff (not exactly sure how to talk about this new addition, if it works). Guess:

**Outline**


**ABS/intro**
We propose an alternative method to enable models to generalize to affine transfotmations. We show that data augmentation alone has limitations. We show a method for enabling VAEs trained on a single orientation (a subset of the possible distribution) to generalize, as well as a way to train the VAE on the full distribution, without increasing the model capicity.

The model capicity required to work for this subset of the possible distribution (rotations of 0 degrees) is less than that required for a model expected to generalize to the more broader distribution of all possible rotations. This means that naively using data augmentation should increase the model capacity required.


**Generalizing to Rotations**
Given a VAE trained on images at one orientation $$r_0$$, we can make this model generalize to other rotations taking an image at some rotation, and finding the rotation that minimizes the reconstruction error. This is implemented as: (from paper)

But this procedure requires the images to be supplied at one given orientation, otherwise this model will devolve to the case of just doing data augmenetation, with the increased model capacit required. It would be better to have the model be able to be trained on a dataset without this restriction. 

**Rotation Training**
We use MNIST with random rotations. (or would it have been better to use fixed rotations? probably the same, because we ignore rotation when we train anyway through the random rotation search).

As an alternative to this, we use a procedure where we repeatedly optimize the rotation angle of each image during the training process. The goal is to have the VAE learn some orientation for each class of image that will minimize the loss for a given model complexity. It is not necessarily the case that the best way is to have all stored in the standard orientation we view images. We formulate this in terms of an optimization process added into the training loop, where for each batch of images the rotation of each image is optimized to reduce the model reconstruction loss before backpropagation is done. so overall thee training process is:

for each batch:
* try 20 random rotations $$r_opt$$, pick the one giving the lowest VAE loss
* do SGD on the rotation with lowest loss
* given this loss $$r_opt$$, perform standard backprop on the VAE.

This is implemented naively, with no attempt to be efficient. It is likely that it would be more useful to optimize per batch, but these were not tested. Another much better approach would be to combine an STN in the optimization process, as although this is not always correct, it could be much faster than simple random search. Process of use STN, eval, use STN again, ..., but mixed in with randomness and SGD. This way you get the benefit of explicitly optimizing the objective, with faster speed than the random way.

In this way we enable models to be equivariant to rotation through the use of optimization, as an alternative to explicting encoding this into the model (taco) or approximating it (stn). While in this current form the model is practically limited, it could be useful to use this idea of optimizing a representation over some set of transfroms to enforce generalization.

5. Add couple notes to github.

In [2]:
import os
import sys
import json
import math
import random
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from IPython.display import display

%reload_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image, ImageOps

sys.path.append(str(Path.cwd().parent))
from utils.display import read_img_to_np, torch_to_np
from utils.norms import MNIST_norm
import model.model as module_arch
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
from model.model import AffineVAE
from data_loader.data_loaders import make_generators_MNIST, make_generators_MNIST_CTRNFS

device = torch.device("cuda:0")

In [3]:
def get_model_loaders_config(PATH, old_gpu='cuda:0', new_gpu='cuda:1'):
    """PATH: path to dir where training results of a run are saved"""
    PATH = Path(PATH)
    config_loc = PATH / 'config.json'
    weight_path = PATH / 'model_best.pth'
    config = json.load(open(config_loc))
    
    
    def get_instance(module, name, config, *args):
        return getattr(module, config[name]['type'])(*args, **config[name]['args'])

    data_loader = get_instance(module_data, 'data_loader', config)['train']
    valid_data_loader = get_instance(module_data, 'data_loader', config)['val']
    model = get_instance(module_arch, 'arch', config)
    model = model.to(torch.device(new_gpu))
    checkpoint = torch.load(weight_path, map_location={'cuda:0': 'cuda:1'})
    state_dict = checkpoint['state_dict']
    
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)

    model.load_state_dict(state_dict)
    model = model.to(device).eval()
    
    loss_fn = get_instance(module_loss, 'loss', config)
    metric_fns = [getattr(module_metric, met) for met in config['metrics']]
    
    return model, data_loader, valid_data_loader, loss_fn, metric_fns, config

In [5]:
config_loc = '/media/rene/data/equivariance/mnist/vae_mnist_L16/0129_230250'
VAE, data_loader, valid_data_loader, loss_fn, metric_fns, config = get_model_loaders_config(config_loc, old_gpu='cuda:1', new_gpu='cuda:0')
VAE = VAE.to(device)

AffineVAE = getattr(module_arch, 'AffineVAE')
affine_model = AffineVAE(pre_trained_VAE=VAE, img_size=28, input_dim=1, output_dim=1, latent_size=8, use_STN=False)
affine_model = affine_model.to(device)

files_dict_loc = '/media/rene/data/MNIST/files_dict.pkl'
data_loaders = make_generators_MNIST_CTRNFS(files_dict_loc, batch_size=1, num_workers=4, 
                                            return_size=28, rotation_range=None, normalize=False)


def vae_loss_unreduced(output, target, KLD_weight=1):
    recon_x, mu_logvar  = output
    mu = mu_logvar[:, 0:int(mu_logvar.size()[1]/2)]
    logvar = mu_logvar[:, int(mu_logvar.size()[1]/2):]
    KLD = -0.5 * torch.sum(1 + 2 * logvar - mu.pow(2) - (2 * logvar).exp(), dim=1)
    BCE = F.mse_loss(recon_x, target, reduction='none')    
    BCE = torch.sum(BCE, dim=(1, 2, 3))
    loss = BCE + KLD_weight*KLD
    return loss


data, target = next(iter(data_loader))
batch_size = data.shape[0]
rot_x = rotate_mnist_batch(data, return_size=40, fixed_rotation=45)
rot_x, target = rot_x.to(device), target.to(device)
output = affine_model(rot_x, deterministic=True)
loss = vae_loss_unreduced(output, rot_x)
loss.size()

NameError: name 'rotate_mnist_batch' is not defined