In [1]:
import numpy as np
import libraries.tfim_functions as tfim_functions
import libraries.utils as utils
import torch
import torch.nn as nn
from libraries.NeuralStates import *
from kan import MultKAN
import time

In [2]:
def generate_eloc_distr(sampled_vector, N, J, Gamma, model):
    nn_output_calcs = {}
    def model_to_output(x):
        if x in sampled_vector.nn_output:
            return sampled_vector.nn_output[x]
        if x in nn_output_calcs:
            return nn_output_calcs[x]
        tens = torch.tensor([lib.generate_state_array(x, N)], dtype = torch.float32)
        output = model(tens)[0]
        nn_output_calcs[x] = output
        return output
    
    eloc_values = {}
    for basis_state in sampled_vector.distribution:
        eloc = 0
        output = model_to_output(basis_state)
        for adjacency in tfim_functions.generate_adjacencies(basis_state, N):
            output_prime = model_to_output(adjacency)
            eloc += tfim_functions.calc_H_elem(N, J, Gamma, basis_state, adjacency) * torch.exp(output_prime[0] - output[0] + 1.j * (output_prime[1] - output[1]))
        eloc_values[basis_state] = eloc
    return eloc_values

def generate_eloc_distr_efficient(sampled_vector, N, J, Gamma, model):
    to_calculate = []
    visited = {}
    for basis_state in sampled_vector.distribution:
        for adj in tfim_functions.generate_adjacencies(basis_state, N):
            if adj not in sampled_vector.nn_output and adj not in visited:
                to_calculate.append(adj)
                visited[adj] = len(to_calculate) - 1
    nn_output_calcs = model(utils.generate_input_samples(N, to_calculate)) if to_calculate else None

    def model_to_output(x):
        if x in sampled_vector.nn_output:
            return sampled_vector.nn_output[x]
        if x in visited:
            return nn_output_calcs[visited[x]]
        raise Exception('should not reach')
    
    eloc_values = {}
    for basis_state in sampled_vector.distribution:
        eloc = 0
        output = model_to_output(basis_state)
        for adjacency in tfim_functions.generate_adjacencies(basis_state, N):
            output_prime = model_to_output(adjacency)
            eloc += tfim_functions.calc_H_elem(N, J, Gamma, basis_state, adjacency) * torch.exp(output_prime[0] - output[0] + 1.j * (output_prime[1] - output[1]))
        eloc_values[basis_state] = eloc
    return eloc_values

In [3]:
N = 2
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))
mlp_model = nn.Sequential(*layers)

In [4]:
mlp_mh_state = MHNeuralState(N, mlp_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), 0, 256)

In [5]:
generate_eloc_distr(mlp_mh_state, N, 1, 1, mlp_model)

{0: tensor(-3.8771+0.0212j, grad_fn=<AddBackward0>),
 1: tensor(0.1552+0.2450j, grad_fn=<AddBackward0>),
 2: tensor(-0.0864-0.2296j, grad_fn=<AddBackward0>),
 3: tensor(-4.1640-0.1147j, grad_fn=<AddBackward0>)}

In [6]:
generate_eloc_distr_efficient(mlp_mh_state, N, 1, 1, mlp_model)

{0: tensor(-3.8771+0.0212j, grad_fn=<AddBackward0>),
 1: tensor(0.1552+0.2450j, grad_fn=<AddBackward0>),
 2: tensor(-0.0864-0.2296j, grad_fn=<AddBackward0>),
 3: tensor(-4.1640-0.1147j, grad_fn=<AddBackward0>)}

In [7]:
N = 10
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))
mlp_model = nn.Sequential(*layers)

In [8]:
mlp_mh_state = MHNeuralState(N, mlp_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), 0, 5)

In [9]:
generate_eloc_distr(mlp_mh_state, N, 1, 1, mlp_model)

{512: tensor(-16.0188+0.0788j, grad_fn=<AddBackward0>),
 0: tensor(-19.9186-0.1198j, grad_fn=<AddBackward0>),
 2: tensor(-15.9095-0.0636j, grad_fn=<AddBackward0>),
 6: tensor(-16.0861+0.0744j, grad_fn=<AddBackward0>),
 22: tensor(-12.0823+0.0055j, grad_fn=<AddBackward0>)}

In [10]:
generate_eloc_distr_efficient(mlp_mh_state, N, 1, 1, mlp_model)

{512: tensor(-16.0188+0.0788j, grad_fn=<AddBackward0>),
 0: tensor(-19.9186-0.1198j, grad_fn=<AddBackward0>),
 2: tensor(-15.9095-0.0636j, grad_fn=<AddBackward0>),
 6: tensor(-16.0861+0.0744j, grad_fn=<AddBackward0>),
 22: tensor(-12.0823+0.0055j, grad_fn=<AddBackward0>)}

In [19]:
mlp_mh_state_256 = MHNeuralState(N, mlp_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), 0, 256)
# as num_samples goes up, the batched/unbatched times approach each other as the number of additional calculations goes to 0
# as num_samples gets smaller, we get more improvement in batching since more of the adjacent states haven't already been calculated
# ~ 3x improvement for num_samples=5

In [20]:
start = time.time() 
generate_eloc_distr(mlp_mh_state_256, N, 1, 1, mlp_model)
print(f'MLP unbatched time: {time.time() - start}')
start = time.time() 
generate_eloc_distr_efficient(mlp_mh_state_256, N, 1, 1, mlp_model)
print(f'MLP batched time: {time.time() - start}')

MLP unbatched time: 0.14342761039733887
MLP batched time: 0.08694863319396973


In [None]:
kan_model = MultKAN(width=[N, N, 2], device='cpu')
kan_mh_state = MHNeuralState(N, kan_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), 0, 256)
# batching seems to provide massive improvement regardless of num_samples

checkpoint directory created: ./model
saving model version 0.0


In [30]:
start = time.time() 
generate_eloc_distr(kan_mh_state, N, 1, 1, kan_model)
print(f'KAN unbatched time: {time.time() - start}')
start = time.time() 
generate_eloc_distr_efficient(kan_mh_state, N, 1, 1, kan_model)
print(f'KAN batched time: {time.time() - start}')

KAN unbatched time: 9.036884784698486
KAN batched time: 0.13714289665222168
