# MAP-training

## In this notebook we perform MAP-training with additional constraint on ranks 

In [71]:
import torch
import torch.nn as nn
from torch.distributions import Normal, Gamma

from random import shuffle

from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader, random_split, Subset
import torchvision.transforms as transforms

#from tqdm import tqdm_notebook as tqdm
from tqdm import tqdm
from tt_model import *

In [78]:
# 7666 12
config = {
    'resize_shape': (32, 32),
    
    'in_factors': (4, 4, 4, 4, 4),
    'l1_ranks': (8, 8, 8, 8),
    'hidd_out_factors': (2, 2, 2, 2, 2),
    'ein_string1': "nabcde,aoiv,bijw,cjkx,dkly,elpz",
    
    'hidd_in_factors': (4, 8),
    'l2_ranks': (16,),
    'out_factors': (5, 2),
    'ein_string2': 'nab,aoix,bipy',
}

parameters_config = {
    'batch_size': 200,
    'train_size': 40000,
    'device': torch.device('cuda'),
    'learning_rate': 1e-4,
    'n_epochs': 100,
    
    # gamma distribution parameters from paper
#     'a_l': 1,
#     'b_l': 5,
    'a_l': 1,
    'b_l': 15,
}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
model_cfg = AttrDict(config)
params_cfg = AttrDict(parameters_config)


In [12]:
model = TTModel(model_cfg)
model = model.to(params_cfg.device)

In [79]:
NUM_LABELS = 10
MNIST_TRANSFORM = transforms.Compose((
    transforms.Pad(2),
    transforms.ToTensor(),
    transforms.Normalize((0.1,), (0.2752,))
))

dataset = MNIST('mnist', train=True, download=True, transform=MNIST_TRANSFORM)

In [80]:
train_dataset, val_dataset = random_split(dataset, (params_cfg.train_size, len(dataset) - params_cfg.train_size))
train_subset_indices = list(range(len(train_dataset)))
shuffle(train_subset_indices)
train_subset_indices = train_subset_indices[:10000]
train_subset = Subset(train_dataset, train_subset_indices)

train_loader = DataLoader(train_dataset, batch_size=params_cfg.batch_size, shuffle=True, pin_memory=(params_cfg.device.type == "cuda"))
val_loader = DataLoader(val_dataset, batch_size=params_cfg.batch_size, shuffle=True, pin_memory=(params_cfg.device.type == "cuda"))
train_subset_loader = DataLoader(train_subset, batch_size=params_cfg.batch_size, shuffle=True, pin_memory=(params_cfg.device.type == "cuda"))

In [81]:
model = model.to(params_cfg.device)

#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.95, weight_decay=0.0005)

In [82]:
def log_prior(model, lambdas=None, a_l=1, b_l=5):
    log_prior_sum = 0
    for name, core_tensor in model.named_parameters():
        if 'tt' not in name:
            continue
        core_mean = torch.zeros_like(core_tensor)
        
        if lambdas is None:
            core_std = torch.ones_like(core_tensor)
        else:
            layer_idx = int(name.split('tt')[-1].split('.')[0])
            core_idx = int(name.split('cores.')[-1])
            
            prev_rank = core_tensor.shape[1]
            next_rank = core_tensor.shape[2]
        
            if prev_rank == 1:
                l_next = lambdas[layer_idx][core_idx]
                l_prev = l_next
            elif next_rank  == 1:
                l_prev = lambdas[layer_idx][core_idx - 1]
                l_next = l_prev
            else:
                l_prev = lambdas[layer_idx][core_idx - 1]
                l_next = lambdas[layer_idx][core_idx]
            
#             print(l_prev.shape, l_next.shape)
            core_std = torch.einsum('p,q->pq', l_prev, l_next)
            core_std = core_std.repeat(core_tensor.shape[0], core_tensor.shape[3], 1, 1).permute(0, 2, 3, 1)
            
        log_prior_sum += Normal(core_mean, core_std).log_prob(core_tensor).sum()
    log_g = log_prior_sum
    log_lambda = 0
    if lambdas is not None:
        for layer_lambdas in lambdas:
            for l in layer_lambdas:
                log_lambda += Gamma(a_l, b_l).log_prob(l).sum()
        
    return log_g, log_lambda
        


def log_posterior(model, input, gt, lambdas=None, likelihood_coef=1., a_l=1, b_l=5):
    """Calculate log-posterior for core tensors and lambdas (optional)

    Parameters:   
        model : TT-model with core tensors as parameters, 
        input : Model input
        gt : Ground truth
        lambdas : LR-parameters \lambda, if any
        likelihood_coef : Coefficient to multiply log-likelihood by (for batches)
    
    Returns:
        Log-posterior 
    """
    model_out = model(input)
    
    log_g = torch.nn.functional.log_softmax(model_out, dim=1)
    log_likelihood = (gt * log_g).sum()
    
    log_g_prior, log_lambda = log_prior(model, lambdas, a_l, b_l)
    
    # not including margnial log-likelihood log(p(D))
    global global_step
#     if global_step % 2000 == 0:
#         print('Likelihood loss', -(likelihood_coef * log_likelihood).item())
#         print('Prior G', -log_g_prior.item())
#         print('Prior lambda', -log_lambda.item())
    global_step += 1
    return likelihood_coef * log_likelihood + (log_g_prior + log_lambda)

In [84]:
def acc(model, loader):
    accs = []
    with torch.no_grad():
        for b, gt in loader:
            out = model(b.to(params_cfg.device)).argmax(1).cpu().numpy()
            gt = gt.numpy()
            accs.append(sum(out == gt) / len(out))
    return sum(accs) / len(accs) 


lambdas1 = nn.ParameterList([nn.Parameter(torch.distributions.Gamma(params_cfg.a_l, params_cfg. b_l).sample([r])) for r in model_cfg.l1_ranks]).to(params_cfg.device)
lambdas2 = nn.ParameterList([nn.Parameter(torch.distributions.Gamma(params_cfg.a_l, params_cfg.b_l).sample([r])) for r in model_cfg.l2_ranks]).to(params_cfg.device)
                     

optimizer = torch.optim.Adam(list(model.parameters())
                             + list(lambdas1)
                             + list(lambdas2)
                             , lr=params_cfg.learning_rate)


global_step = 0
for ep in range(params_cfg.n_epochs):
    for b, gt in train_loader:
        
        optimizer.zero_grad()
    
        onehot_gt = torch.zeros(gt.shape[0], NUM_LABELS).scatter_(1, gt.view(-1, 1), 1)

        likelihood_coef = len(train_dataset) / params_cfg.batch_size
        loss = -log_posterior(model, b.to(params_cfg.device), onehot_gt.to(params_cfg.device),
                              lambdas=[lambdas1,lambdas2], 
                              a_l=params_cfg.a_l, b_l=params_cfg.b_l, likelihood_coef=likelihood_coef)
        
        loss.backward()
        optimizer.step()
    
#     print(loss.item())
    print('lambdas proposed shapes:', 
          np.array(model_cfg['l1_ranks']) - [(q.cpu().numpy() < 1e-2).sum() for q in lambdas1.state_dict().values()],
          np.array(model_cfg['l2_ranks']) - [(q.cpu().numpy() < 1e-2).sum() for q in lambdas2.state_dict().values()])
    print(f"val_acc = {acc(model, val_loader)}")
    print(f"train acc = {acc(model, train_subset_loader)}")
    

lambdas proposed shapes: [8 6 3 8] [16]
val_acc = 0.8630999999999998
train acc = 0.861
lambdas proposed shapes: [8 6 2 8] [16]
val_acc = 0.8575499999999995
train acc = 0.8530999999999999
lambdas proposed shapes: [8 5 2 8] [16]
val_acc = 0.8553000000000003
train acc = 0.8516000000000002
lambdas proposed shapes: [8 4 2 8] [16]
val_acc = 0.8528500000000001
train acc = 0.8484999999999998
lambdas proposed shapes: [8 4 2 8] [16]
val_acc = 0.84455
train acc = 0.8405000000000002
lambdas proposed shapes: [8 4 2 8] [16]
val_acc = 0.8282499999999997
train acc = 0.8225999999999997
lambdas proposed shapes: [8 4 2 8] [16]
val_acc = 0.8211999999999996
train acc = 0.8170999999999999
lambdas proposed shapes: [8 4 2 8] [16]
val_acc = 0.8182999999999997
train acc = 0.8136
lambdas proposed shapes: [8 3 2 8] [16]
val_acc = 0.81395
train acc = 0.8107000000000002
lambdas proposed shapes: [8 3 2 8] [16]
val_acc = 0.7966500000000003
train acc = 0.7930999999999999
lambdas proposed shapes: [8 2 2 8] [16]
val_acc

KeyboardInterrupt: 

We can see fair rank constraints obtained by optimization by lambdas, but accuracy drop is too high.
