In [1]:
import torch
import torch.nn as nn
from torch.distributions import Normal, Gamma

from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms

from tqdm import tqdm_notebook as tqdm

In [2]:
from tt_model import TTModel

config = {
    'resize_shape': (32, 32),
    
    'in_factors': (4, 4, 4, 4, 4),
    'l1_ranks': (8, 8, 8, 8),
    'hidd_out_factors': (2, 2, 2, 2, 2),
    'ein_string1': "nabcde,aoiv,bijw,cjkx,dkly,elpz",
    
    'hidd_in_factors': (4, 8),
    'l2_ranks': (8,),
    'out_factors': (5, 2),
    'ein_string2': 'nab,aoix,bipy',
}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
cfg = AttrDict(config)

model = TTModel(cfg)

In [3]:
NUM_LABELS = 10
MNIST_TRANSFORM = transforms.Compose((
    transforms.Pad(2),
    transforms.ToTensor(),
    transforms.Normalize((0.1,), (0.2752,))
))

device = torch.device('cpu')

train_size = 40000
batch_size = 10
dataset = MNIST('mnist', train=True, download=True, transform=MNIST_TRANSFORM)

In [4]:
train_dataset, val_dataset = random_split(dataset, (train_size, len(dataset) - train_size))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=(device.type == "cuda"))
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, pin_memory=(device.type == "cuda"))

model = model.to(device)

#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.95, weight_decay=0.0005)

In [None]:
def log_prior(model, lambdas=None, a_l=1, b_l=5):
    log_prior_sum = 0
    for name, core_tensor in model.named_parameters():
        if 'tt' not in name:
            continue
            
        core_mean = torch.zeros_like(core_tensor)
        
        if lambdas is None:
            core_std = torch.ones_like(core_tensor)
        else:
            layer_idx = int(name.split('tt')[-1].split('.')[0])
            core_idx = int(name.split('cores.')[-1])
            
            prev_rank = core_tensor.shape[1]
            next_rank = core_tensor.shape[2]
        
            if prev_rank == 1:
                l_next = lambdas[layer_idx][core_idx]
                l_prev = l_next
            elif next_rank  == 1:
                l_prev = lambdas[layer_idx][core_idx - 1]
                l_next = l_prev
            else:
                l_prev = lambdas[layer_idx][core_idx - 1]
                l_next = lambdas[layer_idx][core_idx]
            
            core_std = torch.einsum('p,q->pq', l_prev, l_next)
            core_std = core_std.repeat(core_tensor.shape[0], core_tensor.shape[3], 1, 1).permute(0, 2, 3, 1)
            
        log_prior_sum += Normal(core_mean, core_std).log_prob(core_tensor).sum()
    
    if lambdas is not None:
        for l in lambdas:
            log_prior_sum += Gamma(a_l, b_l).log_prob(l).sum()
        
    return log_prior_sum
        


def log_posterior(model, out, gt, lambdas=None, likelihood_coef=1.):
    """Calculate log-posterior for core tensors and lambdas (optional)

    Parameters:   
        model : TT-model with core tensors as parameters, 
        out : Model output
        gt : Ground truth
        lambdas : LR-parameters \lambda, if any
        likelihood_coef : Coefficient to multiply log-likelihood by (for batches)
    
    Returns:
        Log-posterior 
    """
    
    out = torch.nn.functional.softmax(out, dim=0)
    log_g = torch.log(out)
    log_likelihood = (gt * log_g).sum()
    
    #print(out, log_g)
    
    log_prior_sum = log_prior(model, lambdas, 1, 5)
    
    # not including margnial log-likelihood log(p(D))
    return likelihood_coef * log_likelihood + log_prior_sum

In [None]:
def acc(model, loader):
    accs = []
    with torch.no_grad():
        for b, gt in tqdm(loader):
            out = model(b.to(device)).argmax(1).cpu().numpy()
            gt = gt.numpy()
            accs.append(sum(out == gt) / len(out))
    return sum(accs) / len(accs) 

learning_rate = 1e-3
n_epochs = 100

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for ep in range(n_epochs):
    for b, gt in tqdm(train_loader):
        
        optimizer.zero_grad()
    
        out = model(b.to(device))
        onehot_gt = torch.zeros(gt.shape[0], NUM_LABELS).scatter_(1, gt.view(-1, 1), 1)

        likelihood_coef = len(train_dataset) / b.shape[0]
        loss = -log_posterior(model, out, onehot_gt.to(device), likelihood_coef=likelihood_coef)
        
        loss.backward()
        optimizer.step()
    
    print(loss.item())
    #print(acc(model, val_loader))
    

HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


33149.625


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


47201.34375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


32192.984375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


36987.19921875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


45767.80859375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


42220.15234375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


38874.96875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


36237.734375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


25575.453125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))