In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tensornetworks_pytorch.TNModels import PosMPS, Born
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

print(torch.__version__)

In [None]:
### import data
import pickle
for dataset in [#'biofam',
    'flare','lymphography','spect','tumor','votes']:
    with open('datasets/'+dataset, 'rb') as f:
            a=pickle.load(f)
    X=a[0].astype(int)
    print(dataset)
    print("\tdata shape:", X.shape)
    print(f"\trange of X values: {X.min()} -- {X.max()}")

def load_dataset(dataset):
    with open('datasets/'+dataset, 'rb') as f:
            a=pickle.load(f)
    X=a[0]
    X=X.astype(int)

    print("\tdata shape:", X.shape)
    print(f"\trange of X values: {X.min()} -- {X.max()} ==> d={X.max()+1}")
    d = X.max()+1
    return X, d

In [None]:
### initialize models
dataset = 'lymphography'
print("dataset:", dataset)
X,d = load_dataset(dataset)

D = 2
mps      = PosMPS(X, d, D, homogeneous=False)
mps_hom  = PosMPS(X, d, D, homogeneous=True)

# mps_s    = PosMPS(X, d, D, homogeneous=False, log_stability=True)
# mps_s_hom= PosMPS(X, d, D, homogeneous=True, log_stability=True)

rBorn      = Born(X, d, D, dtype=torch.float, homogeneous=False, log_stability=False) 
rBorn_hom  = Born(X, d, D, dtype=torch.float, homogeneous=True, log_stability=False) 

rBorn_s    = Born(X, d, D, dtype=torch.float, homogeneous=False, log_stability=True) 
rBorn_s_hom= Born(X, d, D, dtype=torch.float, homogeneous=True, log_stability=True) 

cBorn      = Born(X, d, D, dtype=torch.cfloat, homogeneous=False, log_stability=False)
cBorn_hom  = Born(X, d, D, dtype=torch.cfloat, homogeneous=True, log_stability=False)

cBorn_s    = Born(X, d, D, dtype=torch.cfloat, homogeneous=False, log_stability=True)
cBorn_s_hom= Born(X, d, D, dtype=torch.cfloat, homogeneous=True, log_stability=True)

models     = (
    rBorn, cBorn, rBorn_s, cBorn_s, mps#, mps_c
)
models_hom = (
    rBorn_hom, cBorn_hom, rBorn_s_hom, cBorn_s_hom, mps_hom#, mps_c_hom
)

def clip_grad(grad, clip_val, param_name, verbose=False):
    if torch.isnan(grad).any():
        print(f"│ Hook: NaN value in gradient of {param_name}, {grad.size()}")
    if grad.dtype==torch.cfloat:
        for ext, v in [("min", grad.real.min()),("max", grad.real.max())]:
            if verbose and abs(v) > clip_val:
                print(f"clipping real {ext} {v:.2} to size {clip_val}")
        for ext, v in [("min", grad.imag.min()),("max", grad.imag.max())]:
            if verbose and abs(v) > clip_val:
                print(f"clipping imag {ext} {1.j*v:.2} to size {clip_val}")
        clipped_grad = torch.complex(grad.real.clamp(-clip_val, clip_val),
                                     grad.imag.clamp(-clip_val, clip_val))
    else:
        for ext, v in [("min", grad.min()),("max", grad.max())]:
            if verbose and abs(v) > clip_val:
                print(f"clipping {ext} {v:.2} to size {clip_val}")
        clipped_grad = torch.clamp(grad, -clip_val, clip_val)
    return clipped_grad

print("Initializing models:")
for model in (*models, *models_hom):
    print(f"\t{model.core.shape} model type: {model.name}")
    clip_val = 1000
    for param_index, p in enumerate(model.parameters()):
        pnames = list(model.state_dict().keys())
        p.register_hook(lambda grad: clip_grad(grad, clip_val, pnames[param_index], verbose=True))
        if torch.isnan(p).any():
            print(f"{pnames[param_index]} contains a NaN value!")

In [None]:
def train(self, dataset, batchsize, max_epochs, plot=True, **optim_kwargs):
    trainloader = DataLoader(dataset, batch_size=batchsize, shuffle=True)
    optimizer = torch.optim.SGD(self.parameters(), **optim_kwargs)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    early_stopping_threshold = 0.0001 # min difference in epoch loss
    loss_values = [] # store by-epoch avg loss values
    print('╭───────────────────────────')
    print(f"│Training {self.name},")
    print(f"│  batchsize:{batchsize} opt:{optim_kwargs}.")
    av_batch_loss_running = -1000
    with tqdm(range(max_epochs), leave=True) as tepochs:
        for epoch in tepochs:
            batch_loss = []
            with tqdm(trainloader, unit="batch", leave=False, desc=f"epoch {epoch}") as tepoch:
                for batch in tepoch:
                    for p in self.parameters():
                        if torch.isnan(p).any():
                            print("│ loss values:", *(f"{x:.3f}" for x in loss_values))
                            print("└────Stopped. After updating, model weights contain a NaN value!")
                            if plot:
                                plt.plot(loss_values)
                                plt.show()
                            return loss_values
                    self.zero_grad()
                    neglogprob = 0
                    for i,x in enumerate(batch):
                        out = self(x)
                        neglogprob -= out
                    loss = neglogprob / len(batch)
#                     scheduler.step(loss)
                    loss.backward()
                    for p in self.parameters():
                        if torch.isnan(p.grad).any():
                            print("│ loss values:", *(f"{x:.3f}" for x in loss_values))
                            print("└────Stopped. Gradient contains a NaN value!")
                            if plot:
                                plt.plot(loss_values)
                                plt.show()
                            return loss_values
                    optimizer.step()
                    tepoch.set_postfix(loss=loss.item())
                    with torch.no_grad():
                        batch_loss.append(loss.item())
            av_batch_loss = torch.Tensor(batch_loss).mean().item()
            #print(f"ep{epoch} av_batch_loss\t {av_batch_loss}")
            loss_values.append(av_batch_loss)
            tepochs.set_postfix(av_batch_loss=av_batch_loss)
            if abs(av_batch_loss_running - av_batch_loss) < early_stopping_threshold:
                print("└────Early stopping.")
                break
            av_batch_loss_running = av_batch_loss
    print("│ loss values:", *(f"{x:.3f}" for x in loss_values))
    if plot:
        plt.plot(loss_values)
        plt.show()
    print('╰────────Finished─training──\n')
    return loss_values

In [None]:
# modelhom_loss_values={}
# for model in [rBorn_hom]:
#     loss_values = train(model, X, batchsize=20, plot=False, max_epochs=50, lr=0.1)
#     plt.plot(loss_values, label=model.name)
#     plt.ylabel('avg loss (NLL)')
#     plt.xlabel('Epoch')
#     plt.title(f"dataset: {dataset} (d={d}), bond dim={D}")
#     modelhom_loss_values["model.name"]=loss_values
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
# plt.show()

In [None]:
modelhom_loss_values={}
batchsize=35
max_epochs=100
optim_kwargs = dict(lr=0.1)
for model in models_hom:
    loss_values = train(model, X, batchsize=batchsize, plot=False, max_epochs=max_epochs, **optim_kwargs)
    plt.plot(loss_values, label=model.name)
    plt.ylabel('avg loss (NLL)')
    plt.xlabel('Epoch')
    plt.title(f"dataset: {dataset} (d={d}), bond dim={D}\n batchsize:{batchsize}, {optim_kwargs}")
    modelhom_loss_values["model.name"]=loss_values
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
plt.show()

In [None]:
modelhom_loss_values={}
batchsize=35
max_epochs=100
optim_kwargs = dict(lr=0.05)
for model in models:
    loss_values = train(model, X, batchsize=batchsize, plot=False, max_epochs = max_epochs, **optim_kwargs)
    plt.plot(loss_values, label=model.name)
    plt.ylabel('avg loss (NLL)')
    plt.xlabel('Epoch')
    plt.title(f"dataset: {dataset} (d={d}), bond dim={D}\n batchsize:{batchsize}, {optim_kwargs}")
    modelhom_loss_values["model.name"]=loss_values
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
plt.show()

## Useful things?

In [None]:
class TTrain(nn.Module):
    """Abstract class for Tensor Train models.  Use instantiating class.

    Parameters:
        D (int): bond dimension
        d (int): physical dimension (number of categories in data)
        dtype ([tensor.dtype]): 
            tensor.float for real, or tensor.cfloat for complex
    """
    def __init__(self, dataset, d, D, dtype, homogeneous=True, verbose=False):
        super().__init__()
        self.D = D
        self.d = d
        self.verbose = verbose
        self.homogeneous = homogeneous
        self.n_datapoints = dataset.shape[0]
        self.seqlen = dataset.shape[1]

        w_init = self.randomsign_ones # alternatively, use torch.ones

        # the following are set to nn.Parameters thus are backpropped over
        k_core = (d*D*D)**-0.5 
        k_vectors = (d)**-0.5
        # TODO k should be (d*D*D)**-0.5, 
        # we should use randn instead of rand, 
        # but this seems to make more NaNs for the homogeneous models.
        if homogeneous: # initialize single core to be repeated
            core = k_core * w_init(
                (d, D, D), dtype=dtype)
            self.core = nn.Parameter(core)
        else: # initialize seqlen different non-homogeneous cores
            core = k_core * w_init(
                (self.seqlen, d, D, D), dtype=dtype)
            self.core = nn.Parameter(core)
        self.left_boundary = nn.Parameter(
            k_vectors * w_init((D), dtype=dtype))
        self.right_boundary = nn.Parameter(
            k_vectors * w_init((D), dtype=dtype))
    
    @staticmethod
    def randomsign_ones(shape, dtype=torch.float):
        """Makes a vector of ones with random sign, 
        or if dtype is torch.cfloat, randomized real or imaginary units"""
        x = torch.zeros(shape)
        if dtype==torch.cfloat:
            random4=torch.randint_like(x,4)
            r = x + 1*(random4==0) - 1*(random4==1) 
            i = x + 1*(random4==2) - 1*(random4==3)
            out = torch.complex(r,i)
        else:
            random2=torch.randint_like(x,2)
            out = x + 1*(random2==0) - 1*(random2==1) 
        return torch.tensor(out, dtype=dtype)
