In [1]:
import kan
import torch
import torch.nn as nn
import numpy as np
from libraries import utils
from libraries import magnetization
import numpy.random as npr
import qutip as qt

In [2]:
N = 6; J1 = 1; J2 = 0

In [3]:
# mlp_model = nn.Sequential(
#     nn.Linear(N, 5*N),
#     nn.Sigmoid(),
#     nn.Linear(5*N, 2),
#     nn.Sigmoid()
# )
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))
mlp_model = nn.Sequential(*layers)

In [4]:
def J1J2_hamiltonian(N, j1, j2):
    id = qt.qeye(2)
    x = qt.sigmax()
    y = qt.sigmay()
    z = qt.sigmaz()
    sxi = []; syi = []; szi = []
    for i in range(N):
        sxi.append(qt.tensor([id] * i + [x] + [id] * (N - i - 1)))
        syi.append(qt.tensor([id] * i + [y] + [id] * (N - i - 1)))
        szi.append(qt.tensor([id] * i + [z] + [id] * (N - i - 1)))
    sis = [sxi, syi, szi]
    J1_term = j1 * (sum(sis[coord][i] * sis[coord][i + 1] for coord in range(len(sis)) for i in range(N - 1)) + sum(sis[coord][N - 1] * sis[coord][0] for coord in range(len(sis))))
    J2_term = j2 * (sum(sis[coord][i] * sis[coord][i + 2] for coord in range(len(sis)) for i in range(N - 2)) + sum(sis[coord][N - 2] * sis[coord][0] for coord in range(len(sis))) + sum(sis[coord][N - 1] * sis[coord][1] for coord in range(len(sis))))
    return J1_term + J2_term


In [5]:
h = J1J2_hamiltonian(N, J1, J2)
h

Quantum object: dims=[[2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2]], shape=(64, 64), type='oper', dtype=CSR, isherm=True
Qobj data =
[[6. 0. 0. ... 0. 0. 0.]
 [0. 2. 2. ... 0. 0. 0.]
 [0. 2. 2. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 2. 2. 0.]
 [0. 0. 0. ... 2. 2. 0.]
 [0. 0. 0. ... 0. 0. 6.]]

In [6]:
eigs = h.eigenstates()
gse = eigs[0][0]
gs = eigs[1][0]
print(gse)

-11.21110255092797


In [7]:
for i in range(2 ** N):
    print(bin(i), gs[i][0])

0b0 0j
0b1 0j
0b10 (-8.293141773196102e-18+0j)
0b11 (1.4810524357047876e-17+0j)
0b100 (1.8666965225373454e-17+0j)
0b101 (-1.0525177992486299e-16+0j)
0b110 (2.7022968735892047e-17+0j)
0b111 (-0.06292062151869875+0j)
0b1000 0j
0b1001 (-4.1265433344754184e-17+0j)
0b1010 (1.2190834469159135e-16+0j)
0b1011 (0.20781269586291384+0j)
0b1100 (-1.2059301697899346e-17+0j)
0b1101 (-0.20781269586291368+0j)
0b1110 (0.06292062151869872+0j)
0b1111 (-1.1102230246251563e-16+0j)
0b10000 (-1.866696522537345e-17+0j)
0b10001 (-1.2789183584906882e-16+0j)
0b10010 (-4.9810630208592375e-17+0j)
0b10011 (-0.2078126958629138+0j)
0b10100 (-1.0709850332208275e-16+0j)
0b10101 (0.4785460132445264+0j)
0b10110 (-0.2078126958629138+0j)
0b10111 (1.4186407099380446e-17+0j)
0b11000 (6.860533993502216e-17+0j)
0b11001 (-0.2078126958629138+0j)
0b11010 (0.20781269586291382+0j)
0b11011 (-5.044225972796643e-18+0j)
0b11100 (-0.06292062151869879+0j)
0b11101 (2.9539520701718193e-18+0j)
0b11110 (-2.594959655701273e-18+0j)
0b11111 (-1

In [8]:
gs.dag() * h * gs

(-11.211102550927974+0j)

In [9]:
input = utils.generate_input_torch(N)
print(input.shape)

torch.Size([64, 6])


In [10]:
mlp_model(input)

tensor([[-1.1929e-02,  1.8568e-01],
        [-6.7976e-02,  2.6127e-01],
        [-2.4875e-03,  3.3872e-01],
        [-4.6186e-02,  3.8103e-01],
        [ 2.5269e-02,  2.2020e-01],
        [-3.3347e-02,  2.8118e-01],
        [ 3.1065e-02,  3.6357e-01],
        [-2.7021e-02,  4.1352e-01],
        [-1.0046e-01,  2.4727e-02],
        [-1.3677e-01,  1.1315e-01],
        [-1.1655e-01,  1.7444e-01],
        [-1.3664e-01,  2.3219e-01],
        [-8.6669e-02,  6.2621e-02],
        [-1.2047e-01,  1.4311e-01],
        [-8.7269e-02,  2.0231e-01],
        [-1.2295e-01,  2.6812e-01],
        [-1.4822e-01,  1.0291e-02],
        [-1.8022e-01,  9.2268e-02],
        [-1.2646e-01,  1.5634e-01],
        [-1.6633e-01,  2.4178e-01],
        [-8.9690e-02,  4.4332e-02],
        [-1.3345e-01,  1.1796e-01],
        [-8.9779e-02,  1.8794e-01],
        [-1.2266e-01,  2.5293e-01],
        [-2.4327e-01, -1.1737e-01],
        [-2.7693e-01, -3.9381e-02],
        [-2.3025e-01,  1.9170e-02],
        [-2.6474e-01,  9.030

In [11]:
def amp_phase(nn_output):
    return nn_output[:, 0] * torch.exp(1.j * 2 * np.pi * nn_output[:, 1])
def reim(nn_output):
    return nn_output[:, 0] + 1.j * nn_output[:, 1]
psi = amp_phase(mlp_model(input))
print(psi)
print(psi.dtype)

tensor([-0.0047-1.0968e-02j,  0.0048-6.7805e-02j,  0.0013-2.1109e-03j,
         0.0339-3.1399e-02j,  0.0047+2.4828e-02j,  0.0065-3.2709e-02j,
        -0.0203+2.3486e-02j,  0.0231-1.3971e-02j, -0.0992-1.5545e-02j,
        -0.1036-8.9254e-02j, -0.0533-1.0366e-01j, -0.0153-1.3578e-01j,
        -0.0800-3.3228e-02j, -0.0750-9.4309e-02j, -0.0258-8.3380e-02j,
         0.0140-1.2215e-01j, -0.1479-9.5769e-03j, -0.1508-9.8727e-02j,
        -0.0702-1.0518e-01j, -0.0086-1.6611e-01j, -0.0862-2.4661e-02j,
        -0.0984-9.0100e-02j, -0.0341-8.3041e-02j,  0.0023-1.2264e-01j,
        -0.1801+1.6357e-01j, -0.2685+6.7827e-02j, -0.2286-2.7666e-02j,
        -0.2233-1.4228e-01j, -0.1796+9.4576e-02j, -0.2406-7.5758e-05j,
        -0.1775-7.3791e-02j, -0.1516-1.6426e-01j, -0.0843-4.0456e-02j,
        -0.0761-9.9111e-02j, -0.0203-8.5509e-02j,  0.0132-1.2135e-01j,
        -0.0495-3.3197e-02j, -0.0489-7.9833e-02j, -0.0048-5.8232e-02j,
         0.0183-8.4577e-02j, -0.1713+5.9777e-02j, -0.2200-3.1426e-02j,
      

In [12]:
np_array = amp_phase(mlp_model(input)).detach().numpy()
np_array.reshape((1, -1)).conj() @ h.full() @ np_array.reshape((-1, 1))

array([[9.51835589-1.11022302e-16j]])

In [13]:
torch.tensor(h.full(), dtype = torch.complex64).dtype # default complex128

torch.complex64

In [14]:
amp_phase(mlp_model(input)).reshape((1, -1)).conj() @ torch.tensor(h.full(), dtype = torch.complex64) @ amp_phase(mlp_model(input)).reshape((-1, 1))

tensor([[9.5184+6.8767e-08j]], grad_fn=<MmBackward0>)

In [42]:
# mlp_model = nn.Sequential(
#     nn.Linear(N, 5*N),
#     nn.Sigmoid(),
#     nn.Linear(5*N, 2),
#     nn.Sigmoid()
# )
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))

mlp_model = nn.Sequential(*layers)
torch_h = torch.tensor(h.full(), dtype = torch.complex64)
epochs = []
loss_data = []
data_rate = 1
optimizer = torch.optim.Adam(mlp_model.parameters(), lr = 1e-3) # adam does better here idk why
for epoch in range(1000):
    psi = utils.log_amp_phase(mlp_model(input)).reshape((-1, 1))
    loss = ((psi.conj().T @ torch_h @ psi).real[0][0]) / (psi.conj().T @ psi).real[0][0]
    if epoch % data_rate == 0:
        loss_data.append(loss.item())
        epochs.append(epoch)
        print(f'{epoch}: {loss}')
    if loss < -12:
        break
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0: 5.979545593261719
1: 5.964103698730469
2: 5.943150043487549
3: 5.916226863861084
4: 5.882883548736572
5: 5.842474460601807
6: 5.794460773468018
7: 5.737898826599121
8: 5.672120571136475
9: 5.596042156219482
10: 5.508618354797363
11: 5.408514022827148
12: 5.295520305633545
13: 5.1670241355896
14: 5.0208353996276855
15: 4.8550896644592285
16: 4.668469429016113
17: 4.458975315093994
18: 4.223692417144775
19: 3.9612302780151367
20: 3.6707708835601807
21: 3.3516879081726074
22: 3.004044771194458
23: 2.6278624534606934
24: 2.2240965366363525
25: 1.794600486755371
26: 1.3425474166870117
27: 0.8723588585853577
28: 0.3900247812271118
29: -0.0974024161696434
30: -0.5823349356651306
31: -1.0572309494018555
32: -1.515920877456665
33: -1.9521870613098145
34: -2.3588356971740723
35: -2.7359349727630615
36: -3.0851993560791016
37: -3.4043962955474854
38: -3.6921682357788086
39: -3.9478421211242676
40: -4.171072483062744
41: -4.362283706665039
42: -4.524722576141357
43: -4.657912254333496
44: -4.76

In [29]:
[n for n in mlp_model.parameters()]

[Parameter containing:
 tensor([[ 0.5101,  0.0084, -0.3574, -0.0776,  0.1328,  0.4142],
         [ 0.3997,  0.1774, -0.1862, -0.3531,  0.1492,  0.2229],
         [ 0.1122, -0.0074,  0.2557,  0.1233,  0.0894,  0.3561],
         [-0.3501, -0.1817, -0.1540, -0.2015,  0.0700, -0.3685],
         [-0.3679, -0.1001,  0.2148, -0.4137,  0.3538, -0.1567],
         [ 0.2592, -0.2104,  0.2584, -0.2593,  0.3182,  0.1000],
         [ 0.0212, -0.3234, -0.1083, -0.1857,  0.2341,  0.0142],
         [-0.1380,  0.2852, -0.0426, -0.2209, -0.1583,  0.3989],
         [ 0.2973,  0.3304, -0.1577,  0.1130, -0.0560,  0.2166],
         [ 0.1102, -0.2337,  0.2242,  0.3692, -0.2632, -0.0921],
         [-0.2175, -0.3259, -0.1853,  0.1383, -0.1380,  0.0722],
         [ 0.1423, -0.3090, -0.1184,  0.1309,  0.0412, -0.3136],
         [-0.0940, -0.0174, -0.0039, -0.2360,  0.1301,  0.2654],
         [-0.1971, -0.3389,  0.1918, -0.1708,  0.3067, -0.0939],
         [-0.0017,  0.2367,  0.0313,  0.3127, -0.3349,  0.3774],
  

In [30]:
psi = amp_phase(mlp_model(input))
psi

tensor([-2.1334e-04-2.4326e-04j, -1.4894e-06+5.8175e-06j,
        -3.5071e-06-2.6359e-05j,  1.8070e-04-1.3097e-04j,
         2.7573e-05-2.5970e-04j,  4.2158e-04-2.1967e-04j,
        -3.1382e-04+2.1701e-04j, -1.0638e-02-2.5830e-03j,
        -4.7091e-05-3.9410e-04j, -3.5408e-04+3.4436e-04j,
         3.8777e-04-3.9496e-04j,  3.5511e-02+7.9580e-03j,
         6.3011e-05-5.7060e-05j, -3.4953e-02-7.9270e-03j,
         1.0274e-02+2.8188e-03j,  4.4108e-05+1.0141e-04j,
         8.1715e-06-1.2301e-04j,  1.7471e-04-1.2046e-04j,
        -2.7556e-04+2.3507e-04j, -3.4943e-02-7.7097e-03j,
         6.7849e-04-6.3571e-04j,  8.1016e-02+1.8307e-02j,
        -3.5366e-02-8.2248e-03j, -1.0776e-05-1.9489e-05j,
        -6.5353e-05+6.4143e-05j, -3.5089e-02-8.0852e-03j,
         3.5250e-02+7.8575e-03j,  1.7410e-04+4.2313e-04j,
        -1.0372e-02-1.5386e-03j,  3.6646e-05+6.9016e-05j,
        -1.1014e-05-2.0483e-05j, -1.3038e-05+3.0265e-05j,
        -1.8189e-06-2.0079e-05j,  1.6070e-04-1.3045e-04j,
        -7.933

In [31]:
(psi.reshape((-1, 1)).conj().T @ psi.reshape((-1, 1))).real[0][0]

tensor(0.0301, grad_fn=<SelectBackward0>)

In [32]:
(psi.reshape((1, -1)).conj() @ torch_h @ psi.reshape((-1, 1))).real[0][0] / (psi.reshape((-1, 1)).conj().T @ psi.reshape((-1, 1))).real[0][0]

tensor(-11.2085, grad_fn=<DivBackward0>)

In [33]:
qpsi = qt.Qobj(psi.data).unit()
qpsi.dims = [N * [2], N * [1]]
qpsi

Quantum object: dims=[[2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1]], shape=(64, 1), type='ket', dtype=Dense
Qobj data =
[[-1.22918402e-03-1.40151689e-03j]
 [-8.58091856e-06+3.35178704e-05j]
 [-2.02063560e-05-1.51866675e-04j]
 [ 1.04111973e-03-7.54593641e-04j]
 [ 1.58862233e-04-1.49625330e-03j]
 [ 2.42893264e-03-1.26562166e-03j]
 [-1.80806559e-03+1.25031808e-03j]
 [-6.12882132e-02-1.48817954e-02j]
 [-2.71314983e-04-2.27063392e-03j]
 [-2.04001595e-03+1.98400814e-03j]
 [ 2.23414748e-03-2.27555287e-03j]
 [ 2.04595028e-01+4.58502407e-02j]
 [ 3.63037350e-04-3.28754590e-04j]
 [-2.01383689e-01-4.56717089e-02j]
 [ 5.91959707e-02+1.62405203e-02j]
 [ 2.54130111e-04+5.84262649e-04j]
 [ 4.70802580e-05-7.08744426e-04j]
 [ 1.00661957e-03-6.94016913e-04j]
 [-1.58765191e-03+1.35433997e-03j]
 [-2.01321874e-01-4.44197949e-02j]
 [ 3.90911629e-03-3.66266998e-03j]
 [ 4.66772699e-01+1.05475542e-01j]
 [-2.03760706e-01-4.73870130e-02j]
 [-6.20874544e-05-1.12285309e-04j]
 [-3.76531268e-04+3.69560223e-04j]
 [-2.021683

In [34]:
qpsi.dag() @ h @ qpsi

(-11.208535050346196-1.2757401446613433e-16j)

In [None]:
kan_model = kan.KAN(width = [N, N, 2])
torch_h = torch.tensor(h.full(), dtype = torch.complex64)
epochs = []
loss_data = []
data_rate = 10
optimizer = torch.optim.Adam(kan_model.parameters(), lr = 0.01) # adam also seems to work better here?
for epoch in range(1000):
    psi = utils.log_amp_phase(kan_model(input)).reshape((-1, 1))
    loss = ((psi.conj().T @ torch_h @ psi).real[0][0]) / (psi.conj().T @ psi).real[0][0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % data_rate == 0:
        loss_data.append(loss.item())
        epochs.append(epoch)
        print(f'{epoch}: {loss}')

checkpoint directory created: ./model
saving model version 0.0
0: 5.941198348999023
10: 4.7441086769104
20: 0.29270175099372864
30: -2.058305263519287
40: -2.1971001625061035
50: -2.649165391921997
60: -3.519038438796997
70: -4.229647159576416
80: -5.445335388183594
90: -7.375014781951904
100: -8.666014671325684
110: -9.676692962646484
120: -10.175782203674316
130: -10.587321281433105
140: -10.895034790039062
150: -11.076375961303711
160: -11.153007507324219
170: -11.18027114868164
180: -11.19078540802002
190: -11.195976257324219
200: -11.198993682861328
210: -11.201033592224121
220: -11.20247745513916
230: -11.20357608795166
240: -11.204439163208008
250: -11.20513916015625
260: -11.20572280883789
270: -11.206212043762207
280: -11.206634521484375
290: -11.206993103027344
300: -11.207316398620605
310: -11.207592964172363
320: -11.207842826843262
330: -11.208062171936035
340: -11.208261489868164
350: -11.2084379196167
360: -11.208598136901855
370: -11.20874309539795
380: -11.208873748779

In [40]:
qpsi = qt.Qobj(psi.detach().reshape((-1, 1))) 
qpsi

Quantum object: dims=[[64], [1]], shape=(64, 1), type='ket', dtype=Dense
Qobj data =
[[-5.34438848e-01-2.49952182e-01j]
 [ 4.66661930e-01-4.43741649e-01j]
 [-1.00064382e-01+6.91466965e-03j]
 [ 7.96453431e-02-5.74716687e-01j]
 [ 6.15118183e-02-1.27333635e-02j]
 [-8.13165009e-02+1.19967234e+00j]
 [ 1.08455978e-01-7.85700858e-01j]
 [ 4.73558426e+01+1.51166525e+01j]
 [-1.55318575e-02+6.47792593e-03j]
 [-2.02472210e-01-8.61598253e-01j]
 [ 3.98380198e-02+1.40113330e+00j]
 [-1.56519241e+02-4.98399391e+01j]
 [-5.01848720e-02-7.25200295e-01j]
 [ 1.56708267e+02+4.96559296e+01j]
 [-4.75954361e+01-1.46566916e+01j]
 [ 6.34698808e-01-2.38212585e-01j]
 [ 1.27692912e-02-4.32747863e-02j]
 [ 1.41410637e+00+1.26532447e+00j]
 [-7.45498180e-01-1.20220053e+00j]
 [ 1.56726456e+02+4.96098633e+01j]
 [ 3.04221392e-01+4.28409457e-01j]
 [-3.60820770e+02-1.14207634e+02j]
 [ 1.56699387e+02+4.95416031e+01j]
 [-2.84062833e-01+2.60302961e-01j]
 [-2.87431926e-01-3.85013074e-01j]
 [ 1.56665039e+02+4.96239548e+01j]
 [-1.

In [43]:
N=10
h10 = J1J2_hamiltonian(N, J1, J2)
eigs10 = h10.eigenstates()
print(f'GSE: {eigs10[0][0]}')

GSE: -18.06178541796817


In [54]:
gs10 = eigs10[1][0]

In [46]:
torch_h10 = torch.tensor(h10.full(), dtype = torch.complex64)

In [53]:
N=10
input = utils.generate_input_torch(N)
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))

mlp_model = nn.Sequential(*layers)
torch_h = torch.tensor(h.full(), dtype = torch.complex64)
epochs = []
loss_data = []
data_rate = 50
optimizer = torch.optim.Adam(mlp_model.parameters(), lr = 1e-3) # adam does better here idk why
for epoch in range(10000):
    psi = utils.log_amp_phase(mlp_model(input)).reshape((-1, 1))
    loss = ((psi.conj().T @ torch_h10 @ psi).real[0][0]) / (psi.conj().T @ psi).real[0][0]
    if epoch % data_rate == 0:
        loss_data.append(loss.item())
        epochs.append(epoch)
        print(f'{epoch}: {loss}')
    if loss < -19:
        break
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0: 9.974654197692871
50: -8.192296981811523
100: -9.931819915771484
150: -10.249652862548828
200: -10.450944900512695
250: -10.706153869628906
300: -11.185789108276367
350: -11.67613410949707
400: -12.459854125976562
450: -13.860406875610352
500: -15.413647651672363
550: -16.558095932006836
600: -17.32294464111328
650: -17.68755340576172
700: -17.85002326965332
750: -17.9227294921875
800: -17.96568489074707
850: -17.990381240844727
900: -18.004594802856445
950: -18.016645431518555
1000: -18.02418327331543
1050: -18.02886390686035
1100: -18.033506393432617
1150: -18.035249710083008
1200: -18.0389347076416
1250: -18.041248321533203
1300: -18.042905807495117
1350: -18.04239273071289
1400: -18.04582405090332
1450: -18.0471248626709
1500: -18.048112869262695
1550: -18.04811668395996
1600: -18.04895782470703
1650: -18.049583435058594
1700: -18.05001449584961
1750: -18.050552368164062
1800: -18.052082061767578
1850: -18.050081253051758
1900: -18.05249786376953
1950: -18.053436279296875
2000: 

In [57]:
qt.fidelity(gs10, qt.Qobj(psi.detach()).unit())

np.float64(0.999731455296855)

In [None]:
kan_model = kan.KAN(width = [N, N, 2]) # KANs seem sort of random here, get stuck at various local mins?
epochs = []
loss_data = []
data_rate = 50
optimizer = torch.optim.SGD(kan_model.parameters(), lr = 0.1) # for some reason SGD is better now? Adam gets stuck on ~15.8 but SGD reaches ~17.9
for epoch in range(1000):
    psi = utils.log_amp_phase(kan_model(input)).reshape((-1, 1))
    loss = ((psi.conj().T @ torch_h10 @ psi).real[0][0]) / (psi.conj().T @ psi).real[0][0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % data_rate == 0:
        loss_data.append(loss.item())
        epochs.append(epoch)
        print(f'{epoch}: {loss}')

checkpoint directory created: ./model
saving model version 0.0
0: 9.944035530090332
50: -11.1666898727417
100: -13.350960731506348
150: -14.125144958496094
200: -15.085332870483398
250: -15.717145919799805
300: -16.457189559936523
350: -15.185018539428711
400: -16.994850158691406
450: -17.623767852783203
500: -17.510866165161133
550: -17.1640682220459
600: -17.694957733154297
650: -17.756561279296875
700: -17.67783546447754
750: -17.939626693725586
800: -17.67343521118164
850: -17.713966369628906
900: -17.84929847717285
950: -17.791135787963867


In [67]:
[n for n in eigs10[0]]

[np.float64(-18.06178541796817),
 np.float64(-16.368829386954662),
 np.float64(-16.368829386954662),
 np.float64(-16.368829386954655),
 np.float64(-15.082389741633772),
 np.float64(-14.173117497250766),
 np.float64(-14.173117497250761),
 np.float64(-14.173117497250761),
 np.float64(-14.173117497250761),
 np.float64(-14.173117497250761),
 np.float64(-14.173117497250761),
 np.float64(-12.984659666930337),
 np.float64(-12.984659666930334),
 np.float64(-12.984659666930334),
 np.float64(-12.984659666930334),
 np.float64(-12.984659666930334),
 np.float64(-12.984659666930328),
 np.float64(-11.903727476217746),
 np.float64(-11.903727476217746),
 np.float64(-11.90372747621774),
 np.float64(-11.90372747621774),
 np.float64(-11.903727476217737),
 np.float64(-11.56494440140296),
 np.float64(-11.56494440140296),
 np.float64(-11.56494440140296),
 np.float64(-11.56494440140296),
 np.float64(-11.564944401402956),
 np.float64(-11.564944401402956),
 np.float64(-11.545147698039608),
 np.float64(-11.54514

In [66]:
qt.fidelity(gs10, qt.Qobj(psi.detach()).unit())

np.float64(0.6700357758745074)

In [31]:
J1J2_hamiltonian(N, 1, 0.5).eigenstates()

(array([-9.00000000e+00, -9.00000000e+00, -7.00000000e+00, -7.00000000e+00,
        -7.00000000e+00, -6.27491722e+00, -6.27491722e+00, -6.27491722e+00,
        -6.27491722e+00, -6.27491722e+00, -6.27491722e+00, -5.00000000e+00,
        -5.00000000e+00, -4.00000000e+00, -4.00000000e+00, -4.00000000e+00,
        -4.00000000e+00, -4.00000000e+00, -4.00000000e+00, -1.00000000e+00,
        -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00,
        -1.00000000e+00, -1.33431039e-14, -1.09552879e-14, -8.38736674e-15,
        -5.47058516e-15, -4.47688597e-15, -7.30838347e-16, -3.64032313e-16,
         2.94849333e-15,  3.57153156e-15,  7.31070219e-15,  1.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.27491722e+00,  1.27491722e+00,  1.27491722e+00,
         1.27491722e+00,  1.27491722e+00,  1.27491722e+00,  4.00000000e+00,
         4.00000000e+00,  4.00000000e+00,  4.00000000e+00,  4.00000000e+00,
         4.0

In [41]:
kan_model_half = kan.KAN(width = [N, N, 2])
torch_h_half = torch.tensor(J1J2_hamiltonian(N, J1, 0.5).full(), dtype = torch.complex64)
epochs = []
loss_data = []
data_rate = 10
optimizer = torch.optim.SGD(kan_model_half.parameters(), lr = 0.1)
for epoch in range(10000):
    psi = utils.log_amp_phase(kan_model_half(input))
    loss = (psi.reshape((1, -1)).conj() @ torch_h_half @ psi.reshape((-1, 1))).real[0][0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % data_rate == 0:
        loss_data.append(loss.item())
        epochs.append(epoch)
        print(f'{epoch}: {loss}')

checkpoint directory created: ./model
saving model version 0.0
0: 724.072021484375
10: 1.2257259651704544e-09
20: 1.2257072024013382e-09
30: 1.225688439632222e-09
40: 1.2256698989077108e-09
50: 1.2256512471608971e-09
60: 1.2256231585183741e-09
70: 1.2255857440024442e-09
80: 1.225566981233328e-09
90: 1.2255483294865144e-09
100: 1.2255295667173982e-09
110: 1.225510803948282e-09
120: 1.2255017001194801e-09
130: 1.2254642856035503e-09
140: 1.225445522834434e-09
150: 1.225426760065318e-09
160: 1.2254081083185042e-09
170: 1.2253705827802719e-09
180: 1.2253427161823538e-09
190: 1.22532406443554e-09
200: 1.2253054126887264e-09
210: 1.225267887150494e-09
220: 1.225267887150494e-09
230: 1.2252304726345642e-09
240: 1.2252024950143436e-09
250: 1.2252024950143436e-09
260: 1.22518384326753e-09
270: 1.2251464287516e-09
280: 1.225127665982484e-09
290: 1.2250996883622634e-09
300: 1.2250809255931472e-09
310: 1.2250436220995198e-09
320: 1.2250436220995198e-09
330: 1.2250063186058924e-09
340: 1.2249875558

KeyboardInterrupt: 

In [None]:
qpsi = qt.Qobj(psi.detach().reshape((-1, 1))) 
qpsi.unit().full()

array([[ 9.99992637e-01-3.83310970e-03j],
       [ 1.78347848e-04-4.85553812e-05j],
       [-1.93661075e-33+2.59227160e-33j],
       [-2.17638390e-36+4.77463518e-36j],
       [-2.65953468e-24+1.02748419e-24j],
       [-3.64079651e-27+2.40796499e-27j],
       [ 0.00000000e+00+0.00000000e+00j],
       [ 0.00000000e+00+0.00000000e+00j],
       [-5.20536203e-24-1.73716236e-24j],
       [-8.37007079e-27-8.94324690e-28j],
       [ 0.00000000e+00+0.00000000e+00j],
       [ 0.00000000e+00+0.00000000e+00j],
       [ 0.00000000e+00+0.00000000e+00j],
       [ 0.00000000e+00+0.00000000e+00j],
       [-0.00000000e+00+0.00000000e+00j],
       [-0.00000000e+00+0.00000000e+00j],
       [-2.64390314e-14-5.88669146e-13j],
       [-2.15702420e-16-8.00728213e-16j],
       [ 0.00000000e+00+0.00000000e+00j],
       [ 0.00000000e+00+0.00000000e+00j],
       [ 4.79205779e-36+1.25630208e-35j],
       [ 1.29768377e-38+1.99458802e-38j],
       [ 0.00000000e+00-0.00000000e+00j],
       [ 0.00000000e+00-0.00000000

: 