In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as f
import numpy as np 
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt 

In [2]:
data = np.load('6_qubit_crit_data.npz')

In [3]:
data.files

['ground_state', 'fields']

In [4]:
data['fields'].shape

(1000, 18)

In [5]:
training_data_x = data['ground_state'][:800]
training_data_y = data['fields'][:800][:,[0,6,12]]
val_data_x = data['ground_state'][800:]
val_data_y = data['fields'][800:][:,[0,6,12]]

In [6]:
training_data = TensorDataset(torch.Tensor(training_data_x), torch.Tensor(training_data_y))
validation_data = TensorDataset(torch.Tensor(val_data_x), torch.Tensor(val_data_y))

In [7]:
training_loader = DataLoader(training_data, batch_size = 16)
val_loader = DataLoader(validation_data, batch_size = 36)

In [8]:
def seq_gen(num_q):
    if num_q == 2:
        return ['00','01', '10','11']
    else:
        temp = []
        
        smaller_vals = seq_gen(num_q-1)
        for i in ['0','1']:
            for each in smaller_vals:
                temp.append(i+each)
        return temp   

In [15]:
class MPS_autoencoder(nn.Module):
    def __init__(self):
        super(MPS_autoencoder, self).__init__()
        
        self.num_qubits = 6
        
        self.mps_size = 6
        self.encoder = nn.Sequential(nn.Linear(3,4),
                                nn.ReLU(),
                                nn.Linear(4,2 * (self.mps_size ** 2))
                               )
    def encode(self, x):
        encoded = self.encoder(x)
        temp = encoded.view(-1, 2,self.mps_size,self.mps_size)
        spin_up, spin_down = torch.split(temp, 1, dim = 1)
        return spin_up, spin_down
    
    def decode(self, spin_up, spin_down):
        # spin_up and spin_down (1,mps.size,mps.size)
        mps = {'0':spin_up, '1':spin_down}
                
        coeffs = [] 
        
        states = seq_gen(self.num_qubits)
        for state in states:
            mat = mps[state[0]]
            
            for site in state[1:]:
                mat = torch.matmul(mat, mps[site])
            diagonal = torch.diagonal(mat, dim1=-1,dim2=-2)
            coeffs.append(torch.sum(diagonal, dim = -1, keepdim = True))
                
        
        c_i = coeffs[0]
        for i in coeffs[1:]:
            c_i = torch.cat((c_i, i), dim = 1)
        return c_i.squeeze()
    
    def forward(self, x):
        spin_up, spin_down = self.encode(x)
        gs = self.decode(spin_up, spin_down)
        gs = gs / torch.norm(gs, dim = 1).view(-1,1)
        return gs

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MPS_autoencoder()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_func = nn.MSELoss()
for epoch in range(1001):
    total = 0
    for i,(wf,fields,) in enumerate(training_loader):
        gs = model(fields)
        
        loss = loss_func(gs, wf)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += loss.item()
    if epoch % 10 == 0:
        print("Epoch {} :".format(epoch + 1), total / len(training_loader))

Epoch 1 : 0.004971677256980911
Epoch 11 : 0.00015158625419644522
Epoch 21 : 4.486560447730881e-05
Epoch 31 : 2.448362934501347e-05
Epoch 41 : 8.602374282418168e-05
Epoch 51 : 0.00015825390945792606
Epoch 61 : 0.00010276851475737202
Epoch 71 : 8.603214224876865e-05
Epoch 81 : 6.342118163956911e-05
Epoch 91 : 4.9318882397528795e-05
Epoch 101 : 2.4785746454654145e-05
Epoch 111 : 3.0135906674786382e-05
Epoch 121 : 2.738809110951479e-05
Epoch 131 : 3.001431813572708e-05
Epoch 141 : 2.750064187011958e-05
Epoch 151 : 1.6504495993103773e-05
Epoch 161 : 1.441741747839842e-05
Epoch 171 : 1.2418249621077849e-05
Epoch 181 : 1.161982826943131e-05
Epoch 191 : 9.9842692020502e-05
Epoch 201 : 1.5409307461595745e-05
Epoch 211 : 1.29560444295862e-05
Epoch 221 : 1.313040651780284e-05
Epoch 231 : 9.15009176651438e-06
Epoch 241 : 5.741448690969264e-05
Epoch 251 : 1.400887959391639e-05
Epoch 261 : 8.618079163795756e-06
Epoch 271 : 2.0548521051750866e-05
Epoch 281 : 1.182142537402342e-05
Epoch 291 : 1.250019

In [17]:
total = 0
with torch.no_grad():
    for i,(wf,fields,) in enumerate(val_loader):
        gs = model(fields)
        loss = loss_func(gs, wf)
        total += loss.item()
    
print(total)

0.048823071643710136


In [18]:
with torch.no_grad():
    for i,(wf, fields,) in enumerate(val_loader):
        gs = model(fields)
        for j in range(10):
            print("Fields:\t", fields[j])
            print("Wavefunction:\t", wf[j].data) 
            print("Reconstructed:\t", gs[j])
            print("*"*80)
        break

Fields:	 tensor([1.0000, 1.6036, 0.0100])
Wavefunction:	 tensor([0.3390, 0.1600, 0.1600, 0.1361, 0.1600, 0.0823, 0.1361, 0.1300, 0.1600,
        0.0787, 0.0823, 0.0743, 0.1361, 0.0743, 0.1300, 0.1317, 0.1600, 0.0823,
        0.0787, 0.0743, 0.0823, 0.0468, 0.0743, 0.0794, 0.1361, 0.0743, 0.0743,
        0.0758, 0.1300, 0.0794, 0.1317, 0.1497, 0.1600, 0.1361, 0.0823, 0.1300,
        0.0787, 0.0743, 0.0743, 0.1317, 0.0823, 0.0743, 0.0468, 0.0794, 0.0743,
        0.0758, 0.0794, 0.1497, 0.1361, 0.1300, 0.0743, 0.1317, 0.0743, 0.0794,
        0.0758, 0.1497, 0.1300, 0.1317, 0.0794, 0.1497, 0.1317, 0.1497, 0.1497,
        0.3090])
Reconstructed:	 tensor([0.8389, 0.1462, 0.1462, 0.0760, 0.1462, 0.0412, 0.0760, 0.0574, 0.1462,
        0.0400, 0.0412, 0.0284, 0.0760, 0.0289, 0.0574, 0.0521, 0.1462, 0.0412,
        0.0400, 0.0289, 0.0412, 0.0161, 0.0284, 0.0283, 0.0760, 0.0284, 0.0289,
        0.0264, 0.0574, 0.0283, 0.0521, 0.0672, 0.1462, 0.0760, 0.0412, 0.0574,
        0.0400, 0.0284, 0.0289