In [1]:
from libraries import lib
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

In [2]:
import numpy.random as npr
import random

class SparseStateVector:
    """
    Container class for dictionary (self.values) with keys of integer states
    and values being complex amplitude of psi
    """
    def __init__(self):
        self.values = {}
        self.normalized = False

    def TFIM_multiply(psi, N, J, Gamma):
        """
        Returns new sparse vector representing H|psi>
        """
        prod = SparseStateVector()
        for state in psi.values:
            jtotal = 0
            for site in range(N - 1):
                jtotal += J if ((state >> site) ^ (state >> site + 1)) & 1 else -J
            jtotal += J if ((state >> (N - 1)) ^ (state >> 0)) & 1 else -J 
            prod.values[state] = jtotal * psi.values[state]
        
        for state in psi.values:
            for site in range(N):
                flipped_state = state ^ (1 << site)
                prod.values[flipped_state] = prod.values.get(flipped_state, 0) - Gamma * psi.values[state]
        return prod

    def inner_product(v1, v2):
        """
        Returns <v1|v2> for two SparseStateVectors
        """
        prod = 0
        for s in v1.values:
            if s in v2.values:
                prod += torch.conj(v1.values[s]) * v2.values[s]
        return prod

    def TFIM_expectation_from_sparse(psi, N, J, Gamma):
        """
        Returns <psi|H|psi>/<psi|psi> for SparseStateVector psi
        """
        # do H|psi> then <psi| (H|psi>)
        hpsi = SparseStateVector.TFIM_multiply(psi, N, J, Gamma)
        exp = SparseStateVector.inner_product(psi, hpsi)
        if not psi.normalized:
            mag2 = SparseStateVector.inner_product(psi, psi)
            return (exp / mag2).real
        return exp.real

    def normalize(self):
        """
        Normalizes values
        """
        mag = sum(abs(self.values[s]) ** 2 for s in self.values) ** 0.5
        for s in self.values:
            self.values[s] = self.values[s] / mag
        self.normalized = True

    def to_prob_distribution(self, N):
        """
        Returns 1D list representing real probability distribution
        """
        if not self.normalized:
            mag2 = sum(abs(self.values[s]) ** 2 for s in self.values)
            return [(abs(self.values.get(s, 0)) ** 2 / mag2).item() for s in range(0, 2 ** N)]
        return [abs(self.values.get(s, 0)) ** 2 for s in range(0, 2 ** N)]
    
    def to_dense_vector(self, N):
        """
        Returns 1D list of dense representation
        """
        if not self.normalized:
            mag = sum(abs(self.values[s]) ** 2 for s in self.values) ** 0.5
            return [(self.values.get(s, 0) / mag).item() for s in range(0, 2 ** N)]
        return [self.values.get(s, 0).item() for s in range(0, 2 ** N)]

class UniformNeuralState(SparseStateVector):
    def __init__(self, N, model, output_to_psi, num_samples):
        """
        Initializes sparse vector values

        Args:
            N (int): number of qubits
            model: torch model representing psi(x), which returns complex amplitude given integer state
            output_to_psi (function): takes in output of model to compute complex amplitude
            num_samples (int): number of unique integer samples to take
            informed (bool): whether to guarantee sample first and last states
        """
        super().__init__()
        self.samples = num_samples
        self.distribution = {}
        self.list = []
        self.nn_output = {}
        def psi(x):
            tens = torch.tensor([lib.generate_state_array(x, N)], dtype = torch.float32)
            nn_output = model(tens)
            return output_to_psi(nn_output)[0], nn_output[0]
        if num_samples >= 2 ** N:
            for state in range(2 ** N):
                self.distribution[state] = 1
                self.list.append(state)
                self.values[state], self.nn_output[state] = psi(state)
        else:
            sampled_states = set()
            for _ in range(num_samples):
                x = random.getrandbits(N)
                while x in sampled_states:
                    x = random.getrandbits(N)
                sampled_states.add(x)
                self.distribution[x] = self.distribution.get(x, 0) + 1
                self.list.append(x)
            for state in sampled_states:
                self.values[state], self.nn_output[state] = psi(state)

class MHNeuralState(SparseStateVector):
    def __init__(self, N, model, output_to_psi, x_func, x0, num_samples, burnin = 0, lag = 0):
        """
        Initializes distribution of samples and vector values

        Args:
            N (int): number of qubits
            model: torch model representing psi(x), which returns complex amplitude given integer state
            output_to_psi (function): takes in output of model to compute complex amplitude
            x_func (function): takes in state x and generates proposal x*
            x0 (int): intger state to begin sampling
            num_samples (int): number of proposal x* generated
            burnin (int): number of samples to throw away before accepting first sample
            lag (int): number of samples to throw away in-between accepting samples
            informed (bool): whether to guarantee sample first and last states
        """
        # uses arbitrary x_func for MH sampling
        super().__init__()
        self.distribution = {}
        self.list = []
        self.nn_output = {}
        self.samples = num_samples
        def psi(x):
            tens = torch.tensor([lib.generate_state_array(x, N)], dtype = torch.float32)
            nn_output = model(tens)
            return output_to_psi(nn_output)[0], nn_output[0]
        num_uniform = burnin + num_samples * (lag + 1)
        rand_uniform = npr.uniform(0, 1, num_uniform)
        index = 0

        x = x0
        psi_val, nn_val = psi(x)
        self.values[x] = psi_val
        self.nn_output[x] = nn_val
        for _ in range(burnin):
            new_x = x_func(x)
            new_psi_val = self.values[new_x] if new_x in self.values else psi(new_x)[0]
            ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
            if ratio > 1 or ratio > rand_uniform[index]:
                x = new_x
                psi_val = new_psi_val
            index += 1
        for _ in range(num_samples):
            for _ in range(lag):
                new_x = x_func(x)
                new_psi_val = self.values[new_x] if new_x in self.values else psi(new_x)[0]
                ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
                if ratio > 1 or ratio > rand_uniform[index]:
                    x = new_x
                    psi_val = new_psi_val
                index += 1
            new_x = x_func(x)
            if new_x in self.values: new_psi_val, new_nn_val = self.values[new_x], self.nn_output[new_x]
            else: new_psi_val, new_nn_val = psi(new_x)
            ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
            if ratio > 1 or ratio > rand_uniform[index]:
                self.distribution[new_x] = self.distribution.get(new_x, 0) + 1
                self.list.append(new_x)
                x = new_x 
                psi_val = new_psi_val 
            else: 
                self.distribution[x] = self.distribution.get(x, 0) + 1
                self.list.append(x)
            self.values[new_x] = new_psi_val
            self.nn_output[new_x] = new_nn_val
            index += 1

In [3]:
def log_amp_phase(nn_output):
    return torch.exp(nn_output[:, 0] + 1.j * nn_output[:, 1])
def bitflip_x(x, N, flips):
    new_x = x
    for _ in range(flips):
        new_x = x ^ (1 << npr.randint(0, N))
    return new_x
# only for nn_output = (log(amp), phase)
def generate_eloc_list(sampled_vector, N, J, Gamma, model):
    nn_output_calcs = {}
    def model_to_output(x):
        if x in sampled_vector.nn_output:
            return sampled_vector.nn_output[x]
        if x in nn_output_calcs:
            return nn_output_calcs[x]
        tens = torch.tensor([lib.generate_state_array(x, N)], dtype = torch.float32)
        output = model(tens)[0]
        nn_output_calcs[x] = output
        return output

    eloc_values = []
    for basis_state in sampled_vector.list:
        eloc = 0
        output = model_to_output(basis_state)
        for adjacency in lib.generate_adjacencies(basis_state, N):
            output_prime = model_to_output(adjacency)
            eloc += lib.calc_H_elem(N, J, Gamma, basis_state, adjacency) * torch.exp(output_prime[0] - output[0] + 1.j * 2 * np.pi * (output_prime[1] - output[1]))
        eloc += lib.calc_H_elem(N, J, Gamma, basis_state, basis_state)
        eloc_values.append(eloc)
    return eloc_values

In [4]:
def set_gradients(sampled_vector, N, J, Gamma, model): # also only for log(amp), phase
    nn_output_calcs = {}
    def model_to_output(x):
        if x in sampled_vector.nn_output:
            return sampled_vector.nn_output[x]
        if x in nn_output_calcs:
            return nn_output_calcs[x]
        tens = torch.tensor([lib.generate_state_array(x, N)], dtype = torch.float32)
        output = model(tens)[0]
        nn_output_calcs[x] = output
        return output
    def output_to_log(x):
        return x[0] + 1.j * x[1]

    params = [n for n in model.parameters()]
    energy = lib.TFIM_expectation_using_locals(sampled_vector, N, J, Gamma, model, log_amp_phase)
    for i in range(len(params)):
        p = params[i]
        grad = torch.zeros(p.shape)
        tot = 0
        for basis_state in sampled_vector.list:
            for adjacency in lib.generate_adjacencies(basis_state, N):
                log_psi = output_to_log(model_to_output(basis_state))
                log_psi_p = output_to_log(model_to_output(adjacency))
                tot += lib.calc_H_elem(N, J, Gamma, basis_state, adjacency) * torch.exp(log_psi_p - log_psi).detach()
        avg = tot / len(sampled_vector.list)

        for basis_state in sampled_vector.distribution:
            for adjacency in lib.generate_adjacencies(basis_state, N):
                log_psi = output_to_log(model_to_output(basis_state))
                log_psi_p = output_to_log(model_to_output(adjacency))
                
                log_psi.conj().real.backward(retain_graph=True)
                grad_re_log = p.grad.clone()
                log_psi.conj().imag.backward(retain_graph=True)
                grad_im_log = p.grad.clone()

                mag_psi = abs(torch.exp(log_psi)).detach()
                H_psi_over_psi = lib.calc_H_elem(N, J, Gamma, basis_state, adjacency) * torch.exp(log_psi_p - log_psi).detach() - avg
                grad += mag_psi ** 2 * (grad_re_log * H_psi_over_psi.real - grad_im_log * H_psi_over_psi.imag)

                # print(log_psi, log_psi_p, grad_re_log, grad_im_log, mag_psi, psi_over_psi)
        p.grad = grad
        # psi_times_psi = SparseStateVector.inner_product(sampled_vector, sampled_vector).real
        #sum(abs(torch.exp(output_to_log(model_to_output(basis_state)))) ** 2 for basis_state in sampled_vector.distribution)
        # torch.log(psi_times_psi).backward(retain_graph=True)
        # grad_log_psi_psi = p.grad.clone()

        # p.grad = grad / psi_times_psi.detach() - grad_log_psi_psi * energy.detach()
    return energy

def update_gradients(model, lr):
    with torch.no_grad():
        for p in model.parameters():
            p -= lr * p.grad 


In [59]:
N = 6; J = 1; Gamma = 0.1

In [60]:
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))
model = nn.Sequential(*layers)

In [61]:
params = [n for n in model.parameters()]
print(params[0])
params[0] = torch.zeros(params[0].shape)
print(params[0])

Parameter containing:
tensor([[ 1.2399e-01, -6.9577e-02,  2.0703e-01,  2.7970e-01,  1.7864e-01,
         -2.8881e-01],
        [ 2.4756e-01,  2.9709e-01, -2.8225e-01,  8.7880e-02, -1.7080e-01,
         -3.0567e-02],
        [ 2.0476e-01, -2.4860e-01, -3.1950e-02,  2.6819e-02, -1.0525e-01,
         -3.7605e-01],
        [ 2.9361e-01, -2.0021e-04,  6.7889e-02, -3.1584e-01, -3.4393e-02,
          2.3144e-01],
        [-2.7328e-01,  2.4530e-01, -3.9614e-01, -1.2676e-01, -1.1906e-01,
          2.6199e-02],
        [-3.4742e-01,  2.0432e-01, -2.5909e-01,  3.3132e-01, -3.7207e-01,
          3.4773e-01],
        [-4.0441e-01, -2.6788e-01,  3.3489e-02, -3.3940e-01, -2.5297e-01,
         -3.8853e-01],
        [ 1.2889e-01,  1.3381e-01, -1.7449e-01,  3.8351e-02,  3.6061e-01,
          6.2137e-03],
        [-2.1327e-01, -3.8807e-01,  3.4640e-01,  3.2096e-01,  1.3811e-01,
          3.0136e-01],
        [-1.4434e-01, -5.5235e-02,  5.5037e-02, -2.3440e-01,  1.3363e-01,
          1.4014e-01],
        

In [62]:
num_epochs = 1000
data_rate = 1
num_samples = 128

In [64]:
epochs = []
energy_data = []
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-5)
for epoch in range(num_epochs):
    # mh_state = MHNeuralState(N, model, log_amp_phase, lambda x: bitflip_x(x, N, 1), 2 ** (N - 1), num_samples)
    mh_state = UniformNeuralState(N, model, log_amp_phase, 2 ** N)
    optimizer.zero_grad()
    energy = set_gradients(mh_state, N, J, Gamma, model)
    # energy = lib.TFIM_expectation_using_locals(mh_state, N, J, Gamma, model, log_amp_phase)
    optimizer.step()
    # update_gradients(model, 1e-5)
    if epoch % data_rate == 0:
        energy_data.append(energy.item().real)
        epochs.append(epoch)
    if epoch % 1 == 0:
        print(epoch, energy)
print(energy_data[-1])
print(min(energy_data))

0 tensor(-0.6855, grad_fn=<SelectBackward0>)
1 tensor(-0.8520, grad_fn=<SelectBackward0>)
2 tensor(-0.9844, grad_fn=<SelectBackward0>)
3 tensor(-1.1351, grad_fn=<SelectBackward0>)
4 tensor(-1.3058, grad_fn=<SelectBackward0>)
5 tensor(-1.5043, grad_fn=<SelectBackward0>)
6 tensor(-1.7415, grad_fn=<SelectBackward0>)
7 tensor(-2.0257, grad_fn=<SelectBackward0>)
8 tensor(-2.3663, grad_fn=<SelectBackward0>)
9 tensor(-2.7730, grad_fn=<SelectBackward0>)
10 tensor(-3.2517, grad_fn=<SelectBackward0>)
11 tensor(-3.7958, grad_fn=<SelectBackward0>)
12 tensor(-4.3776, grad_fn=<SelectBackward0>)
13 tensor(-4.9396, grad_fn=<SelectBackward0>)
14 tensor(-5.4076, grad_fn=<SelectBackward0>)
15 tensor(-5.7279, grad_fn=<SelectBackward0>)
16 tensor(-5.9100, grad_fn=<SelectBackward0>)
17 tensor(-5.9898, grad_fn=<SelectBackward0>)
18 tensor(-6.0010, grad_fn=<SelectBackward0>)
19 tensor(nan, grad_fn=<SelectBackward0>)
20 tensor(nan, grad_fn=<SelectBackward0>)


KeyboardInterrupt: 