In [18]:
from tt_model import TTModel
import torch
from torch import nn
from tqdm import tqdm

In [2]:
config = {
    'in_factors': (4, 4, 2, 2, 2),
    'l1_ranks': (1, 1, 1, 1),
    'hidd_out_factors': (2, 2, 2, 2, 2),
    'ein_string1': "nabcde,aoiv,bijw,cjkx,dkly,elpz",
    
    'hidd_in_factors': (1, 32),
    'l2_ranks': (1,),
    'out_factors': (1, 2),
    'ein_string2': 'nab,aoix,bipy',
}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
cfg = AttrDict(config)
model = TTModel(cfg)

In [3]:
model

TTModel(
  (net): Sequential(
    (tt0): TTLayer(
      (cores): ParameterList(
          (0): Parameter containing: [torch.FloatTensor of size 4x1x1x2]
          (1): Parameter containing: [torch.FloatTensor of size 4x1x1x2]
          (2): Parameter containing: [torch.FloatTensor of size 2x1x1x2]
          (3): Parameter containing: [torch.FloatTensor of size 2x1x1x2]
          (4): Parameter containing: [torch.FloatTensor of size 2x1x1x2]
      )
    )
    (relu): ReLU()
    (tt1): TTLayer(
      (cores): ParameterList(
          (0): Parameter containing: [torch.FloatTensor of size 1x1x1x1]
          (1): Parameter containing: [torch.FloatTensor of size 32x1x1x2]
      )
    )
  )
)

In [4]:
t = torch.randn(1, 128)

In [5]:
model(t)

tensor([[ 0.0345, -0.0638]], grad_fn=<AsStridedBackward>)

In [6]:
torch.random.manual_seed(42)
data = torch.randn(100000, 128)

In [7]:
target = torch.argmax(model(data), axis=1)

In [51]:
np.savez('synth_data', {'data': data, 'target': target})

In [8]:
target

tensor([0, 0, 1,  ..., 0, 1, 0])

In [9]:
dataset = torch.utils.data.TensorDataset(data, target)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [10]:
train_dataset[:]

(tensor([[-0.4786, -1.0465,  0.9237,  ...,  0.2253,  1.2762,  0.4853],
         [ 0.3496,  0.6459,  0.1231,  ..., -0.5505, -1.4439, -0.9298],
         [-0.5105, -0.6921,  3.0666,  ..., -0.4517, -1.4604,  0.6890],
         ...,
         [ 0.4699, -0.1049,  1.1463,  ..., -0.1402,  2.1803, -0.3842],
         [ 0.4893, -0.5197, -2.5331,  ...,  0.9431, -0.3671, -0.5147],
         [-0.8304, -1.0810,  1.0682,  ..., -0.5584,  1.3846, -0.0069]]),
 tensor([0, 0, 1,  ..., 1, 0, 0]))

In [23]:
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512)

In [29]:
DEV = torch.device('cuda')

In [48]:
model = nn.Sequential(
    nn.Linear(128, 32),
    nn.ReLU(),
    nn.Linear(32, 2),
    nn.Softmax(dim=1)
).to(DEV)

In [49]:
optimizer = torch.optim.Adam(model.parameters())

In [50]:
for epoch in range(1, 100):
    model.train()
    sum_ = 0
    count_ = 0
    for X, y in data_loader:
        out = model(X.to(DEV))
        loss = torch.nn.functional.cross_entropy(out, y.to(DEV))
        sum_ += loss.item()
        count_ += 1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    model.eval()
    sum1_ = 0
    count1_ = 0
    sum2_ = 0
    count2_ = 0
    for X, y in test_loader:
        out = model(X.to(DEV))
        bin_out = torch.argmax(out, axis=1)
        sum2_ += (bin_out == y.to(DEV)).sum().item() / X.shape[0]
        count2_ += 1
        sum1_ += torch.nn.functional.cross_entropy(out, y.to(DEV)).item()
        count1_ += 1
    print(epoch, sum_ / count_, sum1_ / count1_, sum2_ / count2_)

1 0.550980325527252 0.46881173029541967 0.858837890625
2 0.4463082136242253 0.43861790373921394 0.87314453125
3 0.4275215023262486 0.42570673376321794 0.88427734375
4 0.413534686443912 0.41188603192567824 0.90009765625
5 0.3988949443883957 0.3979099452495575 0.9154296875
6 0.38707061216330074 0.38882550224661827 0.92626953125
7 0.3795865003470403 0.38352841287851336 0.931640625
8 0.3744413862182836 0.38005381971597674 0.9337890625
9 0.370511502407159 0.37752060517668723 0.9361328125
10 0.36726545433329927 0.37547651380300523 0.93740234375
11 0.3644474017771946 0.37378252297639847 0.93935546875
12 0.36195406298728505 0.37228023409843447 0.94072265625
13 0.3597514502182128 0.37102048099040985 0.941455078125
14 0.3577718920768446 0.3698779411613941 0.941796875
15 0.3560138191007505 0.36885996013879774 0.94248046875
16 0.35443125560784794 0.36797242388129237 0.944287109375
17 0.35300500928216677 0.3671853855252266 0.9451171875
18 0.3516938851517477 0.36652852296829225 0.94541015625
19 0.35

In [53]:
model = TTModel(cfg).to(DEV)

In [54]:
optimizer = torch.optim.Adam(model.parameters())

In [55]:
for epoch in range(1, 100):
    model.train()
    sum_ = 0
    count_ = 0
    for X, y in data_loader:
        out = model(X.to(DEV))
        loss = torch.nn.functional.cross_entropy(out, y.to(DEV))
        sum_ += loss.item()
        count_ += 1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    model.eval()
    sum1_ = 0
    count1_ = 0
    sum2_ = 0
    count2_ = 0
    for X, y in test_loader:
        out = model(X.to(DEV))
        bin_out = torch.argmax(out, axis=1)
        sum2_ += (bin_out == y.to(DEV)).sum().item() / X.shape[0]
        count2_ += 1
        sum1_ += torch.nn.functional.cross_entropy(out, y.to(DEV)).item()
        count1_ += 1
    print(epoch, sum_ / count_, sum1_ / count1_, sum2_ / count2_)

1 0.6887694156853257 0.656994016468525 0.664013671875
2 0.6059344893048524 0.5519380688667297 0.70966796875
3 0.4952477258481797 0.45318569913506507 0.773486328125
4 0.435436566923834 0.4163321189582348 0.79853515625
5 0.3749590307284313 0.3254692979156971 0.850537109375
6 0.2577802103226352 0.1991395901888609 0.912451171875
7 0.16380377379572317 0.14377107061445712 0.938134765625
8 0.12663134899298856 0.11694768136367202 0.949267578125
9 0.1045853113577624 0.09737274991348385 0.9587890625
10 0.0873448651544987 0.08135504191741347 0.96484375
11 0.07248587638234637 0.067342540435493 0.9712890625
12 0.059311243628335605 0.055155406100675465 0.977099609375
13 0.048016602944606426 0.04515776135958731 0.981884765625
14 0.03866149792981565 0.036880869837477806 0.98505859375
15 0.03129530260280059 0.030392512492835522 0.988232421875
16 0.025665058167117418 0.025487600197084247 0.990771484375
17 0.021435917580535837 0.021896723099052905 0.992724609375
18 0.018293157463098408 0.0191486058756709

KeyboardInterrupt: 