In [1]:
import numpy as np
import numpy.random as npr
import qutip as qt
import math
import torch
import torch.nn as nn
import libraries.utils as utils

In [2]:
from libraries.NeuralStates import SparseStateVector
class MHNeuralState(SparseStateVector):
    def __init__(self, N, model, output_to_psi, x_func, x0, num_samples, burnin = 0, lag = 0, chains = 1, new = False):
        """
        Initializes distribution of samples and vector values

        Args:
            N (int): number of qubits
            model: torch model representing psi(x), which returns complex amplitude given integer state
            output_to_psi (function): takes in output of model to compute complex amplitude
            x_func (function): takes in state x and generates proposal x*
            x0 (int): integer state to begin sampling
            num_samples (int): number of proposal x* generated
            burnin (int): number of samples to throw away before accepting first sample
            lag (int): number of samples to throw away in-between accepting samples
        """
        # uses arbitrary x_func for MH sampling
        super().__init__()
        self.distribution = {}
        self.samples = num_samples
        self.list = []
        self.nn_output = {}

        self.N = N
        self.model = model
        self.output_to_psi = output_to_psi
        self.x_func = x_func
        self.x0 = x0
        self.num_samples = num_samples
        self.burnin = burnin
        self.lag = lag
        self.chains = chains # note any of these could possibly be modified by a client

        num_uniform = burnin + num_samples * (lag + 1)
        self.rand_uniform = npr.uniform(0, 1, num_uniform)
        self.index = 0
        if chains == 1 and isinstance(x0, int):
            if not new:
                self._init_single_chain()
            else:
                self._init_single_chain_new()
        elif len(x0) == chains:
            self._init_multi_chain()
        else:
            raise Exception('invalid initial values or number of chains')
    
    def _init_single_chain(self):
        def psi(x):
            tens = torch.tensor([utils.generate_state_array(x, self.N)], dtype = torch.float32)
            nn_output = self.model(tens)
            return self.output_to_psi(nn_output)[0], nn_output[0]
        num_uniform = self.burnin + self.num_samples * (self.lag + 1)
        rand_uniform = npr.uniform(0, 1, num_uniform)
        index = 0

        x = self.x0

        psi_val, nn_val = psi(x)
        self.values[x] = psi_val
        self.nn_output[x] = nn_val

        for _ in range(self.burnin):
            new_x = self.x_func(x)
            new_psi_val = self.values[new_x] if new_x in self.values else psi(new_x)[0]
            ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
            if ratio > 1 or ratio > rand_uniform[index]:
                x = new_x
                psi_val = new_psi_val
            index += 1
        for _ in range(self.num_samples):
            for _ in range(self.lag):
                new_x = self.x_func(x)
                new_psi_val = self.values[new_x] if new_x in self.values else psi(new_x)[0]
                ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
                if ratio > 1 or ratio > rand_uniform[index]:
                    x = new_x
                    psi_val = new_psi_val
                index += 1
            new_x = self.x_func(x)
            if new_x in self.values: new_psi_val, new_nn_val = self.values[new_x], self.nn_output[new_x]
            else: new_psi_val, new_nn_val = psi(new_x)
            ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
            if ratio > 1 or ratio > rand_uniform[index]:
                self.distribution[new_x] = self.distribution.get(new_x, 0) + 1
                self.list.append(new_x)
                x = new_x 
                psi_val = new_psi_val 
            else: 
                self.distribution[x] = self.distribution.get(x, 0) + 1
                self.list.append(x)
            self.values[new_x] = new_psi_val
            self.nn_output[new_x] = new_nn_val
            index += 1

    def _init_single_chain_new(self):
        def psi(x):
            tens = torch.tensor([utils.generate_state_array(x, self.N)], dtype = torch.float32)
            nn_output = self.model(tens)
            return self.output_to_psi(nn_output)[0], nn_output[0]

        x = self.x0

        def run_single_sample(modify = False):
            nonlocal x
            nonlocal psi_val
            new_x = self.x_func(x)
            if new_x in self.values: new_psi_val, new_nn_val = self.values[new_x], self.nn_output[new_x]
            else: new_psi_val, new_nn_val = psi(new_x)
            ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
            print(x, self.rand_uniform[self.index])
            if ratio > 1 or ratio > self.rand_uniform[self.index]:
                if modify:
                    self.distribution[new_x] = self.distribution.get(new_x, 0) + 1
                    self.list.append(new_x)
                x = new_x 
                psi_val = new_psi_val 
            elif modify: 
                self.distribution[x] = self.distribution.get(x, 0) + 1
                self.list.append(x)
            self.values[new_x] = new_psi_val
            self.nn_output[new_x] = new_nn_val
            self.index += 1
            

        psi_val, nn_val = psi(x)
        self.values[x] = psi_val
        self.nn_output[x] = nn_val
        print('here', x)
        for _ in range(self.burnin):
            run_single_sample(modify = False)
        for _ in range(self.num_samples):
            for _ in range(self.lag):
                run_single_sample(modify = False)
            run_single_sample(modify = True)

    def _init_multi_chain(self):
        def psi(xs):
            tens = utils.generate_input_samples(self.N, xs)
            nn_output = self.model(tens)
            return self.output_to_psi(nn_output), nn_output

        xs = self.x0.clone()
        psi_vals, nn_vals = psi(xs)
        for i, x in enumerate(xs):
            self.values[x] = psi_vals[i]
            self.nn_output[x] = nn_vals[i]
        
        for _ in range(self.burnin):
            self._run_single_chained_sample(xs, psi_vals, self.x_func, psi, modify = False)
        for _ in range(self.num_samples):
            for _ in range(self.lag):
                self._run_single_chained_sample(xs, psi_vals, self.x_func, psi, modify = False)
            self._run_single_chained_sample(xs, psi_vals, self.x_func, psi, modify = True)
        
        del self.index
    
    def _run_single_chained_sample(self, xs, psis, x_func, psi_function, modify = False):
        new_xs = x_func(xs)
        new_psi_vals, new_nn_vals = psi_function(new_xs)[0]
        for i, new_psi_val in enumerate(new_psi_vals):
            ratio = abs(new_psi_val) ** 2 / abs(psis[i]) ** 2
            if ratio > 1 or ratio > self.rand_uniform[self.index]:
                if modify:
                    self.distribution[new_xs[i]] = self.distribution.get(new_xs[i], 0) + 1
                    self.list.append(new_xs[i])
                xs[i] = new_xs[i]
                psis[i] = new_psi_val
            elif modify:
                self.distribution[xs[i]] = self.distribution(xs[i], 0) + 1
                self.list.append(xs[i])
            self.values[new_xs[i]] = new_psi_val
            self.nn_output[new_xs[i]] = new_nn_vals[i]
            self.index += 1
        

            

In [3]:
torch.manual_seed(0)
npr.seed(0)

In [4]:
N = 10
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))
mlp_model = nn.Sequential(*layers)

In [5]:
print(torch.rand((2, 2)))

tensor([[0.6973, 0.1897],
        [0.5673, 0.7153]])


In [6]:
print(utils.bitflip_x(0, 10, 1))

32


In [7]:
state_new = MHNeuralState(N, mlp_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), 0, 20, burnin = 2, lag = 1, chains = 1, new = True)

here 0
0 0.8442657485810173
512 0.8579456176227568
520 0.8472517387841254
584 0.6235636967859723
712 0.3843817072926998
716 0.2975346065444723
717 0.05671297731744318
709 0.2726562945801132
741 0.47766511732134986
229 0.8121687287754932
245 0.4799771723750573
229 0.3927847961008297
165 0.8360787635373775
181 0.3373961604172684
165 0.6481718720511972
173 0.36824153984054797
189 0.9571551589530464
173 0.14035078041264515
429 0.8700872583584364
445 0.4736080452737105
437 0.8009107519796442
309 0.5204774795512048
277 0.6788795301189603
309 0.7206326547259168
308 0.5820197920751071
310 0.5373732294490107
278 0.7586156243223572
790 0.10590760718779213
798 0.4736004193466574
799 0.18633234332675996
831 0.7369181771289581
830 0.21655035442437187
828 0.13521817340545206
824 0.3241410077932141
808 0.14967486718368317
812 0.22232138825158765
813 0.38648898112586194
805 0.9025984755294046
801 0.4499499899112276
800 0.6130634578841324
928 0.9023485831739843
896 0.09928035035897387


In [8]:
print(state_new.list)

[712, 717, 741, 245, 165, 165, 189, 429, 437, 277, 308, 278, 798, 831, 828, 808, 813, 801, 928, 384]


In [9]:
print(state_new.rand_uniform)

[0.84426575 0.85794562 0.84725174 0.6235637  0.38438171 0.29753461
 0.05671298 0.27265629 0.47766512 0.81216873 0.47997717 0.3927848
 0.83607876 0.33739616 0.64817187 0.36824154 0.95715516 0.14035078
 0.87008726 0.47360805 0.80091075 0.52047748 0.67887953 0.72063265
 0.58201979 0.53737323 0.75861562 0.10590761 0.47360042 0.18633234
 0.73691818 0.21655035 0.13521817 0.32414101 0.14967487 0.22232139
 0.38648898 0.90259848 0.44994999 0.61306346 0.90234858 0.09928035]
