In [1]:
import logging
from functools import partial

from collections import OrderedDict
from typing import Sequence, Any, Iterable, Optional, List
import numpy as np
# import click
# import click_log
from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn as nn
import torch.nn.functional as tnnf
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import StepLR
from tt_model import *


In [2]:
config = {
    'resize_shape': (32, 32),
    
    'in_factors': (4, 4, 4, 4, 4),
    'l1_ranks': (8, 8, 8, 8),
    'hidd_out_factors': (2, 2, 2, 2, 2),
    'ein_string1': "nabcde,aoiv,bijw,cjkx,dkly,elpz",
    
    'hidd_in_factors': (4, 8),
    'l2_ranks': (16,),
    'out_factors': (5, 2),
    'ein_string2': 'nab,aoix,bipy',
}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
cfg = AttrDict(config)
model = TTModel(cfg)

In [5]:

MNIST_DATASET_SIZE = 60000
NUM_LABELS = 10

MNIST_TRANSFORM = transforms.Compose((
    transforms.Pad(2),
    transforms.ToTensor(),
    transforms.Normalize((0.1,), (0.2752,))
))



device = torch.device('cuda')
batch_size = 500
train_dataset_size = 40000

dataset = MNIST('mnist', train=True, download=True, transform=MNIST_TRANSFORM)
assert len(dataset) == MNIST_DATASET_SIZE
train_dataset, val_dataset = random_split(
    dataset, (train_dataset_size, MNIST_DATASET_SIZE - train_dataset_size)
)

train_loader, val_loader = (
    DataLoader(
        dataset_, batch_size=batch_size, shuffle=True, pin_memory=(device.type == "cuda")
    )
    for dataset_ in (train_dataset, val_dataset)
)

model = model.to(device)



In [None]:
def acc(model, loader):
    accs = []
    with torch.no_grad():
        for b, gt in tqdm(loader):
            out = model(b.to(device)).argmax(1).cpu().numpy()
            gt = gt.numpy()
            accs.append(sum(out == gt) / len(out))
    return sum(accs) / len(accs) 

learning_rate = 1e-3
n_epochs = 100

lf = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(
#     model.parameters(), lr=learning_rate, momentum=0.95, weight_decay=0.0005
# )
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for ep in range(n_epochs):
    for b, gt in tqdm(train_loader):
        
        optimizer.zero_grad()
        out = model(b.to(device))
        loss = lf(out, gt.to(device))
        loss.backward()
        optimizer.step()
    
    print(acc(model, val_loader))
    

HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.6545500000000002


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.7586000000000002


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.80655


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.8362000000000004


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.8567500000000001


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.8681999999999999


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.8794500000000003


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.8856500000000004


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.8952


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.90065


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.90635


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.9093999999999998


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.9113500000000002


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))




HBox(children=(IntProgress(value=0, max=40), HTML(value='')))


0.9151000000000001


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))

In [None]:
acc(model, val_loader)