In [1]:
import kan
import kan.utils as ku
import torch
import torch.nn as nn
import numpy as np
from libraries import utils
from libraries import magnetization
from libraries import j1j2_functions
import numpy.random as npr
import qutip as qt
import sympy
import random

In [2]:
N=10; J1=1; J2=0.2
h = j1j2_functions.J1J2_hamiltonian(N, J1, J2)
eigs = h.eigenstates()
gs = eigs[1][0]
gse = eigs[0][0]

In [4]:
def get_nonzero_states(N, gs, threshold):
    states = []
    signs = []
    for i in range(0, 2**N):
        val = gs[i][0].real
        if abs(val) > threshold:
            states.append(i)
            signs.append(-1 + 2 * int(val > 0))
    return states, signs

def find_deviations(states, signs, pred_signs):
    dev_forward = []
    for state, true, calc in zip(states, signs, pred_signs):
        if true != calc:
            dev_forward.append(state)
    
    dev_rev = []
    for state, true, calc in zip(states, signs, pred_signs):
        if true != -calc:
            dev_rev.append(state)
    if len(dev_rev) > len(dev_forward):
        print('forward')
        return dev_forward
    print('rev')
    return dev_rev

In [5]:
threshold = 1e-10
statesf, signsf = get_nonzero_states(N, gs, threshold)
pred_signsf =  [-1 + 2 * (magnetization.count_half_magnetization(i) % 2) for i in statesf]

print(len(statesf))
deviations = find_deviations(statesf, signsf, pred_signsf)
print(len(deviations))


252
forward
10


In [6]:
input = utils.generate_input_samples(N, statesf)
labels = torch.tensor(signsf, dtype=torch.float64).reshape((-1, 1))
dataset = {'train_input': input, 'train_label': labels, 'test_input': input, 'test_label': labels}
print(input.shape, labels.shape)

torch.Size([252, 10]) torch.Size([252, 1])


In [88]:
def MSR(bin_state, N):
    return (-torch.tanh(10 * torch.cos(np.pi * sum(bin_state[:, i] for i in range(0, N, 2))))).reshape((-1, 1))

In [89]:
statesf[5]

62

In [90]:
num = 20
print(dataset['train_input'][num])
print(MSR(dataset['train_input'][np.array([num])], N))
print(labels[num])


tensor([0., 0., 1., 1., 1., 1., 1., 0., 0., 0.])
tensor([[1.]])
tensor([-1.], dtype=torch.float64)


In [91]:
from tqdm import tqdm
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1,
              metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):

        if lamb > 0. and not self.save_act:
            print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')
            
        old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb)

        pbar = tqdm(range(steps), desc='description', ncols=100)

        if loss_fn == None:
            loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
        else:
            loss_fn = loss_fn_eval = loss_fn

        grid_update_freq = int(stop_grid_update_step / grid_update_num)

        if opt == "Adam":
            optimizer = torch.optim.Adam(self.get_params(), lr=lr)
        elif opt == "LBFGS":
            optimizer = kan.LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)

        results = {}
        results['train_loss'] = []
        results['test_loss'] = []
        results['reg'] = []
        if metrics != None:
            for i in range(len(metrics)):
                results[metrics[i].__name__] = []

        if batch == -1 or batch > dataset['train_input'].shape[0]:
            batch_size = dataset['train_input'].shape[0]
            batch_size_test = dataset['test_input'].shape[0]
        else:
            batch_size = batch
            batch_size_test = batch

        global train_loss, reg_

        def closure():
            global train_loss, reg_
            optimizer.zero_grad()
            pred = (MSR(dataset['train_input'][train_id], N)) * (self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th))
            train_loss = loss_fn(pred, dataset['train_label'][train_id])
            if self.save_act:
                if reg_metric == 'edge_backward':
                    self.attribute()
                if reg_metric == 'node_backward':
                    self.node_attribute()
                reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
            else:
                reg_ = torch.tensor(0.)
            objective = train_loss + lamb * reg_
            objective.backward()
            return objective

        for _ in pbar:
            
            if _ == steps-1 and old_save_act:
                self.save_act = True
                
            if save_fig and _ % save_fig_freq == 0:
                save_act = self.save_act
                self.save_act = True
            
            train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
            test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)

            if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step:
                self.update_grid(dataset['train_input'][train_id])

            if opt == "LBFGS":
                optimizer.step(closure)

            if opt == "Adam":
                pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th)
                train_loss = loss_fn(pred, dataset['train_label'][train_id])
                if self.save_act:
                    if reg_metric == 'edge_backward':
                        self.attribute()
                    if reg_metric == 'node_backward':
                        self.node_attribute()
                    reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
                else:
                    reg_ = torch.tensor(0.)
                loss = train_loss + lamb * reg_
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            test_loss = loss_fn_eval((MSR(dataset['test_input'][test_id], N)) * (self.forward(dataset['test_input'][test_id], singularity_avoiding=singularity_avoiding, y_th=y_th)), dataset['test_label'][test_id])
            
            
            if metrics != None:
                for i in range(len(metrics)):
                    results[metrics[i].__name__].append(metrics[i]().item())

            results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
            results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
            results['reg'].append(reg_.cpu().detach().numpy())

            if _ % log == 0:
                if display_metrics == None:
                    pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
                else:
                    string = ''
                    data = ()
                    for metric in display_metrics:
                        string += f' {metric}: %.2e |'
                        try:
                            results[metric]
                        except:
                            raise Exception(f'{metric} not recognized')
                        data += (results[metric][-1],)
                    pbar.set_description(string % data)
                    

        self.log_history('fit')
        # revert back to original state
        self.symbolic_enabled = old_symbolic_enabled
        return results

In [152]:
dev_model2N = kan.KAN(width = [N, N, 1, 1])
dev_model2N(dataset['train_input'])
dev_model2N.fix_symbolic(2, 0, 0, 'tanh')
dev_model2N.symbolic_fun[2].affine = nn.Parameter(torch.tensor([[[1, 0, 1, 0]]], dtype=torch.float64), requires_grad=False)

checkpoint directory created: ./model
saving model version 0.0
r2 is 0.9978489875793457
saving model version 0.1


In [153]:
fit(dev_model2N, dataset, steps=5, lamb=0); # this seems to work very well but then sometimes we get much higher loss if we train too much

| train_loss: 3.75e-09 | test_loss: 3.75e-09 | reg: 9.87e+01 | : 100%|█| 5/5 [00:02<00:00,  1.92it/s

saving model version 0.2





In [154]:
fit(dev_model2N, dataset, steps=25, lamb=5e-3);

| train_loss: 9.06e-02 | test_loss: 9.06e-02 | reg: 2.21e+01 | : 100%|█| 25/25 [00:17<00:00,  1.45it

saving model version 0.3





In [155]:
dev_model2N.prune().width

saving model version 0.4


[[10, 0], [6, 0], [1, 0], [1, 0]]

In [161]:
torch.sqrt(torch.mean(((MSR(input, N) * dev_model2N(input, y_th=1000)) - labels) ** 2))

tensor(1.7594, dtype=torch.float64, grad_fn=<SqrtBackward0>)

In [158]:
pred = MSR(input, N) * dev_model2N(input)
pred_dev = []
for i in range(labels.shape[0]):
    if torch.sgn(pred[i][0])  != torch.sgn(labels[i][0]):
        pred_dev.append(i)
print(len(pred_dev))

223


In [143]:
dev_model2N = dev_model2N.prune()

saving model version 0.4


In [144]:
fit(dev_model2N, dataset, steps=25, lamb=1e-2, update_grid=False);

| train_loss: 2.57e-01 | test_loss: 2.57e-01 | reg: 5.39e+00 | : 100%|█| 25/25 [00:10<00:00,  2.44it

saving model version 0.5





In [145]:
pred = MSR(input, N) * dev_model2N(input)
pred_dev = []
for i in range(labels.shape[0]):
    if torch.sgn(pred[i][0])  != torch.sgn(labels[i][0]):
        pred_dev.append(i)
print(len(pred_dev))

6


In [146]:
dev_model2N.prune().width

saving model version 0.6


[[10, 0], [3, 0], [1, 0], [1, 0]]

In [None]:
dev_model2N = dev_model2N.prune()

In [147]:
fit(dev_model2N, dataset, steps=25, lamb=5e-3, update_grid=False);

| train_loss: 2.51e-01 | test_loss: 2.51e-01 | reg: 5.69e+00 | : 100%|█| 25/25 [00:04<00:00,  5.35it

saving model version 0.6





In [148]:
pred = MSR(input, N) * dev_model2N(input)
pred_dev = []
for i in range(labels.shape[0]):
    if torch.sgn(pred[i][0])  != torch.sgn(labels[i][0]):
        pred_dev.append(i)
print(len(pred_dev))

6
