# Optimizing Affine Transforms

### For testing:
* Need to first check that all the right tensors require grad (maybe just print it out)
* Need to check that the correct tensors are changing in the optimization (use a check to see what is changing and print them out) *** Maybe looking at the optimizer is enough to know what ones can be updated, but won't show what will in practice.

* Next level debugging is that the syntax is all good, but checking that there aren't some exploding gradients

In [10]:
import os
import sys
import json
import math
import random
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from IPython.display import display, HTML

%reload_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image, ImageOps

# Nicer way to import the module?
sys.path.append(str(Path.cwd().parent))

from utils.display import read_img_to_np, torch_to_np
from utils.norms import MNIST_norm
import model.model as module_arch
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
from model.model import AffineVAE

device = torch.device("cuda:0")

from data_loader.data_loaders import make_generators_MNIST, make_generators_MNIST_CTRNFS

In [11]:
def get_model_loaders_config(PATH, old_gpu='cuda:0', new_gpu='cuda:1'):
    """PATH: path to dir where training results of a run are saved"""
    PATH = Path(PATH)
    config_loc = PATH / 'config.json'
    weight_path = PATH / 'model_best.pth'
    config = json.load(open(config_loc))
    
    
    def get_instance(module, name, config, *args):
        return getattr(module, config[name]['type'])(*args, **config[name]['args'])

    data_loader = get_instance(module_data, 'data_loader', config)['train']
    valid_data_loader = get_instance(module_data, 'data_loader', config)['val']
    model = get_instance(module_arch, 'arch', config)
    model = model.to(torch.device(new_gpu))
    checkpoint = torch.load(weight_path, map_location={'cuda:0': 'cuda:1'})
    state_dict = checkpoint['state_dict']
    
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)

    model.load_state_dict(state_dict)
    model = model.to(device).eval()
    
    loss_fn = get_instance(module_loss, 'loss', config)
    metric_fns = [getattr(module_metric, met) for met in config['metrics']]
    
    return model, data_loader, valid_data_loader, loss_fn, metric_fns, config


# def display_batch(x, recon_x):
#     """Display tensor images"""
#     fig, ax = plt.subplots(x.size()[0], 2, sharex='col', sharey='row',figsize=(10,10))
#     for i in range(x.size()[0]):
#         ax[i, 0].imshow(torch_to_np(x[i]), cmap='Greys',  interpolation='nearest')
#         ax[i, 1].imshow(torch_to_np(recon_x[i]), cmap='Greys',  interpolation='nearest')
#         ax[i, 0].axis('off')
#         ax[i, 1].axis('off')
#     plt.show()

        
# def visualize_dataloader(data_loader, device, num_samples=6, bw=True):
#     """Randomly sample from dataloader and display"""
#     with torch.cuda.device(device.index):
#         fig, ax = plt.subplots(int(num_samples/2), 2, sharex='col', sharey='row',figsize=(10,10))
#         for i in range(0, int(num_samples/2)):
#             # sample randomly from the nth batch for the nth row of imgs
#             data, label = next(iter(data_loader))
#             tensor_img = data[random.randint(0, data.size()[0]),:, :, :]
#             ax[i, 0].imshow(torch_to_np(tensor_img), cmap='Greys',  interpolation='nearest')
#             tensor_img = data[random.randint(0, data.size()[0]), :, :, :]
#             ax[i, 1].imshow(torch_to_np(tensor_img), cmap='Greys',  interpolation='nearest')
#             print(f'tensor_img size: {tensor_img.size()}')
#             print(f'Max: {torch.max(tensor_img)}, Min: {torch.min(tensor_img)}')
#             ax[i, 0].axis('off')
#             ax[i, 1].axis('off')
    

# def display_results_auto(vae_model, config, device, rotate=0, pad = 6,
#                          norm=None, num_samples=3, data='bw', size = 28, label_col_name='label', save_loc=None):
#     """ QUESTIONABLE, REDO IT"""
#     with torch.cuda.device(device.index): # ??? Why the fuck???        
#         files_dict_loc = config['data_loader']['args']['files_dict_loc']
#         with open(files_dict_loc, 'rb') as f:
#             files_df = pickle.load(f)['train']
            
#         if label_col_name:
#             all_labels = files_df[label_col_name].unique()
#         else:
#             all_labels = list(range(num_samples))
#         row_names = []
#         col_names = ['Original', "Reconstructed"]

#         fig, ax = plt.subplots(num_samples, 2, sharex='col', sharey='row',figsize=(10,10))

#         for i, label in enumerate(all_labels[0:num_samples]):
#             if label_col_name:
#                 sample_df = files_df.loc[files_df[label_col_name] == label].sample(n=1)
#                 label = sample_df[label_col_name].iloc[0]
#                 row_names.append(label)
#             else:
#                 sample_df = files_df.sample(n=1)
#             img_path = sample_df['path'].iloc[0]

#             if data == 'bw': # Assume MNIST
#                 img = read_img_to_np(img_path, bw=True)
#                 transform = transforms.Compose([
#                                                 transforms.RandomRotation((rotate, rotate), expand=True),
#                                                 transforms.Resize(size),
#                                                 transforms.Pad((pad, pad)),
#                                                 transforms.ToTensor(),
#                                                 transforms.Normalize(*MNIST_norm)])
#             else:
#                 img = read_img_to_np(img_path, bw=False, size=size)
#                 transform = transforms.Compose([
#                                                 transforms.RandomRotation((rotate, rotate), expand=True),
#                                                 transforms.Resize(size),
#                                                 transforms.ToTensor(),
#                                                 transforms.Normalize(*norm)])

#             tensor_img = transform(Image.open(img_path)).unsqueeze(0).to(device)
#             tensor_label = torch.from_numpy(np.array(label)).unsqueeze(0).type(torch.LongTensor).to(device)
#             output = vae_model(tensor_img,  deterministic=False)
#             recon_x = output[0]
            
#             ax[i, 0].imshow(torch_to_np(tensor_img), cmap='Greys',  interpolation='nearest')
#             ax[i, 1].imshow(torch_to_np(recon_x), cmap='Greys',  interpolation='nearest')
#             ax[i, 0].axis('off')
#             ax[i, 1].axis('off')

#         for curr_ax, col in zip(ax[0], col_names):
#             curr_ax.set_title(col)
#         if label_col_name:
#             for curr_ax, row in zip(ax[:,0], row_names):
#                 curr_ax.set_ylabel(row, rotation=0, size='large')
#         if save_loc:
#             plt.savefig(save_loc, bbox_inches='tight')

In [12]:
def pad_to_size(img, new_size):
    delta_width = new_size - img.size()[1]
    delta_height = new_size - img.size()[2]
    pad_width = delta_width //2
    pad_height = delta_height //2
    img = F.pad(img, (pad_height, pad_height, pad_width, pad_width), 'constant', 0)
    return img

def rotate_mnist_batch(x, return_size=40, fixed_rotation=None):
    """Rotate batch without squishing the img. Pad all imgs to same size"""
    batch_size = x.shape[0]
    rot_x = torch.zeros((batch_size, 1, return_size, return_size))
    for i in range(batch_size):
        img = TF.to_pil_image(x[i, :, :])
        if fixed_rotation:
            img = TF.rotate(img, fixed_rotation)

        img = transforms.ToTensor()(img)
        if return_size:
            img = pad_to_size(img, return_size)
        # MNIST norm, wrong because imgs are padded
        img = transforms.Normalize((0.1307,), (0.3081,))(img)
        rot_x[i, :, :, :] = img
    return rot_x
    

def get_vae_MNIST_perf(model, data_loader, loss_fn, metric_fns, device, fixed_rotation, return_size):
    """Evaluate performance on MNIST Dataset using a given rotation
    
    Dataloader should be MNISTCustomTRNFS with size=28x28, unnormalized, not rotated"""

    with torch.cuda.device(device.index):
        model = model.to(device)
        model.eval()

        total_loss = 0.0
        total_metrics = torch.zeros(len(metric_fns))

        with torch.no_grad():
            for i, (data, target) in enumerate(tqdm(data_loader)):
                batch_size = data.shape[0]
                new_data = rotate_mnist_batch(data, return_size=40, fixed_rotation=fixed_rotation)
                new_data, target = new_data.to(device), target.to(device)
                output = model(new_data, deterministic=True)

                # computing loss, metrics on test set
                loss = loss_fn(output, new_data)
                total_loss += loss.item() * batch_size
                for i, metric in enumerate(metric_fns):
                    total_metrics[i] += metric(output, new_data) * batch_size

        n_samples = len(data_loader.sampler)
        log = {'loss': total_loss / n_samples}
        log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)})
        return log

## Optimizing rotation using SGD and random restarts
* Can compute gradient of $\theta$ to optimize the rotation to minimize VAE loss
* Random restarts are needed because this tends to get caught in local optimum

In [28]:
def AFFINE_MNIST_rot_perf(model, data_loader, loss_fn, device, fixed_rotation, 
                          optimize=False, iterations=0, num_rand_restarts=200, num_imgs=1000):
    """Evaluate performance on MNIST Dataset using a given rotation
    Dataloader should be MNISTCustomTRNFS with size=28x28, unnormalized, not rotated"""

    with torch.cuda.device(device.index):
        model = model.to(device)
        model.eval()
        total_loss = 0.0

        with torch.no_grad():
            for i, (data, target) in enumerate(data_loader):
                batch_size = data.shape[0]
                rot_x = rotate_mnist_batch(data, return_size=40, fixed_rotation=fixed_rotation)
                rot_x, target = rot_x.to(device), target.to(device)
                if optimize:
                    best_affine, loss = affine_model.optimize_rotation(rot_x, num_times=num_rand_restarts, 
                                                                       iterations=iterations)
                else:
                    output = model(rot_x, deterministic=True, theta=0.0)
                    loss = loss_fn(output, rot_x).item()

                total_loss += loss * batch_size
                if i>num_imgs:
                    break

        n_samples = len(data_loader.sampler)
        log = {'loss': total_loss / num_imgs}
        return log

### Grab decent AVAE and dataloaders

In [29]:
config_loc = '/media/rene/data/equivariance/mnist/vae_mnist_L16/0129_230250'
VAE, data_loader, valid_data_loader, loss_fn, metric_fns, config = get_model_loaders_config(config_loc, old_gpu='cuda:1', new_gpu='cuda:0')
VAE = VAE.to(device)

AffineVAE = getattr(module_arch, 'AffineVAE')
affine_model = AffineVAE(pre_trained_VAE=VAE, img_size=28, input_dim=1, output_dim=1, latent_size=8, use_STN=False)
affine_model = affine_model.to(device)

files_dict_loc = '/media/rene/data/MNIST/files_dict.pkl'
data_loaders = make_generators_MNIST_CTRNFS(files_dict_loc, batch_size=1, num_workers=4, 
                                            return_size=28, rotation_range=None, normalize=False)

In [34]:
for iterations in [2, 5, 10, 20, 40]:
    log = AFFINE_MNIST_rot_perf(affine_model, data_loaders['val'], loss_fn, device, fixed_rotation=45, 
                                optimize=True, iterations=iterations, num_rand_restarts=1, num_imgs=1000)
    print(f"Iterations: {iterations} Loss: {log['loss']}")

Iterations: 2 Loss: 669.2090291748046
Iterations: 5 Loss: 664.2208488464355
Iterations: 10 Loss: 658.0452225341797
Iterations: 20 Loss: 660.5350748901367
Iterations: 40 Loss: 653.7899044799805


In [35]:
for iterations in [2, 5, 10, 20, 40]:
    log = AFFINE_MNIST_rot_perf(affine_model, data_loaders['val'], loss_fn, device, fixed_rotation=45, 
                                optimize=True, iterations=iterations, num_rand_restarts=10, num_imgs=1000)
    print(f"Iterations: {iterations} Loss: {log['loss']}")

Iterations: 2 Loss: 619.0337117309571
Iterations: 5 Loss: 612.3658827819825
Iterations: 10 Loss: 614.3349027099609


KeyboardInterrupt: 

In [26]:
for rand_restarts in [5, 10, 20, 40]:
    log = AFFINE_MNIST_rot_perf(affine_model, data_loaders['val'], loss_fn, device, fixed_rotation=45, 
                                optimize=True, iterations=20, num_rand_restarts=rand_restarts, num_imgs=1000)
    print(f"Rand Restarts: {rand_restarts} Loss: {log['loss']}")

Rand Restarts: 5 Loss: 6.292737271118164
Rand Restarts: 10 Loss: 6.367885220336914
Rand Restarts: 20 Loss: 6.184860906982422
Rand Restarts: 40 Loss: 6.0255543212890625


## Performance of optimization over different rotations

In [None]:
config_loc = '/media/rene/data/equivariance/mnist/vae_mnist_L16/0129_230250'
VAE, data_loader, valid_data_loader, loss_fn, metric_fns, config = get_model_loaders_config(config_loc, old_gpu='cuda:1', new_gpu='cuda:0')
VAE = VAE.to(device)

AffineVAE = getattr(module_arch, 'AffineVAE')
affine_model = AffineVAE(pre_trained_VAE=VAE, img_size=28, input_dim=1, output_dim=1, latent_size=8, use_STN=False)
affine_model = affine_model.to(device)

files_dict_loc = '/media/rene/data/MNIST/files_dict.pkl'
data_loaders = make_generators_MNIST_CTRNFS(files_dict_loc, batch_size=1, num_workers=4, 
                                            return_size=28, rotation_range=None, normalize=False)

results = pd.DataFrame()
results_opt = pd.DataFrame()
for rotation in range(0, 180, 15):
    print(f'rotation: {rotation}')
    log = AFFINE_MNIST_rot_perf(affine_model, data_loaders['val'], loss_fn, device, fixed_rotation=rotation, 
                                optimize=False, iterations=0, num_rand_restarts=0, num_imgs=1000)
    log['rotation'] = rotation
    results = results.append(log, ignore_index=True)
    
    log_opt = AFFINE_MNIST_rot_perf(affine_model, data_loaders['val'], loss_fn, device, fixed_rotation=rotation, 
                                optimize=True, iterations=10, num_rand_restarts=40, num_imgs=1000)
    log_opt['rotation'] = rotation
    results_opt = results_opt.append(log_opt, ignore_index=True)

results.to_csv('/media/rene/code/equivariance/results/affine_rot_nonopt_sgd20_r40.csv')
results_opt.to_csv('/media/rene/code/equivariance/results/affine_rot_opt_sgd20_r40.csv')

fig, ax = plt.subplots()
ax.plot(results['rotation'], results['loss'], label="VAE")
ax.plot(results_opt['rotation'], results_opt['loss'], label="AVAE")

ax.set(xlabel='Rotation', ylabel='VAE Loss',
       title='Effect of optimizing rotation on VAE loss')
ax.legend(loc='best')

plt.savefig('/media/rene/code/equivariance/imgs/rotation_opt_sgd20_r40.png', bbox='tight')
plt.show()

rotation: 0


# TODO
1. Check graphs/results
2. Clean this notebook
3. Extend to jointly optimizing across a few variables - Finish making the new trainer with the unreduced loss and add in the optimization loop.
4. Test it and run it on the real datasets.

## Optimizing rotation during training
* Goal is to make a network that will represent all images at some cannonical orientation while being trained on variable orientation images.
* This should result in an overall more simple model able to do the same work.
* Switch between optimizing the model parameters, and optimizing the network.
* Could also do it on one batch at a fixed orientation first, but this is much less cool so we won't

For each batch:
1. Optimize the rotation parameters (randomly)
2. Do the backward pass to update the weights, but also update theta

In [54]:
def vae_loss_unreduced(output, target, KLD_weight=1):
    recon_x, mu_logvar  = output
    mu = mu_logvar[:, 0:int(mu_logvar.size()[1]/2)]
    logvar = mu_logvar[:, int(mu_logvar.size()[1]/2):]
    KLD = -0.5 * torch.sum(1 + 2 * logvar - mu.pow(2) - (2 * logvar).exp(), dim=1)
    BCE = F.mse_loss(recon_x, target, reduction='none')    
    BCE = torch.sum(BCE, dim=(1, 2, 3))
    loss = BCE + KLD_weight*KLD
    return loss


data, target = next(iter(data_loader))
batch_size = data.shape[0]
rot_x = rotate_mnist_batch(data, return_size=40, fixed_rotation=45)
rot_x, target = rot_x.to(device), target.to(device)
output = affine_model(rot_x, deterministic=True)
loss = vae_loss_unreduced(output, rot_x)
loss.size()

torch.Size([128])

In [220]:
import numpy as np
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable

import sys
sys.path.insert(0,'..')

from torch.autograd import Variable
from torchvision import transforms
from torch.nn import functional as F
import torch.optim as optim

from base import BaseModel
from model.loss import make_vae_loss


class TESTAffineVAE(nn.Module):
    def __init__(self, pre_trained_VAE=None, img_size=28, input_dim=1, output_dim=1, latent_size=16, rotation_only=False):
        """Do we always need the device whenever we're creating tensors in the model? Whats the proper way to do this?"""
        super(TESTAffineVAE, self).__init__()
        if pre_trained_VAE is None:
            self.VAE = VAE(input_dim=input_dim, output_dim=output_dim, latent_size=latent_size, img_size=img_size)
        else:
            self.VAE = pre_trained_VAE
        self.latent_size = 16
        self.img_size = img_size
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.rotation_only = rotation_only
        self.theta = None
        self.affine_params=None
        self.optim_params=None
        
#         mu_logvar = torch.nn.Parameter(torch.randn((1, 2*self.latent_size), device="cuda", requires_grad=True))
#         self.register_parameter('mu_logvar', mu_logvar) 
#         recon_x = torch.nn.Parameter(torch.randn((1), device="cuda", requires_grad=True))
#         self.register_parameter('recon_x', mu_logvar) 

            
#     def theta_to_affine(self, theta, affine_params):
#         affine_params[:, 0, 0] = torch.cos(theta)
#         affine_params[:, 0, 1] = torch.sin(theta)
#         affine_params[:, 1, 0] = -1*torch.sin(theta)
#         affine_params[:, 1, 1] = torch.cos(theta)
#         return affine_params

    def affine(self, x, affine_params, padding_mode='zeros'):
        grid = F.affine_grid(affine_params, x.size()).cuda()
        x = F.grid_sample(x, grid, padding_mode=padding_mode)
        return x
    
    def affine_inv(self, x, affine_params, padding_mode='zeros'):
        inv_affine_params = torch.cuda.FloatTensor(affine_params.size()).fill_(0)
        A_inv =  torch.inverse(affine_params[:, :, :2].squeeze())
        b = affine_params[:, : , 2:]
        b_inv = torch.matmul(A_inv, b)
        b_inv = b_inv.squeeze()
        inv_affine_params[:, :2, :2] = A_inv
        inv_affine_params[:, :, 2] = -1*b_inv
        grid = F.affine_grid(inv_affine_params, x.size()).cuda()
        x = F.grid_sample(x, grid, padding_mode=padding_mode)
        return x
    
    def vae_loss(self, output, target, KLD_weight=1):
            """loss is BCE + KLD. target is original x"""
            recon_x, mu_logvar  = output
            mu = mu_logvar[:, 0:int(mu_logvar.size()[1]/2)]
            logvar = mu_logvar[:, int(mu_logvar.size()[1]/2):]
            KLD = -0.5 * torch.sum(1 + 2 * logvar - mu.pow(2) - (2 * logvar).exp())
            BCE = F.mse_loss(recon_x, target, reduction='sum')
            loss = BCE + KLD_weight*KLD
            return loss
    
    def opt_latent(self, x, iterations=50, num_times=1):
        deterministic=True
        lr = .01
        
        with torch.enable_grad():
            for trial in range(num_times):
                theta = torch.cuda.FloatTensor(1).uniform_(-2*math.pi, 2*math.pi)
                theta = theta.data.clone().detach().requires_grad_(True).cuda()
                optimizer = optim.Adam([theta], lr=lr)

                for i in range(iterations):
                    affine_params = torch.cat([torch.cos(theta), torch.sin(theta), 
                                       torch.tensor([0.0], requires_grad=True, device="cuda"), 
                                       -1*torch.sin(theta), torch.cos(theta),
                                       torch.tensor([0.0], requires_grad=True, device="cuda")]).view(-1, 2, 3)
                    
                    x_affine = self.affine(x, affine_params)
                    print('x_affine.requires_grad', x_affine.requires_grad)
                    mu_logvar = self.VAE.encode(x_affine)
                    z = self.VAE.reparameterize(mu_logvar, deterministic=deterministic)
                    recon_x = self.VAE.decode(z)
                    recon_x = self.affine_inv(recon_x, affine_params)
                    loss = self.vae_loss((recon_x, mu_logvar), x)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            return loss.item()
        
        
#         with torch.enable_grad():
#             for trial in range(num_times):
#                 print(trial)
#                 mu_logvar = torch.randn((1, 2*self.latent_size), device="cuda", requires_grad=True)
#                 mu_logvar = mu_logvar.data.clone().detach().requires_grad_(True).cuda()
#                 optimizer = optim.Adam([mu_logvar], lr=lr)

#                 for i in range(iterations):
#                     z = self.VAE.reparameterize(mu_logvar, deterministic=deterministic)
#                     recon_x = self.VAE.decode(z)
#                     recon_x = self.affine_inv(recon_x, affine_params)
#                     loss = self.vae_loss((recon_x, mu_logvar), img)
#                     optimizer.zero_grad()
#                     loss.backward()
#                     optimizer.step()
#                     print(loss.item())
#             return loss.item()
            

    def optimize_affine_params(self, x, only_rotation=False, only_shear=False, num_times=100, iterations=50, KLD_weight=1):
        with torch.autograd.set_detect_anomaly(True):
            vae_loss = make_vae_loss(KLD_weight=1)

            lr = .01
            best_loss = 10000000000000

            for trial in range(num_times):
                if only_rotation:
                    x_rot = torch.tensor([0.0], dtype=torch.float32,\
                                         requires_grad=True, device = "cuda")
                    x_sin = torch.sin(x_rot)
                    print('x_sin.requires_grad', x_sin.requires_grad)
                    
#                     print(type(self.VAE.parameters()))
#                     print(type(list(self.VAE.parameters())[0]))
#                     theta = nn.Parameter(torch.from_numpy(np.random.uniform(-2*math.pi, 2*math.pi, 1)).float().clone().detach().cuda())
                    theta = nn.Parameter(torch.cuda.FloatTensor(1).uniform_(-2*math.pi, 2*math.pi))
#                     theta = torch.tensor(torch.cuda.FloatTensor([0.0]).uniform_(-2*math.pi, 2*math.pi).data, requires_grad=True)
#                     print(theta)
#                     print('theta.requires_grad', theta.requires_grad)
#                     print('theta[0].requires_grad', theta[0].requires_grad)
#                     add = torch.add(theta, 2.0)
#                     abc = torch.sin(theta)
#                     print('abc.requires_grad', abc.requires_grad)
#                     print('add.requires_grad', add.requires_grad)

                    affine_params = torch.cat([torch.cos(theta), torch.sin(theta), 
                               torch.tensor([0.0], requires_grad=False, device="cuda"), 
                               -1*torch.sin(theta), torch.cos(theta),
                               torch.tensor([0.0], requires_grad=False, device="cuda")]).view(-1, 2, 3)
                    
                elif only_shear:
#                     c_x = torch.cuda.FloatTensor([0.0]).uniform_(1, 1.5)
                    c_x = torch.cuda.FloatTensor([0.0]).uniform_(-.3, .3)
                    c_y = torch.cuda.FloatTensor([0.0]).uniform_(-.3, .3)
#                     c_y = torch.cuda.FloatTensor([0.0]).uniform_(.6, 1.4)

                    optim_params = [c_x, c_y]
                    affine_params = torch.cat([torch.tensor([1.0], requires_grad=True, device="cuda"), c_x, 
                               torch.tensor([0.0], requires_grad=True, device="cuda"), 
                               c_y, torch.tensor([1.0], requires_grad=True, device="cuda"), 
                               torch.tensor([0.0], requires_grad=True, device="cuda")]).view(-1, 2, 3)
                                        
                else: # initialize to some resonable amount of scaling. This currently includes weird shears.
                    affine_params = 4*torch.rand(x.size()[0], 2, 3, requires_grad=True) -2
                    optim_params = [affine_params]
                    
                if iterations > 0:
                    x = nn.Parameter(x)
                    optimizer = torch.optim.Adam(optim_params, lr=lr)
                    print('optimizer', optimizer)
                    for i in range(iterations):
                        print('x.requires_grad', x.requires_grad)
                        print('affine_params.requires_grad', affine_params.requires_grad)
                        x_affine = self.affine(x, affine_params)
                        print('x_affine.requires_grad', x_affine.requires_grad)
                        mu_logvar = self.VAE.encode(x_affine)
                        z = self.VAE.reparameterize(mu_logvar, deterministic=True)
                        recon_x = self.VAE.decode(z)
                        recon_x = self.affine_inv(recon_x, affine_params)
                        print('recon_x.requires_grad', recon_x.requires_grad)
                        
                        loss = vae_loss((recon_x, mu_logvar), x)
#                         print(loss.item(), self.theta, self.affine_params[0, 0, 0])
#                         print(self.VAE.dec_conv1.weight.grad.data.sum())
                        
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        
                        if loss.item() < best_loss:
                            best_loss = loss.item()
                            best_affine_params = affine_params
                else:
                    x_affine = self.affine(x, affine_params)
                    recon_x, mu_logvar = self.VAE(x_affine, deterministic=True)
                    recon_x = self.affine_inv(recon_x, affine_params)
                    loss = vae_loss((recon_x, mu_logvar), x)
                    if loss.item() < best_loss:
                        best_loss = loss.item()
                        best_affine_params = affine_params
            return best_affine_params, best_loss

        
    def forward(self, x, theta=None, affine_params=None, deterministic=False, return_affine=False):
        """forward pass with optionally learned affine transform. 
        
        !! If no affine passed in will also check if model has some

        Options:
        None: This is the identity transform, equivalent to normal VAE

        Delete::::
        explicit: will use the affine_params provided, else equivalent to None
        stn: use learned params. If STN module isn't trained, will give nonsense
        optimized: optimize affine params to minimize reconstruction loss
        rot_optimized: optimized, but constrained to rotations.
        """

        # learned affine 
        if self.use_STN: 
            affine_params = self.get_stn_params(x)
            
        # initalize affine to rotation 
        elif theta is not None:
            theta = torch.cuda.FloatTensor([theta])
            affine_params = torch.cat([torch.cos(theta), torch.sin(theta), 
                                           torch.tensor([0.0], requires_grad=True, device="cuda"), 
                                           -1*torch.sin(theta), torch.cos(theta), 
                                           torch.tensor([0.0], requires_grad=True, device="cuda")]).view(-1, 2, 3)

        # initialize to identity for each image if affine param not specified and not stn
        elif affine_params is None:
            affine_params = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float).view(2, 3).cuda()
            affine_params = affine_params.expand(x.size()[0], affine_params.size()[0], affine_params.size()[1]).clone()
                    
        x_affine = self.affine(x, affine_params)
        mu_logvar = self.VAE.encode(x_affine)
        z = self.VAE.reparameterize(mu_logvar, deterministic)
        recon_x = self.VAE.decode(z)
        recon_x = self.affine_inv(recon_x, affine_params)
        if return_affine:
            return recon_x, mu_logvar, affine_params, x_affine
        else:
            return recon_x, mu_logvar

def AFFINE_MNIST_rot_perf_LATENT(model, data_loader, loss_fn, device, fixed_rotation, 
                          optimize=False, iterations=0, num_rand_restarts=200, num_imgs=1000):
    """Evaluate performance on MNIST Dataset using a given rotation
    Dataloader should be MNISTCustomTRNFS with size=28x28, unnormalized, not rotated"""

    with torch.cuda.device(device.index):
        model = model.to(device)
        model.eval()
        total_loss = 0.0

        with torch.no_grad():
            for i, (data, target) in enumerate(data_loader):
                batch_size = data.shape[0]
                rot_x = rotate_mnist_batch(data, return_size=40, fixed_rotation=fixed_rotation)
                rot_x, target = rot_x.to(device), target.to(device)
                rot_x.requires_grad = True
                target.requires_grad = True
                if optimize:
                    loss = affine_model.opt_latent(rot_x, iterations=iterations, num_times=num_rand_restarts)
                else:
                    output = model(rot_x, deterministic=True, theta=0.0)
                    loss = loss_fn(output, rot_x).item()

                total_loss += loss * batch_size
                if i>num_imgs:
                    break

        n_samples = len(data_loader.sampler)
        log = {'loss': total_loss / n_samples}
        return log

In [221]:
config_loc = '/media/rene/data/equivariance/mnist/vae_mnist_L16/0129_230250'
VAE, data_loader, valid_data_loader, loss_fn, metric_fns, config = get_model_loaders_config(config_loc, old_gpu='cuda:1', new_gpu='cuda:0')
VAE = VAE.to(device)

AffineVAE = getattr(module_arch, 'AffineVAE')
affine_model = TESTAffineVAE(pre_trained_VAE=VAE, img_size=28, input_dim=1, output_dim=1, latent_size=16)
affine_model = affine_model.to(device)

files_dict_loc = '/media/rene/data/MNIST/files_dict.pkl'
data_loaders = make_generators_MNIST_CTRNFS(files_dict_loc, batch_size=1, num_workers=4, 
                                            return_size=28, rotation_range=None, normalize=False)

In [222]:
log = AFFINE_MNIST_rot_perf_LATENT(affine_model, data_loaders['val'], loss_fn, device, fixed_rotation=45, 
                                   optimize=True, iterations=5, num_rand_restarts=1, num_imgs=3)
log

0
x_affine.requires_grad True
tensor([[[-0.1036,  0.9946,  0.0000],
         [-0.9946, -0.1036,  0.0000]]], device='cuda:0',
       grad_fn=<ViewBackward>)
503.4684143066406
x_affine.requires_grad True
tensor([[[-0.1136,  0.9935,  0.0000],
         [-0.9935, -0.1136,  0.0000]]], device='cuda:0',
       grad_fn=<ViewBackward>)
502.7657470703125
x_affine.requires_grad True
tensor([[[-0.1226,  0.9925,  0.0000],
         [-0.9925, -0.1226,  0.0000]]], device='cuda:0',
       grad_fn=<ViewBackward>)
502.376953125
x_affine.requires_grad True
tensor([[[-0.1316,  0.9913,  0.0000],
         [-0.9913, -0.1316,  0.0000]]], device='cuda:0',
       grad_fn=<ViewBackward>)
502.05572509765625
x_affine.requires_grad True
tensor([[[-0.1405,  0.9901,  0.0000],
         [-0.9901, -0.1405,  0.0000]]], device='cuda:0',
       grad_fn=<ViewBackward>)
501.8534240722656
0
x_affine.requires_grad True
tensor([[[ 0.0705, -0.9975,  0.0000],
         [ 0.9975,  0.0705,  0.0000]]], device='cuda:0',
       grad_fn=<

{'loss': 0.31353234252929685}

10000000000.0

TypeError: get_AFFINE_MNIST_perf() got multiple values for argument 'fixed_rotation'