In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as f
import numpy as np 
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt 

In [2]:
data = np.load('2_qubit_crit_data.npz')

In [3]:
data.files

['ground_state', 'fields']

In [4]:
data['ground_state'].shape

(1000, 4)

In [5]:
training_data_x = data['ground_state'][:80]
training_data_y = data['fields'][:80][:,[0,2,4]]
val_data_x = data['ground_state'][80:]
val_data_y = data['fields'][80:][:,[0,2,4]]

In [6]:
print(val_data_x.shape)
print(val_data_y)

(920, 4)
[[1.         0.16935936 0.01      ]
 [1.         0.17135135 0.01      ]
 [1.         0.17334334 0.01      ]
 ...
 [1.         1.99601602 0.01      ]
 [1.         1.99800801 0.01      ]
 [1.         2.         0.01      ]]


In [7]:
training_data_x.shape

(80, 4)

In [8]:
training_data = TensorDataset(torch.Tensor(training_data_x), torch.Tensor(training_data_y))
validation_data = TensorDataset(torch.Tensor(val_data_x), torch.Tensor(val_data_y))

In [9]:
training_loader = DataLoader(training_data, batch_size = 16)
val_loader = DataLoader(validation_data, batch_size = 36)

In [10]:
for i,(wf, fields,) in enumerate(training_loader):
    print(wf.shape)
    print(fields.shape)

torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])


In [11]:
def seq_gen(num_q):
    if num_q == 2:
        return ['00','01', '10','11']
    else:
        temp = []
        
        smaller_vals = seq_gen(num_q-1)
        for i in ['0','1']:
            for each in smaller_vals:
                temp.append(i+each)
        return temp 

In [12]:
class MPS_autoencoder(nn.Module):
    def __init__(self):
        super(MPS_autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(nn.Linear(3,16),
                                nn.ReLU(),
                                nn.Linear(16,8)
                               )
    def encode(self, x):
        encoded = self.encoder(x)
        temp = encoded.view(-1, 2,2,2)
#         print(temp.shape)
        spin_up, spin_down = torch.split(temp, 1, dim = 1)

        return spin_up, spin_down
    def decode(self, spin_up, spin_down):
        # spin_up and spin_down (1,2,2)
#         mps = {'0':spin_up, '1':spin_down}
        mps = [spin_up, spin_down]        
        coeffs = [] 
        
#         states = seq_gen(2)
#         for state in states:
#             mat = mps[state[0]]
            
#             for site in state[1:]:
#                 mat = torch.matmul(mat, mps[site])
#             diagonal = torch.diagonal(mat, dim1=-1,dim2=-2)
#             coeffs.append(torch.sum(diagonal, dim = -1, keepdim = True))
        
            
        for i in range(2):
            for j in range(2):
                mat = torch.matmul(mps[i], mps[j])
                diagonal = torch.diagonal(mat , dim1=-1, dim2=-2)
                coeffs.append(torch.sum(diagonal, dim = -1, keepdim = True))
        
        c_i = coeffs[0]
        for i in coeffs[1:]:
            c_i = torch.cat((c_i, i), dim = 1)
        return c_i.squeeze()
    
    def forward(self, x):
        spin_up, spin_down = self.encode(x)
        gs = self.decode(spin_up, spin_down)
#         print(torch.norm(gs, dim = 1).view(-1,1).shape)
        gs = gs / torch.norm(gs, dim = 1).view(-1,1)
        return gs

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MPS_autoencoder()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_func = nn.MSELoss()
for epoch in range(1001):
    total = 0
    for i,(wf,fields,) in enumerate(training_loader):
        gs = model(fields)
        
        loss = loss_func(gs, wf)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += loss.item()
    if epoch % 10 == 0:
        print("Epoch {} :".format(epoch + 1), total / len(training_loader))

  return torch._C._cuda_getDeviceCount() > 0


Epoch 1 : 0.1339497348293662
Epoch 11 : 0.002605735615361482
Epoch 21 : 0.002357731034862809
Epoch 31 : 0.0021125629195012153
Epoch 41 : 0.0018221224803710357
Epoch 51 : 0.0014371862285770475
Epoch 61 : 0.0007816101337084546
Epoch 71 : 8.905008544388693e-05
Epoch 81 : 6.583329723071074e-05
Epoch 91 : 6.324484038486844e-05
Epoch 101 : 6.0833810857729985e-05
Epoch 111 : 5.792274205305148e-05
Epoch 121 : 5.5666594926151446e-05
Epoch 131 : 5.214515022089472e-05
Epoch 141 : 4.966433552908711e-05
Epoch 151 : 4.6704159922228426e-05
Epoch 161 : 4.419503120516311e-05
Epoch 171 : 4.1661814520921325e-05
Epoch 181 : 3.973133607360069e-05
Epoch 191 : 4.0405213985650336e-05
Epoch 201 : 5.31418516402482e-05
Epoch 211 : 0.00013129858089087066
Epoch 221 : 9.637874100008049e-05
Epoch 231 : 2.6255233933625277e-05
Epoch 241 : 2.7598055748967453e-05
Epoch 251 : 4.521108348853886e-05
Epoch 261 : 8.993109440780245e-05
Epoch 271 : 3.803054878517287e-05
Epoch 281 : 3.2115918656927533e-05
Epoch 291 : 4.11255168

In [14]:
total = 0
with torch.no_grad():
    for i,(wf,fields,) in enumerate(val_loader):
        gs = model(fields)
        loss = loss_func(gs, wf)
        total += loss.item()
    
print(total)

0.5532865592977032


In [15]:
with torch.no_grad():
    for i,(wf, fields,) in enumerate(val_loader):
        gs = model(fields)
        for j in range(10):
            print("Fields:\t", fields[j])
            print("Wavefunction:\t", wf[j].data) 
            print("Reconstructed:\t", gs[j])
            print("*"*80)
        break

Fields:	 tensor([1.0000, 0.1694, 0.0100])
Wavefunction:	 tensor([0.9501, 0.0525, 0.0525, 0.3030])
Reconstructed:	 tensor([0.9532, 0.0537, 0.0537, 0.2927])
********************************************************************************
Fields:	 tensor([1.0000, 0.1714, 0.0100])
Wavefunction:	 tensor([0.9483, 0.0533, 0.0533, 0.3082])
Reconstructed:	 tensor([0.9528, 0.0541, 0.0541, 0.2938])
********************************************************************************
Fields:	 tensor([1.0000, 0.1733, 0.0100])
Wavefunction:	 tensor([0.9465, 0.0541, 0.0541, 0.3134])
Reconstructed:	 tensor([0.9524, 0.0546, 0.0546, 0.2949])
********************************************************************************
Fields:	 tensor([1.0000, 0.1753, 0.0100])
Wavefunction:	 tensor([0.9447, 0.0548, 0.0548, 0.3186])
Reconstructed:	 tensor([0.9520, 0.0550, 0.0550, 0.2961])
********************************************************************************
Fields:	 tensor([1.0000, 0.1773, 0.0100])
Wavefunction:	