In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as f
import numpy as np 
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt 

In [40]:
data = np.load('2_qubit_data.npz')

In [41]:
data.files

['ground_state', 'fields']

In [42]:
data['ground_state'].shape

(100, 4)

In [120]:
training_data_x = data['ground_state'][:80]
training_data_y = data['fields'][:80][:,[0,2,4]]
val_data_x = data['ground_state'][80:]
val_data_y = data['fields'][80:][:,[0,2,4]]

In [121]:
print(val_data_x.shape)
print(val_data_y)

(20, 4)
[[1.         1.77777778 0.        ]
 [1.         1.77777778 0.22222222]
 [1.         1.77777778 0.44444444]
 [1.         1.77777778 0.66666667]
 [1.         1.77777778 0.88888889]
 [1.         1.77777778 1.11111111]
 [1.         1.77777778 1.33333333]
 [1.         1.77777778 1.55555556]
 [1.         1.77777778 1.77777778]
 [1.         1.77777778 2.        ]
 [1.         2.         0.        ]
 [1.         2.         0.22222222]
 [1.         2.         0.44444444]
 [1.         2.         0.66666667]
 [1.         2.         0.88888889]
 [1.         2.         1.11111111]
 [1.         2.         1.33333333]
 [1.         2.         1.55555556]
 [1.         2.         1.77777778]
 [1.         2.         2.        ]]


In [122]:
training_data_x.shape

(80, 4)

In [123]:
training_data = TensorDataset(torch.Tensor(training_data_x), torch.Tensor(training_data_y))
validation_data = TensorDataset(torch.Tensor(val_data_x), torch.Tensor(val_data_y))

In [124]:
training_loader = DataLoader(training_data, batch_size = 16)
val_loader = DataLoader(validation_data, batch_size = 36)

In [125]:
for i,(wf, fields,) in enumerate(training_loader):
    print(wf.shape)
    print(fields.shape)

torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])
torch.Size([16, 4])
torch.Size([16, 3])


In [11]:
def seq_gen(num_q):
    if num_q == 2:
        return ['00','01', '10','11']
    else:
        temp = []
        
        smaller_vals = seq_gen(num_q-1)
        for i in ['0','1']:
            for each in smaller_vals:
                temp.append(i+each)
        return temp 

In [126]:
class MPS_autoencoder(nn.Module):
    def __init__(self):
        super(MPS_autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(nn.Linear(3,16),
                                nn.ReLU(),
                                nn.Linear(16,8)
                               )
    def encode(self, x):
        encoded = self.encoder(x)
        temp = encoded.view(-1, 2,2,2)
#         print(temp.shape)
        spin_up, spin_down = torch.split(temp, 1, dim = 1)

        return spin_up, spin_down
    def decode(self, spin_up, spin_down):
        # spin_up and spin_down (1,2,2)
#         mps = {'0':spin_up, '1':spin_down}
        mps = [spin_up, spin_down]        
        coeffs = [] 
        
#         states = seq_gen(2)
#         for state in states:
#             mat = mps[state[0]]
            
#             for site in state[1:]:
#                 mat = torch.matmul(mat, mps[site])
#             diagonal = torch.diagonal(mat, dim1=-1,dim2=-2)
#             coeffs.append(torch.sum(diagonal, dim = -1, keepdim = True))
        
            
        for i in range(2):
            for j in range(2):
                mat = torch.matmul(mps[i], mps[j])
                diagonal = torch.diagonal(mat , dim1=-1, dim2=-2)
                coeffs.append(torch.sum(diagonal, dim = -1, keepdim = True))
        
        c_i = coeffs[0]
        for i in coeffs[1:]:
            c_i = torch.cat((c_i, i), dim = 1)
        return c_i.squeeze()
    
    def forward(self, x):
        spin_up, spin_down = self.encode(x)
        gs = self.decode(spin_up, spin_down)
#         print(torch.norm(gs, dim = 1).view(-1,1).shape)
        gs = gs / torch.norm(gs, dim = 1).view(-1,1)
        return gs

In [127]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MPS_autoencoder()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_func = nn.MSELoss()
for epoch in range(1001):
    total = 0
    for i,(wf,fields,) in enumerate(training_loader):
        gs = model(fields)
        
        loss = loss_func(gs, wf)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += loss.item()
    if epoch % 10 == 0:
        print("Epoch {} :".format(epoch + 1), total / len(training_loader))

Epoch 1 : 0.13737019002437592
Epoch 11 : 0.0034235646831803023
Epoch 21 : 0.0013732730032643304
Epoch 31 : 0.000958823754626792
Epoch 41 : 0.0007515706100093666
Epoch 51 : 0.0006388937406882178
Epoch 61 : 0.0005515202909009531
Epoch 71 : 0.00045577145647257566
Epoch 81 : 0.0003415234656131361
Epoch 91 : 0.00021838934408151544
Epoch 101 : 0.0001290649212023709
Epoch 111 : 0.0008863025926984846
Epoch 121 : 0.0001132978723035194
Epoch 131 : 6.357299716910347e-05
Epoch 141 : 0.00022687121527269482
Epoch 151 : 5.7259979075752196e-05
Epoch 161 : 3.885745281877462e-05
Epoch 171 : 0.00046149680711096155
Epoch 181 : 4.865696082561044e-05
Epoch 191 : 2.8356862640066537e-05
Epoch 201 : 7.767187344143167e-05
Epoch 211 : 5.118376357131638e-05
Epoch 221 : 2.5560339781804942e-05
Epoch 231 : 2.7390567993279546e-05
Epoch 241 : 0.000343644945678534
Epoch 251 : 3.4301742925890724e-05
Epoch 261 : 2.0543496793834494e-05
Epoch 271 : 0.00010101677798957098
Epoch 281 : 4.069356164109195e-05
Epoch 291 : 1.9061

In [128]:
total = 0
with torch.no_grad():
    for i,(wf,fields,) in enumerate(val_loader):
        gs = model(fields)
        loss = loss_func(gs, wf)
        total += loss.item()
    
print(total)

9.803861757973209e-05


In [130]:
with torch.no_grad():
    for i,(wf, fields,) in enumerate(val_loader):
        gs = model(fields)
        for j in range(10):
            print("Fields:\t", fields[j])
            print("Wavefunction:\t", wf[j].data) 
            print("Reconstructed:\t", gs[j])
            print("*"*80)
        break

Fields:	 tensor([1.0000, 1.7778, 0.0000])
Wavefunction:	 tensor([0.6104, 0.3570, 0.3570, 0.6104])
Reconstructed:	 tensor([0.6187, 0.3408, 0.3408, 0.6204])
********************************************************************************
Fields:	 tensor([1.0000, 1.7778, 0.2222])
Wavefunction:	 tensor([0.7272, 0.3484, 0.3484, 0.4779])
Reconstructed:	 tensor([0.7361, 0.3394, 0.3394, 0.4774])
********************************************************************************
Fields:	 tensor([1.0000, 1.7778, 0.4444])
Wavefunction:	 tensor([0.8080, 0.3285, 0.3285, 0.3624])
Reconstructed:	 tensor([0.8209, 0.3213, 0.3213, 0.3459])
********************************************************************************
Fields:	 tensor([1.0000, 1.7778, 0.6667])
Wavefunction:	 tensor([0.8584, 0.3057, 0.3057, 0.2763])
Reconstructed:	 tensor([0.8611, 0.3025, 0.3025, 0.2745])
********************************************************************************
Fields:	 tensor([1.0000, 1.7778, 0.8889])
Wavefunction:	