In [5]:
import torch 
import torch.nn as nn 
import torch.nn.functional as f
import numpy as np 
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt 

In [6]:
data = np.load('2_qubit_crit_data.npz')

In [7]:
data.files

['ground_state', 'fields']

In [8]:
data['ground_state'].shape

(10000, 4)

In [9]:
training_data_x = data['ground_state'][:80]
training_data_y = data['fields'][:80][:,[0,2,4]]
val_data_x = data['ground_state'][80:]
val_data_y = data['fields'][80:][:,[0,2,4]]

In [10]:
print(val_data_x.shape)
print(val_data_y)

(9920, 4)
[[1.         0.02592159 0.01      ]
 [1.         0.02612061 0.01      ]
 [1.         0.02631963 0.01      ]
 ...
 [1.         1.99960196 0.01      ]
 [1.         1.99980098 0.01      ]
 [1.         2.         0.01      ]]


In [11]:
training_data_x.shape

(80, 4)

In [12]:
training_data = TensorDataset(torch.Tensor(training_data_x), torch.Tensor(training_data_y))
validation_data = TensorDataset(torch.Tensor(val_data_x), torch.Tensor(val_data_y))

In [13]:
training_loader = DataLoader(training_data, batch_size = 16)
val_loader = DataLoader(validation_data, batch_size = 36)

In [14]:
for i,(wf, fields,) in enumerate(training_loader):
    print(wf.shape)
    print(fields.shape)

torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])


In [15]:
def seq_gen(num_q):
    if num_q == 2:
        return ['00','01', '10','11']
    else:
        temp = []
        
        smaller_vals = seq_gen(num_q-1)
        for i in ['0','1']:
            for each in smaller_vals:
                temp.append(i+each)
        return temp 

In [16]:
class MPS_autoencoder(nn.Module):
    def __init__(self):
        super(MPS_autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(nn.Linear(3,16),
                                nn.ReLU(),
                                nn.Linear(16,8)
                               )
    def encode(self, x):
        encoded = self.encoder(x)
        temp = encoded.view(-1, 2,2,2)
#         print(temp.shape)
        spin_up, spin_down = torch.split(temp, 1, dim = 1)

        return spin_up, spin_down
    def decode(self, spin_up, spin_down):
        # spin_up and spin_down (1,2,2)
#         mps = {'0':spin_up, '1':spin_down}
        mps = [spin_up, spin_down]        
        coeffs = [] 
        
#         states = seq_gen(2)
#         for state in states:
#             mat = mps[state[0]]
            
#             for site in state[1:]:
#                 mat = torch.matmul(mat, mps[site])
#             diagonal = torch.diagonal(mat, dim1=-1,dim2=-2)
#             coeffs.append(torch.sum(diagonal, dim = -1, keepdim = True))
        
            
        for i in range(2):
            for j in range(2):
                mat = torch.matmul(mps[i], mps[j])
                diagonal = torch.diagonal(mat , dim1=-1, dim2=-2)
                coeffs.append(torch.sum(diagonal, dim = -1, keepdim = True))
        
        c_i = coeffs[0]
        for i in coeffs[1:]:
            c_i = torch.cat((c_i, i), dim = 1)
        return c_i.squeeze()
    
    def forward(self, x):
        spin_up, spin_down = self.encode(x)
        gs = self.decode(spin_up, spin_down)
#         print(torch.norm(gs, dim = 1).view(-1,1).shape)
        gs = gs / torch.norm(gs, dim = 1).view(-1,1)
        return gs

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MPS_autoencoder()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_func = nn.MSELoss()
for epoch in range(1001):
    total = 0
    for i,(wf,fields,) in enumerate(training_loader):
        gs = model(fields)
        
        loss = loss_func(gs, wf)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += loss.item()
    if epoch % 10 == 0:
        print("Epoch {} :".format(epoch + 1), total / len(training_loader))

Epoch 1 : 0.11380127472802996
Epoch 11 : 5.98325816099532e-05
Epoch 21 : 1.0366542574047343e-05
Epoch 31 : 6.1916120102978315e-06
Epoch 41 : 3.729656782525126e-06
Epoch 51 : 2.413095609199445e-06
Epoch 61 : 1.8240137194425187e-06
Epoch 71 : 1.6271635260523e-06
Epoch 81 : 1.5811931930898026e-06
Epoch 91 : 1.574293455064435e-06
Epoch 101 : 1.5735553674289803e-06
Epoch 111 : 1.5727395151543532e-06
Epoch 121 : 1.5713356560809188e-06
Epoch 131 : 1.5694105940156077e-06
Epoch 141 : 1.5670850984861318e-06
Epoch 151 : 1.5643758317196444e-06
Epoch 161 : 1.5613520162105487e-06
Epoch 171 : 1.5580056270891874e-06
Epoch 181 : 1.554354406607672e-06
Epoch 191 : 1.5503786855219913e-06
Epoch 201 : 1.5461314532672077e-06
Epoch 211 : 1.5415692772080546e-06
Epoch 221 : 1.5367304570190755e-06
Epoch 231 : 1.5316173858082038e-06
Epoch 241 : 1.526255982753355e-06
Epoch 251 : 1.520599002446943e-06
Epoch 261 : 1.5147014892136212e-06
Epoch 271 : 1.5085315808960332e-06
Epoch 281 : 1.5021360070477386e-06
Epoch 291 

In [18]:
total = 0
with torch.no_grad():
    for i,(wf,fields,) in enumerate(val_loader):
        gs = model(fields)
        loss = loss_func(gs, wf)
        total += loss.item()
    
print(total)

18.07402462364189


In [19]:
with torch.no_grad():
    for i,(wf, fields,) in enumerate(val_loader):
        gs = model(fields)
        for j in range(10):
            print("Fields:\t", fields[j])
            print("Wavefunction:\t", wf[j].data) 
            print("Reconstructed:\t", gs[j])
            print("*"*80)
        break

Fields:	 tensor([1.0000, 0.0259, 0.0100])
Wavefunction:	 tensor([0.9999, 0.0065, 0.0065, 0.0084])
Reconstructed:	 tensor([0.9999, 0.0070, 0.0070, 0.0043])
********************************************************************************
Fields:	 tensor([1.0000, 0.0261, 0.0100])
Wavefunction:	 tensor([0.9999, 0.0066, 0.0066, 0.0085])
Reconstructed:	 tensor([0.9999, 0.0070, 0.0070, 0.0043])
********************************************************************************
Fields:	 tensor([1.0000, 0.0263, 0.0100])
Wavefunction:	 tensor([0.9999, 0.0066, 0.0066, 0.0086])
Reconstructed:	 tensor([0.9999, 0.0071, 0.0071, 0.0043])
********************************************************************************
Fields:	 tensor([1.0000, 0.0265, 0.0100])
Wavefunction:	 tensor([0.9999, 0.0067, 0.0067, 0.0087])
Reconstructed:	 tensor([0.9999, 0.0071, 0.0071, 0.0043])
********************************************************************************
Fields:	 tensor([1.0000, 0.0267, 0.0100])
Wavefunction:	