In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

In [3]:
import torch
from hamiltonians.Ising import Ising
from model.model import TransformerModel

In [4]:
def gpu_setup():
    # Setup for PyTorch:
    if torch.cuda.is_available():
        torch_device = torch.device("cuda")
        print("PyTorch is using GPU {}".format(torch.cuda.current_device()))
    else:
        torch_device = torch.device("cpu")
        print("GPU unavailable; using CPU")

In [5]:
gpu_setup()
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float32)

PyTorch is using GPU 0


In [6]:
os.chdir("..")

In [7]:
system_sizes = torch.arange(8, 8 + 2, 2).reshape(-1, 1)
Hamiltonians = [Ising(size, periodic=True, get_basis=True) for size in system_sizes]
data_dir_path = os.path.join("TFIM_ground_states", "2024-08-02T20-42-27.916")
for ham in Hamiltonians:
    ham.load_dataset(
        data_dir_path,
        batch_size=1024,
        samples_in_epoch=100,
        sampling_type="shuffled",
    )

Loaded dataset for system size 8 from TFIM_ground_states/2024-08-02T20-42-27.916/8.arrow.
(h_min, h_step, h_max) = (0, 0.1, 2).


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)


In [8]:
param_dim = Hamiltonians[0].param_dim
embedding_size = 32
n_head = 8
n_hid = embedding_size
n_layers = 8
dropout = 0
minibatch = 10000
param_range = None
point_of_interest = None
use_SR = False

testmodel = TransformerModel(
    system_sizes,
    param_dim,
    embedding_size,
    n_head,
    n_hid,
    n_layers,
    dropout=dropout,
    minibatch=minibatch,
)



In [9]:
testmodel.cuda()

results_dir = "results"
paper_checkpoint_name = "ckpt_100000_Ising_32_8_8_0.ckpt"
paper_checkpoint_path = os.path.join(results_dir, paper_checkpoint_name)
checkpoint = torch.load(paper_checkpoint_path)
testmodel.load_state_dict(checkpoint)

<All keys matched successfully>

In [10]:
ham = Hamiltonians[0]

In [11]:
dataset = ham.training_dataset
sampler = ham.sampler
for idx in sampler:
    basis, params, psi_true = dataset[idx]
    break

In [12]:
basis

tensor([[0, 1, 1,  ..., 0, 1, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [0, 0, 1,  ..., 0, 1, 1],
        ...,
        [1, 0, 1,  ..., 0, 0, 1],
        [0, 1, 0,  ..., 0, 0, 1],
        [1, 0, 1,  ..., 0, 1, 1]], device='cuda:0')

In [13]:
params

tensor([[0.6000],
        [1.4000],
        [1.8000],
        ...,
        [1.4000],
        [1.6000],
        [0.7000]], device='cuda:0', dtype=torch.float64)

In [14]:
psi_true

tensor([-0.0027,  0.0232,  0.0545,  ...,  0.1153, -0.0271, -0.0062],
       device='cuda:0', dtype=torch.float64)

In [15]:
hfull = ham.full_H(1)

In [16]:
hfull

<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 2164 stored elements and shape (256, 256)>

In [17]:
from scipy.sparse.linalg import eigsh

In [18]:
ham.update_param(1)

In [19]:
vals, vecs = eigsh(hfull, k=1, which="SA")

In [20]:
energy, psi_dataset = ham.retrieve_ground(1)

In [21]:
psi = vecs[:, 0].reshape(-1)

In [22]:
diff = torch.tensor(psi) - psi_dataset

In [23]:
diff.max()

tensor(0.2584, device='cuda:0', dtype=torch.float64)

In [24]:
torch.set_printoptions(threshold=10_000_000)

In [46]:
print(psi_dataset)

tensor([-0.4591, -0.1292, -0.1292, -0.0774, -0.1292, -0.0404, -0.0774, -0.0607,
        -0.1292, -0.0378, -0.0404, -0.0274, -0.0774, -0.0274, -0.0607, -0.0564,
        -0.1292, -0.0375, -0.0378, -0.0246, -0.0404, -0.0137, -0.0274, -0.0247,
        -0.0774, -0.0246, -0.0274, -0.0221, -0.0607, -0.0247, -0.0564, -0.0607,
        -0.1292, -0.0378, -0.0375, -0.0246, -0.0378, -0.0128, -0.0246, -0.0221,
        -0.0404, -0.0128, -0.0137, -0.0110, -0.0274, -0.0113, -0.0247, -0.0274,
        -0.0774, -0.0246, -0.0246, -0.0192, -0.0274, -0.0110, -0.0221, -0.0246,
        -0.0607, -0.0221, -0.0247, -0.0246, -0.0564, -0.0274, -0.0607, -0.0774,
        -0.1292, -0.0404, -0.0378, -0.0274, -0.0375, -0.0137, -0.0246, -0.0247,
        -0.0378, -0.0128, -0.0128, -0.0113, -0.0246, -0.0110, -0.0221, -0.0274,
        -0.0404, -0.0137, -0.0128, -0.0110, -0.0137, -0.0060, -0.0110, -0.0137,
        -0.0274, -0.0110, -0.0113, -0.0128, -0.0247, -0.0137, -0.0274, -0.0404,
        -0.0774, -0.0274, -0.0246, -0.02

In [26]:
print(torch.tensor(psi))

tensor([-0.4591,  0.1292,  0.1292, -0.0774,  0.1292, -0.0404, -0.0774,  0.0607,
         0.1292, -0.0378, -0.0404,  0.0274, -0.0774,  0.0274,  0.0607, -0.0564,
         0.1292, -0.0375, -0.0378,  0.0246, -0.0404,  0.0137,  0.0274, -0.0247,
        -0.0774,  0.0246,  0.0274, -0.0221,  0.0607, -0.0247, -0.0564,  0.0607,
         0.1292, -0.0378, -0.0375,  0.0246, -0.0378,  0.0128,  0.0246, -0.0221,
        -0.0404,  0.0128,  0.0137, -0.0110,  0.0274, -0.0113, -0.0247,  0.0274,
        -0.0774,  0.0246,  0.0246, -0.0192,  0.0274, -0.0110, -0.0221,  0.0246,
         0.0607, -0.0221, -0.0247,  0.0246, -0.0564,  0.0274,  0.0607, -0.0774,
         0.1292, -0.0404, -0.0378,  0.0274, -0.0375,  0.0137,  0.0246, -0.0247,
        -0.0378,  0.0128,  0.0128, -0.0113,  0.0246, -0.0110, -0.0221,  0.0274,
        -0.0404,  0.0137,  0.0128, -0.0110,  0.0137, -0.0060, -0.0110,  0.0137,
         0.0274, -0.0110, -0.0113,  0.0128, -0.0247,  0.0137,  0.0274, -0.0404,
        -0.0774,  0.0274,  0.0246, -0.02

In [27]:
hfull[0:4, 0:4].todense()

matrix([[-8.,  1.,  1.,  0.],
        [ 1., -4.,  0.,  1.],
        [ 1.,  0., -4.,  1.],
        [ 0.,  1.,  1., -4.]])

In [28]:
params = torch.tensor([[1.0]])
system_size = torch.tensor([15])
testmodel.set_param(system_size, params)

In [47]:
from model.model_utils import compute_psi

In [48]:
log_probs, log_phases = compute_psi(testmodel, ham.basis, ham.symmetry)

In [49]:
log_probs

tensor([-2.0448, -3.4559, -4.3917, -4.0036, -4.5463, -5.6334, -5.2093, -4.2599,
        -4.5859, -5.8897, -6.6644, -6.0234, -5.4894, -6.2668, -5.5583, -4.3355,
        -4.5859, -5.9480, -6.8449, -6.3382, -6.8510, -7.7849, -7.2551, -6.1323,
        -5.5512, -6.7054, -7.3426, -6.4548, -5.8779, -6.4406, -5.6325, -4.2599,
        -4.5463, -5.9327, -6.8557, -6.4024, -6.9684, -7.9767, -7.5084, -6.4548,
        -6.8510, -8.0753, -8.7824, -7.9914, -7.4908, -8.1220, -7.3678, -6.0234,
        -5.4894, -6.7809, -7.6168, -6.9498, -7.4908, -8.2456, -7.6213, -6.3382,
        -5.8779, -6.9116, -7.4592, -6.4024, -5.8797, -6.3043, -5.4765, -4.0036,
        -4.3917, -5.8165, -6.7240, -6.3043, -6.8557, -7.9128, -7.4549, -6.4406,
        -6.8449, -8.1240, -8.8482, -8.1220, -7.6168, -8.3110, -7.5779, -6.2668,
        -6.6644, -8.0057, -8.8482, -8.2456, -8.7824, -9.6125, -9.0210, -7.7849,
        -7.3426, -8.4260, -8.9919, -7.9767, -7.4592, -7.9128, -7.0955, -5.6334,
        -5.2093, -6.5941, -7.4549, -6.91

In [50]:
probs = torch.exp(log_probs)

In [51]:
probs

tensor([1.2941e-01, 3.1558e-02, 1.2379e-02, 1.8249e-02, 1.0606e-02, 3.5764e-03,
        5.4654e-03, 1.4124e-02, 1.0195e-02, 2.7677e-03, 1.2755e-03, 2.4215e-03,
        4.1301e-03, 1.8982e-03, 3.8555e-03, 1.3095e-02, 1.0195e-02, 2.6111e-03,
        1.0648e-03, 1.7676e-03, 1.0584e-03, 4.1599e-04, 7.0654e-04, 2.1715e-03,
        3.8828e-03, 1.2242e-03, 6.4734e-04, 1.5730e-03, 2.8006e-03, 1.5955e-03,
        3.5797e-03, 1.4124e-02, 1.0606e-02, 2.6513e-03, 1.0535e-03, 1.6576e-03,
        9.4114e-04, 3.4337e-04, 5.4848e-04, 1.5730e-03, 1.0584e-03, 3.1114e-04,
        1.5342e-04, 3.3837e-04, 5.5821e-04, 2.9693e-04, 6.3124e-04, 2.4215e-03,
        4.1301e-03, 1.1352e-03, 4.9211e-04, 9.5886e-04, 5.5821e-04, 2.6242e-04,
        4.8992e-04, 1.7676e-03, 2.8006e-03, 9.9612e-04, 5.7614e-04, 1.6576e-03,
        2.7957e-03, 1.8284e-03, 4.1840e-03, 1.8249e-02, 1.2379e-02, 2.9781e-03,
        1.2017e-03, 1.8284e-03, 1.0535e-03, 3.6602e-04, 5.7860e-04, 1.5955e-03,
        1.0648e-03, 2.9633e-04, 1.4364e-

In [52]:
amps = torch.sqrt(probs)

In [53]:
amps

tensor([0.3597, 0.1776, 0.1113, 0.1351, 0.1030, 0.0598, 0.0739, 0.1188, 0.1010,
        0.0526, 0.0357, 0.0492, 0.0643, 0.0436, 0.0621, 0.1144, 0.1010, 0.0511,
        0.0326, 0.0420, 0.0325, 0.0204, 0.0266, 0.0466, 0.0623, 0.0350, 0.0254,
        0.0397, 0.0529, 0.0399, 0.0598, 0.1188, 0.1030, 0.0515, 0.0325, 0.0407,
        0.0307, 0.0185, 0.0234, 0.0397, 0.0325, 0.0176, 0.0124, 0.0184, 0.0236,
        0.0172, 0.0251, 0.0492, 0.0643, 0.0337, 0.0222, 0.0310, 0.0236, 0.0162,
        0.0221, 0.0420, 0.0529, 0.0316, 0.0240, 0.0407, 0.0529, 0.0428, 0.0647,
        0.1351, 0.1113, 0.0546, 0.0347, 0.0428, 0.0325, 0.0191, 0.0241, 0.0399,
        0.0326, 0.0172, 0.0120, 0.0172, 0.0222, 0.0157, 0.0226, 0.0436, 0.0357,
        0.0183, 0.0120, 0.0162, 0.0124, 0.0082, 0.0110, 0.0204, 0.0254, 0.0148,
        0.0112, 0.0185, 0.0240, 0.0191, 0.0288, 0.0598, 0.0739, 0.0370, 0.0241,
        0.0316, 0.0234, 0.0148, 0.0194, 0.0350, 0.0266, 0.0149, 0.0110, 0.0176,
        0.0221, 0.0172, 0.0256, 0.0526, 

In [36]:
log_phases

tensor([-0.1827,  6.1037,  6.1033, -0.1785,  6.1047, -0.1762, -0.1782,  6.1049,
         6.1048, -0.1760, -0.1747,  6.1071, -0.1756,  6.1085,  6.1066, -0.1782,
         6.1048, -0.1746, -0.1744,  6.1089, -0.1736,  6.1124,  6.1089, -0.1762,
        -0.1755,  6.1096,  6.1100, -0.1741,  6.1059, -0.1741, -0.1751,  6.1049,
         6.1047, -0.1735, -0.1744,  6.1099, -0.1734,  6.1141,  6.1101, -0.1741,
        -0.1736,  6.1138,  6.1133, -0.1697,  6.1107, -0.1694, -0.1726,  6.1071,
        -0.1756,  6.1115,  6.1102, -0.1744,  6.1107, -0.1700, -0.1714,  6.1089,
         6.1059, -0.1745, -0.1709,  6.1099, -0.1732,  6.1102,  6.1096, -0.1785,
         6.1033, -0.1741, -0.1748,  6.1102, -0.1744,  6.1128,  6.1102, -0.1741,
        -0.1744,  6.1131,  6.1147, -0.1694,  6.1102, -0.1707, -0.1711,  6.1085,
        -0.1747,  6.1136,  6.1147, -0.1700,  6.1133, -0.1681, -0.1673,  6.1124,
         6.1100, -0.1698, -0.1656,  6.1141, -0.1709,  6.1128,  6.1132, -0.1762,
        -0.1782,  6.1108,  6.1102, -0.17

In [37]:
phases = torch.exp(1j * log_phases)

In [38]:
phases

tensor([0.9834-0.1817j, 0.9839-0.1785j, 0.9839-0.1789j, 0.9841-0.1775j,
        0.9841-0.1775j, 0.9845-0.1753j, 0.9842-0.1772j, 0.9842-0.1773j,
        0.9841-0.1774j, 0.9846-0.1751j, 0.9848-0.1738j, 0.9845-0.1752j,
        0.9846-0.1747j, 0.9848-0.1738j, 0.9844-0.1757j, 0.9842-0.1772j,
        0.9841-0.1774j, 0.9848-0.1737j, 0.9848-0.1735j, 0.9849-0.1734j,
        0.9850-0.1727j, 0.9855-0.1699j, 0.9848-0.1735j, 0.9845-0.1753j,
        0.9846-0.1746j, 0.9850-0.1727j, 0.9850-0.1724j, 0.9849-0.1732j,
        0.9843-0.1764j, 0.9849-0.1733j, 0.9847-0.1742j, 0.9842-0.1773j,
        0.9841-0.1775j, 0.9850-0.1727j, 0.9848-0.1736j, 0.9850-0.1724j,
        0.9850-0.1725j, 0.9857-0.1682j, 0.9851-0.1722j, 0.9849-0.1732j,
        0.9850-0.1727j, 0.9857-0.1686j, 0.9856-0.1691j, 0.9856-0.1689j,
        0.9852-0.1716j, 0.9857-0.1686j, 0.9851-0.1717j, 0.9845-0.1752j,
        0.9846-0.1747j, 0.9853-0.1709j, 0.9851-0.1722j, 0.9848-0.1735j,
        0.9852-0.1716j, 0.9856-0.1691j, 0.9853-0.1706j, 0.9849-0

In [39]:
predicted_psi = amps * phases

In [40]:
predicted_psi

tensor([0.3537-0.0654j, 0.1748-0.0317j, 0.1095-0.0199j, 0.1329-0.0240j,
        0.1013-0.0183j, 0.0589-0.0105j, 0.0728-0.0131j, 0.1170-0.0211j,
        0.0994-0.0179j, 0.0518-0.0092j, 0.0352-0.0062j, 0.0484-0.0086j,
        0.0633-0.0112j, 0.0429-0.0076j, 0.0611-0.0109j, 0.1126-0.0203j,
        0.0994-0.0179j, 0.0503-0.0089j, 0.0321-0.0057j, 0.0414-0.0073j,
        0.0320-0.0056j, 0.0201-0.0035j, 0.0262-0.0046j, 0.0459-0.0082j,
        0.0614-0.0109j, 0.0345-0.0060j, 0.0251-0.0044j, 0.0391-0.0069j,
        0.0521-0.0093j, 0.0393-0.0069j, 0.0589-0.0104j, 0.1170-0.0211j,
        0.1013-0.0183j, 0.0507-0.0089j, 0.0320-0.0056j, 0.0401-0.0070j,
        0.0302-0.0053j, 0.0183-0.0031j, 0.0231-0.0040j, 0.0391-0.0069j,
        0.0320-0.0056j, 0.0174-0.0030j, 0.0122-0.0021j, 0.0181-0.0031j,
        0.0233-0.0041j, 0.0170-0.0029j, 0.0248-0.0043j, 0.0484-0.0086j,
        0.0633-0.0112j, 0.0332-0.0058j, 0.0219-0.0038j, 0.0305-0.0054j,
        0.0233-0.0041j, 0.0160-0.0027j, 0.0218-0.0038j, 0.0414-0

In [41]:
amp_recovered = torch.abs(predicted_psi)

In [42]:
amp_recovered

tensor([0.3597, 0.1776, 0.1113, 0.1351, 0.1030, 0.0598, 0.0739, 0.1188, 0.1010,
        0.0526, 0.0357, 0.0492, 0.0643, 0.0436, 0.0621, 0.1144, 0.1010, 0.0511,
        0.0326, 0.0420, 0.0325, 0.0204, 0.0266, 0.0466, 0.0623, 0.0350, 0.0254,
        0.0397, 0.0529, 0.0399, 0.0598, 0.1188, 0.1030, 0.0515, 0.0325, 0.0407,
        0.0307, 0.0185, 0.0234, 0.0397, 0.0325, 0.0176, 0.0124, 0.0184, 0.0236,
        0.0172, 0.0251, 0.0492, 0.0643, 0.0337, 0.0222, 0.0310, 0.0236, 0.0162,
        0.0221, 0.0420, 0.0529, 0.0316, 0.0240, 0.0407, 0.0529, 0.0428, 0.0647,
        0.1351, 0.1113, 0.0546, 0.0347, 0.0428, 0.0325, 0.0191, 0.0241, 0.0399,
        0.0326, 0.0172, 0.0120, 0.0172, 0.0222, 0.0157, 0.0226, 0.0436, 0.0357,
        0.0183, 0.0120, 0.0162, 0.0124, 0.0082, 0.0110, 0.0204, 0.0254, 0.0148,
        0.0112, 0.0185, 0.0240, 0.0191, 0.0288, 0.0598, 0.0739, 0.0370, 0.0241,
        0.0316, 0.0234, 0.0148, 0.0194, 0.0350, 0.0266, 0.0149, 0.0110, 0.0176,
        0.0221, 0.0172, 0.0256, 0.0526, 

In [43]:
phases_recovered = torch.angle(predicted_psi)

In [44]:
phases_recovered

tensor([-0.1827, -0.1795, -0.1798, -0.1785, -0.1785, -0.1762, -0.1782, -0.1783,
        -0.1784, -0.1760, -0.1747, -0.1761, -0.1756, -0.1746, -0.1766, -0.1782,
        -0.1784, -0.1746, -0.1744, -0.1743, -0.1736, -0.1708, -0.1743, -0.1762,
        -0.1755, -0.1736, -0.1732, -0.1741, -0.1773, -0.1741, -0.1751, -0.1783,
        -0.1785, -0.1735, -0.1744, -0.1732, -0.1734, -0.1690, -0.1731, -0.1741,
        -0.1736, -0.1694, -0.1699, -0.1697, -0.1725, -0.1694, -0.1726, -0.1761,
        -0.1756, -0.1717, -0.1730, -0.1744, -0.1725, -0.1700, -0.1714, -0.1743,
        -0.1773, -0.1745, -0.1709, -0.1732, -0.1732, -0.1730, -0.1736, -0.1785,
        -0.1798, -0.1741, -0.1748, -0.1730, -0.1744, -0.1704, -0.1730, -0.1741,
        -0.1744, -0.1700, -0.1685, -0.1694, -0.1730, -0.1707, -0.1711, -0.1746,
        -0.1747, -0.1696, -0.1685, -0.1700, -0.1699, -0.1681, -0.1673, -0.1708,
        -0.1732, -0.1698, -0.1656, -0.1690, -0.1709, -0.1704, -0.1700, -0.1762,
        -0.1782, -0.1724, -0.1730, -0.17

In [54]:
log_amps = log_probs / 2

In [60]:
((log_probs + 1j * log_phases) / 2).exp()

tensor([ 0.3582-0.0328j, -0.1769+0.0159j, -0.1108+0.0100j,  0.1346-0.0120j,
        -0.1026+0.0092j,  0.0596-0.0053j,  0.0736-0.0066j, -0.1184+0.0106j,
        -0.1006+0.0090j,  0.0524-0.0046j,  0.0356-0.0031j, -0.0490+0.0043j,
         0.0640-0.0056j, -0.0434+0.0038j, -0.0619+0.0055j,  0.1140-0.0102j,
        -0.1006+0.0090j,  0.0509-0.0045j,  0.0325-0.0028j, -0.0419+0.0037j,
         0.0324-0.0028j, -0.0203+0.0017j, -0.0265+0.0023j,  0.0464-0.0041j,
         0.0621-0.0055j, -0.0349+0.0030j, -0.0253+0.0022j,  0.0395-0.0034j,
        -0.0527+0.0047j,  0.0398-0.0035j,  0.0596-0.0052j, -0.1184+0.0106j,
        -0.1026+0.0092j,  0.0513-0.0045j,  0.0323-0.0028j, -0.0406+0.0035j,
         0.0306-0.0027j, -0.0185+0.0016j, -0.0233+0.0020j,  0.0395-0.0034j,
         0.0324-0.0028j, -0.0176+0.0015j, -0.0123+0.0011j,  0.0183-0.0016j,
        -0.0235+0.0020j,  0.0172-0.0015j,  0.0250-0.0022j, -0.0490+0.0043j,
         0.0640-0.0056j, -0.0336+0.0029j, -0.0221+0.0019j,  0.0308-0.0027j,
        -0.0