In [8]:
import logging
from functools import partial

from typing import Sequence, Any, Iterable, Optional, List
import numpy as np
# import click
# import click_log
from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn as nn
import torch.nn.functional as tnnf
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import StepLR


In [9]:
class TTLayer(nn.Module):
    def __init__(self, in_factors, out_factors, ranks, ein_string):
        super().__init__()
        self.in_factors = in_factors
        self.out_factors = out_factors
        self.ein_string = ein_string
        assert len(in_factors) == len(out_factors) == len(ranks) + 1, 'Input factorization should match output factorization and should be equal to len(ranks) - 1'
#         assert len(ranks) == 4, 'Now we consider particular factorization for given dataset'

        self.cores = [nn.Parameter(torch.randn(in_factors[0], 1, ranks[0], out_factors[0], ) * 0.8)]
        for i in range(1, len(in_factors) - 1):
            self.cores.append(nn.Parameter(torch.randn(in_factors[0], ranks[i-1], ranks[i], out_factors[0],) * 0.1))
        self.cores.append(nn.Parameter(torch.randn(in_factors[-1], ranks[-1], 1, out_factors[-1], ) * 0.8))
#         print(self.cores)
    def forward(self, x):
        reshaped_input = x.reshape(-1, *self.in_factors)
#         print('reshaped_input', reshaped_input.shape)
        # in the einsum below, n stands for index of sample in the batch,
        # abcde - indices corresponding to h1, h2, hw, w1, w2 modes
        # o, i, j, k, l, p - indices corresponding to the 4 tensor train ranks
        # v, w, x, y, z - indices corresponding to o1, o2, o3, o4, o5

        result = torch.einsum(
            self.ein_string,
            reshaped_input, *self.cores
        )
        return result.reshape(-1, np.prod(self.out_factors))
    
    def parameters(self):
        return self.cores

class TTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.net = nn.Sequential(
            nn.Upsample(size=cfg.resize_shape, mode="bilinear", align_corners=False),
            TTLayer(cfg.in_factors, cfg.hidd_out_factors, cfg.l1_ranks, cfg.ein_string1),
            nn.ReLU(),
#             nn.Linear(np.prod(hidd_factors), NUM_LABELS),
            TTLayer(cfg.hidd_in_factors, cfg.out_factors, cfg.l2_ranks, cfg.ein_string2),
            nn.Softmax(dim=1))

        # self.

    def forward(self, x):
        return self.net(x)
    def parameters(self,):
        return self.net[1].parameters() + list(self.net[3].parameters())

In [10]:
config = {
    'resize_shape': (32, 32),
    
    'in_factors': (4, 4, 4, 4, 4),
    'l1_ranks': (8, 8, 8, 8),
    'hidd_out_factors': (2, 2, 2, 2, 2),
    'ein_string1': "nabcde,aoiv,bijw,cjkx,dkly,elpz",
    
    'hidd_in_factors': (4, 8),
    'l2_ranks': (16,),
    'out_factors': (5, 2),
    'ein_string2': 'nab,aoix,bipy',
}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
cfg = AttrDict(config)
model = TTModel(cfg)

In [22]:

# from ignite.metrics import Loss, Accuracy
# from ignite.engine import Events, create_supervised_evaluator
# from ignite.contrib.handlers.param_scheduler import LRScheduler

# from libcrap import shuffled
# from libcrap.torch import set_random_seeds
# from libcrap.torch.click import (
#     click_dataset_root_option, click_models_dir_option, click_tensorboard_log_dir_option,
#     click_seed_and_device_options
# )
# from libcrap.torch.training import (
#     add_checkpointing, add_early_stopping, add_weights_and_grads_logging,
#     setup_trainer, setup_evaluator, setup_tensorboard_logger,
#     make_standard_prepare_batch_with_events, add_logging_input_images
# )

logger = logging.getLogger()

MNIST_DATASET_SIZE = 60000
NUM_LABELS = 10

MNIST_TRANSFORM = transforms.Compose((
    transforms.Pad(2),
    transforms.ToTensor(),
    transforms.Normalize((0.1,), (0.2752,))
))



device = torch.device('cpu')

train_dataset_size = 40000
batch_size = 10
learning_rate = 1e-3
n_epochs = 10
dataset = MNIST('mnist', train=True, download=True, transform=MNIST_TRANSFORM)
assert len(dataset) == MNIST_DATASET_SIZE
train_dataset, val_dataset = random_split(
    dataset, (train_dataset_size, MNIST_DATASET_SIZE - train_dataset_size)
)

train_loader, val_loader = (
    DataLoader(
        dataset_, batch_size=batch_size, shuffle=True, pin_memory=(device.type == "cuda")
    )
    for dataset_ in (train_dataset, val_dataset)
)

model = model.to(device)
optimizer = torch.optim.SGD(
    model.parameters(), lr=learning_rate, momentum=0.95, weight_decay=0.0005
)


In [24]:
lf = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for ep in range(n_epochs):
    for b, gt in tqdm(train_loader):
        
        optimizer.zero_grad()
#         b = b.cuda()
        out = model(b)
#         print(out.argmax(1), gt)
        loss = lf(out, gt.to(device))
#         print(loss.item())
        loss.backward()
        optimizer.step()
    
    print(acc(model, val_loader))
    

HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

KeyboardInterrupt: 

In [None]:
def acc(model, loader):
    accs = []
    with torch.no_grad():
        for b, gt in tqdm(loader):
            out = model(b).argmax(1).numpy()
            gt = gt.numpy()
            accs.append(sum(out == gt) / len(out))
    return sum(accs) / len(accs) 


In [175]:
acc(model, val_loader)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))




0.6506500000000003