In [1]:
import torch
import torch.nn as nn
from torch.distributions import Normal, Gamma

from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms

from tqdm import tqdm_notebook as tqdm

In [2]:
from tt_model import TTModel
from tt_model import vectorize_params, unvectorize_params

config = {
    'resize_shape': (32, 32),
    
    'in_factors': (4, 4, 4, 4, 4),
    'l1_ranks': (8, 8, 8, 8),
    'hidd_out_factors': (2, 2, 2, 2, 2),
    'ein_string1': "nabcde,aoiv,bijw,cjkx,dkly,elpz",
    
    'hidd_in_factors': (4, 8),
    'l2_ranks': (8,),
    'out_factors': (5, 2),
    'ein_string2': 'nab,aoix,bipy',
}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
cfg = AttrDict(config)

model = TTModel(cfg)

In [3]:
NUM_LABELS = 10
MNIST_TRANSFORM = transforms.Compose((
    transforms.Pad(2),
    transforms.ToTensor(),
    transforms.Normalize((0.1,), (0.2752,))
))

device = torch.device('cpu')

train_size = 40000
batch_size = 10
dataset = MNIST('mnist', train=True, download=True, transform=MNIST_TRANSFORM)

In [4]:
train_dataset, val_dataset = random_split(dataset, (train_size, len(dataset) - train_size))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=(device.type == "cuda"))
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, pin_memory=(device.type == "cuda"))

model = model.to(device)

#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.95, weight_decay=0.0005)

In [5]:
def log_prior(model, lambdas=None, a_l=1, b_l=5):
    log_prior_sum = 0
    for name, core_tensor in model.named_parameters():
        if 'tt' not in name:
            continue
            
        core_mean = torch.zeros_like(core_tensor)
        
        if lambdas is None:
            core_std = torch.ones_like(core_tensor)
        else:
            layer_idx = int(name.split('tt')[-1].split('.')[0])
            core_idx = int(name.split('cores.')[-1])
            
            prev_rank = core_tensor.shape[1]
            next_rank = core_tensor.shape[2]
        
            if prev_rank == 1:
                l_next = lambdas[layer_idx][core_idx]
                l_prev = l_next
            elif next_rank  == 1:
                l_prev = lambdas[layer_idx][core_idx - 1]
                l_next = l_prev
            else:
                l_prev = lambdas[layer_idx][core_idx - 1]
                l_next = lambdas[layer_idx][core_idx]
            
            core_std = torch.einsum('p,q->pq', l_prev, l_next)
            core_std = core_std.repeat(core_tensor.shape[0], core_tensor.shape[3], 1, 1).permute(0, 2, 3, 1)
            
        log_prior_sum += Normal(core_mean, core_std).log_prob(core_tensor).sum()
    
    if lambdas is not None:
        for layer_lambdas in lambdas:
            for l in layer_lambdas:
                log_prior_sum += Gamma(a_l, b_l).log_prob(l).sum()
        
    return log_prior_sum
        


def log_posterior(model, input, gt, lambdas=None, likelihood_coef=1.):
    """Calculate log-posterior for core tensors and lambdas (optional)

    Parameters:   
        model : TT-model with core tensors as parameters, 
        input : Model input
        gt : Ground truth
        lambdas : LR-parameters \lambda, if any
        likelihood_coef : Coefficient to multiply log-likelihood by (for batches)
    
    Returns:
        Log-posterior 
    """
    
    out = torch.nn.functional.softmax(model(input), dim=1)
    log_g = torch.log(out)
    log_likelihood = (gt * log_g).sum()
    
    #print(out, log_g)
    
    log_prior_sum = log_prior(model, lambdas, 1, 5)
    
    # not including margnial log-likelihood log(p(D))
    return likelihood_coef * log_likelihood + log_prior_sum

In [21]:
def acc(model, loader):
    accs = []
    with torch.no_grad():
        for b, gt in tqdm(loader):
            out = model(b.to(device)).argmax(1).cpu().numpy()
            gt = gt.numpy()
            accs.append(sum(out == gt) / len(out))
    return sum(accs) / len(accs) 

learning_rate = 1e-3
n_epochs = 100

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for ep in range(n_epochs):
    for b, gt in tqdm(train_loader):
        
        optimizer.zero_grad()
    
        onehot_gt = torch.zeros(gt.shape[0], NUM_LABELS).scatter_(1, gt.view(-1, 1), 1)

        likelihood_coef = len(train_dataset) / batch_size
        loss = -log_posterior(model, b.to(device), onehot_gt.to(device), likelihood_coef=likelihood_coef)
        
        loss.backward()
        optimizer.step()
    
    print(loss.item())
    #print(acc(model, val_loader))
    

HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

13179.3818359375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2089.62548828125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

8099.8623046875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

5793.25


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2991.842529296875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

11371.572265625


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

4531.1279296875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

3848.16259765625


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

7232.333984375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2285.934326171875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

13634.74609375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

4904.47607421875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

3431.09375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2016.231201171875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

13805.0234375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

20207.6171875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

3635.41845703125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

24677.533203125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

3999.40576171875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

7582.1923828125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

4553.427734375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

4246.4482421875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

5368.42919921875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

12090.70703125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

16243.7294921875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

14423.7861328125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

3541.14404296875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

11034.5126953125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2869.46044921875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2612.28173828125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

4550.822265625


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

3054.267578125


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2199.796630859375


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

15994.8466796875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

9123.369140625


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

4007.50732421875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

21017.12890625


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

6704.69482421875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2935.379150390625


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

13986.154296875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

3588.729248046875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

20497.31640625


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

2850.29296875


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

KeyboardInterrupt: 

In [None]:
def svgd_step(particles, kernel_sigma=1., eps=1e-4, n_steps=100):
    for step in range(n_steps):
        log_posterior_grads = []
        kernels = []
        phi = []
        
        for theta, (b, gt) in zip(particles, loader):
            model, lambdas = unvectorize_params(theta, cfg)
            
            onehot_gt = torch.zeros(gt.shape[0], NUM_LABELS).scatter_(1, gt.view(-1, 1), 1)
            log_posterior_i = log_posterior(model, b, onehot_gt, lambdas)
            log_posterior_i.backward()
            log_posterior_grad = theta.grad
            
            log_posterior_grads.append(log_posterior_grad)
            theta.grad.data.zero_()
            
            kernels.append(Normal(theta, kernel_sigma))
        
        for k, theta_k in enumerate(particles):
            phi_k = 0
            for i, theta_i in enumerate(particles):
                #theta_i.grad.data.zero_()

                kernel_ik = torch.exp(kernels[k].log_prob(theta_i))
                kernel_ik.backward()
                kernel_ik_grad = theta_i.grad
                theta_i.grad.data.zero_()

                phi_k += kernel_ik * log_posterior_grads[i] + kernel_grad
            phi_k /= len(particles)
            phi.append(phi_k)
            
        for theta, phi_k in particles:
            theta += eps * phi_k
    
    return particles

In [12]:
class GaussKernel:
    def __init__(self, sigma):
        self.sigma = sigma
    def __call__(self, x_1, x_2)

def get_svgd_ranks(kernel=None, n_particles=50, n_steps=100, rank_as_mean=True, a_l=1, b_l=5):
    # sample particles:
    # * initialize empty model for G's
    # * sample lambdas from priors
    # * vectorize the particle
    particles = []
    for n in range(n_particles):
        model = TTModel(cfg)
        lambdas = []
        layer_ranks = [cfg.l1_ranks, cfg.l2_ranks]
        for layer in range(2):
            lambdas.append([])
            for j, rank in enumerate(layer_ranks[layer]):
                lambdas[layer].append(Gamma(a_l, b_l).rsample([rank]))
        
        particle = vectorize_params(model, lambdas)
        particles.append(particle)
        
    return particles
    # n_steps times perform svgd_step
    for step in range(n_steps):
        particles = svgd_step(particles, kernel)
    
    if rank_as_mean:
        # calculate mean of posterior lambdas
        pass
    else:
        # calculate mode of posterior lambdas
        pass
    
    return ranks
    

In [None]:
torch.distribution.Normal()

In [56]:
Gamma(1, 5).sample([8])

tensor([0.0556, 0.3757, 0.9405, 0.0749, 0.1693, 0.2281, 0.4548, 0.5641])

In [13]:
a = get_svgd_ranks()

In [17]:
a[0].shape

torch.Size([1992])