In [26]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tensornetworks_pytorch.TNModels import PosMPS, Born
print(torch.__version__)

1.8.0.dev20201128


Use some data

In [27]:
import pickle
for dataset in [#'biofam',
    'flare','lymphography','spect','tumor','votes']:
    with open('datasets/'+dataset, 'rb') as f:
            a=pickle.load(f)
    X=a[0].astype(int)
    print(dataset)
    print("\tdata shape:", X.shape)
    print(f"\trange of X values: {X.min()} -- {X.max()}")

def load_dataset(dataset):
    with open('datasets/'+dataset, 'rb') as f:
            a=pickle.load(f)
    X=a[0]
    X=X.astype(int)

    print("data shape:", X.shape)
    print(f"range of X values: {X.min()} -- {X.max()}")
    print(f"setting d={X.max()+1}")
    d = X.max()+1
    return X, d

flare
	data shape: (1065, 13)
	range of X values: 0 -- 7
lymphography
	data shape: (148, 19)
	range of X values: 0 -- 7
spect
	data shape: (187, 23)
	range of X values: 0 -- 1
tumor
	data shape: (339, 17)
	range of X values: 0 -- 3
votes
	data shape: (435, 17)
	range of X values: 0 -- 2


In [103]:
X,d = load_dataset('flare')

data shape: (1065, 13)
range of X values: 0 -- 7
setting d=8


In [106]:
D = 7
mps = PosMPS(D=D, d=d, verbose=True)
rBorn = Born(D=D, d=d, dtype=torch.float, verbose=True) 
cBorn = Born(D=D, d=d, dtype=torch.cfloat, verbose=True)
models = (mps, rBorn, cBorn)
for model in models:
    print(model.core.shape, model.name)
print("===")
for model in models:
    model.fit(X, d)

torch.Size([8, 7, 7]) Positive MPS
torch.Size([8, 7, 7]) Born model torch.float32
torch.Size([8, 7, 7]) Born model torch.complex64
===


In [107]:
from tqdm.notebook import tqdm
batchsize=30
trainloader = DataLoader(X, batch_size=batchsize, shuffle=True)
model = rBorn
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.0)

max_epochs = 5
print("Training...")
av_batch_loss_running = -1000
for epoch in range(max_epochs):
    print("epoch", epoch)
    batch_loss = []
    for batch_idx, batch in tqdm(enumerate(trainloader), leave=False):
        model.zero_grad()
        neglogprob = 0
        for x in batch:
            neglogprob -= model(x)
        loss = neglogprob / batchsize
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            batch_loss.append(loss.item())
            if batch_idx % 10 == 0: # print every 10th batch loss
                print("\tbatch", batch_idx, "loss", loss.item())
    av_batch_loss = torch.Tensor(batch_loss).mean().item()
    print("\tavg batch_loss", av_batch_loss)
    if abs(av_batch_loss_running - av_batch_loss) < .1:
        print("Early stopping")
        break
    av_batch_loss_running = av_batch_loss
print('Finished training. Last av loss = ', av_batch_loss)

Training...
epoch 0


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

	batch 0 loss 26.36189842224121
	batch 10 loss 19.781164169311523
	batch 20 loss 17.422203063964844
	batch 30 loss 16.521820068359375
	avg batch_loss 18.695804595947266
epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

	batch 0 loss 16.022523880004883
	batch 10 loss 15.371594429016113
	batch 20 loss 15.091082572937012
	batch 30 loss 14.59440803527832
	avg batch_loss 15.08620834350586
epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

	batch 0 loss 14.636351585388184
	batch 10 loss 14.228934288024902
	batch 20 loss 13.433638572692871
	batch 30 loss 14.586594581604004
	avg batch_loss 13.883378028869629
epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

	batch 0 loss 14.344987869262695
	batch 10 loss 13.527606010437012
	batch 20 loss 12.727429389953613
	batch 30 loss 11.847824096679688
	avg batch_loss 12.794553756713867
epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

	batch 0 loss 11.9341402053833
	batch 10 loss 11.759820938110352
	batch 20 loss 13.489293098449707
	batch 30 loss 13.647462844848633
	avg batch_loss 12.394817352294922
Finished training. Av loss =  12.394817352294922
