In [1]:
import logging
from functools import partial

from collections import OrderedDict
from typing import Sequence, Any, Iterable, Optional, List
import numpy as np
# import click
# import click_log
from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn as nn
import torch.nn.functional as tnnf
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import StepLR
from tt_model import *

In [4]:
config = {
    'resize_shape': (32, 32),
    
    'in_factors': (4, 4, 4, 4, 4),
    'l1_ranks': (8, 8, 8, 8),
    'hidd_out_factors': (2, 2, 2, 2, 2),
    'ein_string1': "nabcde,aoiv,bijw,cjkx,dkly,elpz",
    
    'hidd_in_factors': (4, 8),
    'l2_ranks': (8,),
    'out_factors': (5, 2),
    'ein_string2': 'nab,aoix,bipy',
}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
cfg = AttrDict(config)
model = TTModel(cfg)

In [5]:

MNIST_DATASET_SIZE = 60000
NUM_LABELS = 10

MNIST_TRANSFORM = transforms.Compose((
    transforms.Pad(2),
    transforms.ToTensor(),
    transforms.Normalize((0.1,), (0.2752,))
))



device = torch.device('cpu')

train_dataset_size = 40000
batch_size = 10
learning_rate = 1e-3
n_epochs = 10
dataset = MNIST('mnist', train=True, download=True, transform=MNIST_TRANSFORM)
assert len(dataset) == MNIST_DATASET_SIZE
train_dataset, val_dataset = random_split(
    dataset, (train_dataset_size, MNIST_DATASET_SIZE - train_dataset_size)
)

train_loader, val_loader = (
    DataLoader(
        dataset_, batch_size=batch_size, shuffle=True, pin_memory=(device.type == "cuda")
    )
    for dataset_ in (train_dataset, val_dataset)
)

model = model.to(device)
optimizer = torch.optim.SGD(
    model.parameters(), lr=learning_rate, momentum=0.95, weight_decay=0.0005
)


In [7]:
def acc(model, loader):
    accs = []
    with torch.no_grad():
        for b, gt in tqdm(loader):
            out = model(b).argmax(1).numpy()
            gt = gt.numpy()
            accs.append(sum(out == gt) / len(out))
    return sum(accs) / len(accs) 


lf = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for ep in range(n_epochs):
    for b, gt in tqdm(train_loader):
        
        optimizer.zero_grad()
#         b = b.cuda()
        out = model(b)
#         print(out.argmax(1), gt)
        loss = lf(out, gt.to(device))
#         print(loss.item())
        loss.backward()
        optimizer.step()
    
    print(acc(model, val_loader))
    

HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

KeyboardInterrupt: 

In [175]:
acc(model, val_loader)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))




0.6506500000000003