In [36]:
import pytorch_lightning as pl
import torch
import Operator as op
import utils
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

g_dtype = torch.float32

Using cpu device


Setting up the Hamiltonian

In [37]:
lattice_sites = 3
hamiltonian = []
h = -1
'''
for l in range(lattice_sites):
  hamiltonian = Sx(l)* (h) + Sz(l) * Sz((l+1) % lattice_sites) + hamiltonian
print_op_list(hamiltonian)
'''
hamiltonian = op.Sx(2) + ([op.Sz(1)*op.Sy(2)] + hamiltonian)
op.print_op_list(hamiltonian)

adding to sequence
Hamiltonan = Sx_2 + Sz_1 * Sy_2


In [38]:
from torch.utils.data import Dataset, DataLoader
#setting up the datamodule

class spin_data(Dataset):
    def __init__(self, lattice_sites, num_t_values):
        self.spins = utils.get_all_spin_configs(lattice_sites).type(g_dtype)
        self.alpha_arr = torch.rand((num_t_values, 2), dtype=g_dtype)

    def __len__(self):
        return self.alpha_arr.shape[0]

    def __getitem__(self, index):
        return self.spins, self.alpha_arr[index, :].broadcast_to(self.spins.shape[0], self.alpha_arr.shape[1])

    def cuda(self):
        self.spins = self.spins.to(device)
        self.alpha_arr = self.alpha_arr.to(device)



In [39]:
data = spin_data(lattice_sites, 3)
data.cuda()
dataloader = DataLoader(dataset=data, batch_size=2)
data_iter = iter(dataloader)
spins, alpha = next(data_iter)

print(spins.size(), alpha.size())
print(alpha, '\n\n' ,spins)

torch.Size([2, 8, 3]) torch.Size([2, 8, 2])
tensor([[[0.8734, 0.5823],
         [0.8734, 0.5823],
         [0.8734, 0.5823],
         [0.8734, 0.5823],
         [0.8734, 0.5823],
         [0.8734, 0.5823],
         [0.8734, 0.5823],
         [0.8734, 0.5823]],

        [[0.1604, 0.1025],
         [0.1604, 0.1025],
         [0.1604, 0.1025],
         [0.1604, 0.1025],
         [0.1604, 0.1025],
         [0.1604, 0.1025],
         [0.1604, 0.1025],
         [0.1604, 0.1025]]]) 

 tensor([[[-1., -1., -1.],
         [ 1., -1., -1.],
         [-1., -1.,  1.],
         [-1.,  1., -1.],
         [ 1.,  1., -1.],
         [ 1., -1.,  1.],
         [-1.,  1.,  1.],
         [ 1.,  1.,  1.]],

        [[-1., -1., -1.],
         [ 1., -1., -1.],
         [-1., -1.,  1.],
         [-1.,  1., -1.],
         [ 1.,  1., -1.],
         [ 1., -1.,  1.],
         [-1.,  1.,  1.],
         [ 1.,  1.,  1.]]])


In [43]:
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from torch import nn

class Model(pl.LightningModule):

  def __init__(self, lattice_sites):
    super().__init__()
    self.lattice_net = nn.Sequential(
      nn.Conv1d(1, 8, kernel_size=2, padding=1, padding_mode='circular'),
      nn.ReLU(),
      nn.Conv1d(8, 16, kernel_size=2, padding=1, padding_mode='circular'),
      nn.Flatten(start_dim=1, end_dim=-1)
    )
    
    self.tNN = nn.Sequential(
      nn.Linear(2, 16),
      nn.ReLU(),
      nn.Linear(16, 32),
      nn.ReLU(),
      nn.Linear(32, 64)
    )

    self.psi = nn.Sequential(
      nn.Linear( 64 + 16 * ( lattice_sites + 2 ), 128 ),
      nn.ReLU(),
      nn.Linear(128, 64),
      nn.ReLU(),
      nn.Linear(64,2)
    )

  def forward(self, spins, alpha):
    '''
    Forward function of the neural net to calculate psi. uses same spin config for all alpha values
    Parameters
    __________
    spins: tensor, dtype=int
      tensor of input spins to wave function 
      size = (num_spin_configs, lattice_sites)
    alpha: tensor, dtype=float
      other inputs to hamiltonian e.g. (time, ext_param) 
      size = (num_alpha_configs, num_inputs)

    Returns
    _______
    psi: tensor, dtype=complex
      wave function for a combination of (spins, alpha) 
      size = (num_spin_configs, num_alpha_configs)
    '''

    spin_shape = spins.shape
    alpha_shape = alpha.shape
    #unsqueeze since circular padding needs tensor of dim 3
    spins = torch.flatten(spins, end_dim=-2).unsqueeze(1)
    alpha = torch.flatten(alpha, end_dim=-2)

    lat_out = self.lattice_net(spins)
    t_out = self.tNN(alpha)

    rad_and_phase = (self.psi(torch.cat((lat_out, t_out), dim=1)))
    psi = rad_and_phase[:, 0] * torch.exp( 1.j * rad_and_phase[:, 1] )
    return psi
    
  
  def training_step(self, batch, batch_idx):
    #get psi(s, alpha)

    #get map and s' to the s from batch_idx

    #get psi(s', alpha)

    #calc dt_psi(s, alpha)

    #get mat_els for all alphas

    #get map, s' mat_els for O_loc_init

    #calc loss
    print('nyi')

  def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

model = Model(lattice_sites)
print(model)



print(model(spins, alpha))

#trainer = pl.Trainer(fast_dev_run=True)
#trainer.fit(model, dataloader)


Model(
  (lattice_net): Sequential(
    (0): Conv1d(1, 8, kernel_size=(2,), stride=(1,), padding=(1,), padding_mode=circular)
    (1): ReLU()
    (2): Conv1d(8, 16, kernel_size=(2,), stride=(1,), padding=(1,), padding_mode=circular)
    (3): Flatten(start_dim=1, end_dim=-1)
  )
  (tNN): Sequential(
    (0): Linear(in_features=2, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=64, bias=True)
  )
  (psi): Sequential(
    (0): Linear(in_features=144, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=2, bias=True)
  )
)
torch.complex64


In [6]:
from timeit import default_timer as timer

map = utils.get_map(hamiltonian, lattice_sites)
map = map.to(device)
print("map: ", map)

mat_els = utils.get_total_mat_els(hamiltonian, lattice_sites)
mat_els = mat_els.to(device)
print("mat els: ", mat_els)

#1.dim: Batch
#2.dim: lattice sites
#s_config = get_all_spin_configs(lattice_sites)
s_config = torch.tensor([1,1,1]).reshape(1,3)
s_config = (s_config.type(torch.float32)).to(device)
print("spin config: ", s_config)

start = timer()
s_p = utils.get_sp(s_config, map)
psi_s = model(s_config)
psi_sp = model(s_p.reshape(-1, lattice_sites)).reshape(s_p.shape[0], s_p.shape[1])
print(psi_sp.shape, psi_s.shape, s_config.shape)
O_loc = utils.calc_Oloc(psi_sp, mat_els, s_config)
end = timer()

print(f"time to calculate O_loc: {end - start:.2e}") 

map:  tensor([[[ 1,  1, -1],
         [ 1,  1, -1]]], dtype=torch.int8)
mat els:  tensor([[[[ 1.+0.j,  1.+0.j,  1.+0.j],
          [ 1.+0.j,  1.+0.j,  0.+1.j]],

         [[ 1.+0.j,  1.+0.j,  1.+0.j],
          [ 1.+0.j, -1.+0.j, -0.-1.j]]]])
spin config:  tensor([[1., 1., 1.]])


TypeError: forward() missing 1 required positional argument: 'alpha'

In [9]:
spins = torch.full((2,5), 0.5)
spins[1,:] /=4

alpha = torch.full((4,2), 1)
alpha[:, 1] = 2
alpha[:2, : ]*=4
spins_batched = torch.flatten(torch.broadcast_to(spins, (alpha.shape[0], spins.shape[0], spins.shape[1])), end_dim=-2)
alpha_batched = torch.flatten(torch.broadcast_to(alpha.unsqueeze(1), (alpha.shape[0], spins.shape[0], alpha.shape[1])), end_dim=-2)
print(alpha)
print(spins)

print(spins_batched)
print(alpha_batched)
#torch.cat((alpha, spins), dim=2)

tensor([[4, 8],
        [4, 8],
        [1, 2],
        [1, 2]])
tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[4, 8],
        [4, 8],
        [4, 8],
        [4, 8],
        [1, 2],
        [1, 2],
        [1, 2],
        [1, 2]])


In [19]:
a = torch.arange(0,16).reshape(4,2,2)
print(a[:,:,1])
res = a[:,:,0]*torch.exp(1.j*0.5*np.pi*a[:,:,1])
print(res)


tensor([[ 1,  3],
        [ 5,  7],
        [ 9, 11],
        [13, 15]])
tensor([[-0.0000e+00+0.j,  2.3850e-08-2.j],
        [-1.3511e-06+4.j,  3.9816e-06-6.j],
        [-2.8620e-07+8.j,  1.3153e-05-10.j],
        [-8.2495e-06+12.j,  8.3474e-07-14.j]])
