In [1]:
import numpy as np
import numpy.random as npr
import qutip as qt
import math
import torch
import torch.nn as nn
import libraries.utils as utils
import time

In [2]:
def generate_input_samples(state_nums, N):
    # state_nums = torch.tensor(state_nums, dtype=torch.long)  # shape: (B,)
    powers = torch.arange(N, dtype=torch.long)               # shape: (N,)
    bits = (state_nums.unsqueeze(1) >> powers) & 1           # shape: (B, N)
    return bits.to(torch.float32)
def bitflip_batch(xs, N, flips):
    """
    Vectorized random bit flips on a batch of integers.

    Args:
        xs (Tensor): shape (B,), integers
        N (int): number of bits
        flips (int): number of random bit flips per element

    Returns:
        Tensor of shape (B,), integers after bit flips
    """
    B = xs.shape[0]
    xs = xs.clone()

    # Generate random bit indices for each flip and sample
    bit_indices = torch.randint(0, N, size=(B, flips))

    # Compute bitmasks: 1 << bit index
    bitmasks = (1 << bit_indices)  # shape: (B, flips)

    flip_masks = bitmasks[:, 0]
    for i in range(1, flips):
        flip_masks = flip_masks ^ bitmasks[:, i]

    return xs ^ flip_masks


In [3]:
from libraries.NeuralStates import SparseStateVector
class MHNeuralState(SparseStateVector):
    def __init__(self, N, model, output_to_psi, x_func, x0, num_samples, burnin = 0, lag = 0, chains = 1, new = False):
        """
        Initializes distribution of samples and vector values

        Args:
            N (int): number of qubits
            model: torch model representing psi(x), which returns complex amplitude given integer state
            output_to_psi (function): takes in output of model to compute complex amplitude
            x_func (function): takes in state x and generates proposal x*
            x0 (int): integer state to begin sampling
            num_samples (int): number of proposal x* generated
            burnin (int): number of samples to throw away before accepting first sample
            lag (int): number of samples to throw away in-between accepting samples
        """
        # uses arbitrary x_func for MH sampling
        super().__init__()
        self.distribution = {}
        self.samples = num_samples
        self.list = []
        self.nn_output = {}

        self.N = N
        self.model = model
        self.output_to_psi = output_to_psi
        self.x_func = x_func
        self.x0 = x0
        self.num_samples = num_samples
        self.burnin = burnin
        self.lag = lag
        self.chains = chains # note any of these could possibly be modified by a client

        self.forwards = 0
        self.times = {'forwards':0, 'x_func':0, 'gen_samples':0, 'convert_samples': 0, 'output_to_psi':0, 'calc_ratio':0, 'modifications':0}
        self.gen_states = 0

        if new or chains > 1:
            num_uniform = burnin * chains + num_samples * (lag + 1)
            self.rand_uniform = torch.rand(num_uniform)
        self.index = 0
        if chains == 1 and isinstance(x0, int):
            if not new:
                self._init_single_chain()
            else:
                self._init_single_chain_new()
        elif len(x0) == chains:
            self._init_multi_chain()
        else:
            del self.index
            raise Exception('invalid initial values or number of chains')
        del self.index
    
    
    def _init_single_chain(self):
        def psi(x):
            tens = torch.tensor([utils.generate_state_array(x, self.N)], dtype = torch.float32)
            nn_output = self.model(tens)
            self.forwards += 1
            return self.output_to_psi(nn_output)[0], nn_output[0]
        num_uniform = self.burnin + self.num_samples * (self.lag + 1)
        rand_uniform = npr.uniform(0, 1, num_uniform)
        index = 0

        x = self.x0

        psi_val, nn_val = psi(x)
        self.values[x] = psi_val
        self.nn_output[x] = nn_val

        for _ in range(self.burnin):
            new_x = self.x_func(x)
            new_psi_val = self.values[new_x] if new_x in self.values else psi(new_x)[0]
            ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
            print(x, rand_uniform[index])
            if ratio > 1 or ratio > rand_uniform[index]:
                x = new_x
                psi_val = new_psi_val
            index += 1
        for _ in range(self.num_samples):
            for _ in range(self.lag):
                new_x = self.x_func(x)
                new_psi_val = self.values[new_x] if new_x in self.values else psi(new_x)[0]
                ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
                if ratio > 1 or ratio > rand_uniform[index]:
                    x = new_x
                    psi_val = new_psi_val
                index += 1
            new_x = self.x_func(x)
            if new_x in self.values: new_psi_val, new_nn_val = self.values[new_x], self.nn_output[new_x]
            else: new_psi_val, new_nn_val = psi(new_x)
            ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
            if ratio > 1 or ratio > rand_uniform[index]:
                self.distribution[new_x] = self.distribution.get(new_x, 0) + 1
                self.list.append(new_x)
                x = new_x 
                psi_val = new_psi_val 
            else: 
                self.distribution[x] = self.distribution.get(x, 0) + 1
                self.list.append(x)
            self.values[new_x] = new_psi_val
            self.nn_output[new_x] = new_nn_val
            index += 1

    def _init_single_chain_new(self):
        def psi(x):
            start = time.time()
            arr = [utils.generate_state_array(x, self.N)]
            self.times['gen_samples'] += time.time() - start

            start = time.time()
            tens = torch.tensor(arr, dtype = torch.float32)
            self.times['convert_samples'] += time.time() - start
            
            self.gen_states += 1
            start = time.time()
            nn_output = self.model(tens)
            self.times['forwards'] += time.time() - start
            self.forwards += 1
            return self.output_to_psi(nn_output)[0], nn_output[0]

        x = self.x0

        def run_single_sample(modify = False):
            nonlocal x
            nonlocal psi_val
            new_x = self.x_func(x)
            if new_x in self.values: new_psi_val, new_nn_val = self.values[new_x], self.nn_output[new_x]
            else: new_psi_val, new_nn_val = psi(new_x)
            ratio = abs(new_psi_val) ** 2 / abs(psi_val) ** 2
            if ratio > 1 or ratio > self.rand_uniform[self.index]:
                if modify:
                    self.distribution[new_x] = self.distribution.get(new_x, 0) + 1
                    self.list.append(new_x)
                x = new_x 
                psi_val = new_psi_val 
            elif modify: 
                self.distribution[x] = self.distribution.get(x, 0) + 1
                self.list.append(x)
            self.values[new_x] = new_psi_val
            self.nn_output[new_x] = new_nn_val
            self.index += 1
            

        psi_val, nn_val = psi(x)
        self.values[x] = psi_val
        self.nn_output[x] = nn_val
        print('here', x)
        for _ in range(self.burnin):
            run_single_sample(modify = False)
        for _ in range(self.num_samples):
            for _ in range(self.lag):
                run_single_sample(modify = False)
            run_single_sample(modify = True)

    def _init_multi_chain(self):
        def psi(xs):
            self.gen_states += len(xs)
            start = time.time()
            # tens = utils.generate_input_samples(self.N, xs)
            tens = generate_input_samples(xs, self.N)
            # arr = [utils.generate_state_array(x, self.N) for x in xs]
            self.times['gen_samples'] += time.time() - start

            # start = time.time()
            # tens = torch.tensor(arr, dtype = torch.float32)
            # self.times['convert_samples'] += time.time() - start

            start = time.time()
            nn_output = self.model(tens)
            self.times['forwards'] += time.time() - start
            self.forwards += 1

            start = time.time()
            res = self.output_to_psi(nn_output)
            self.times['output_to_psi'] += time.time() - start
            return res, nn_output

        xs = self.x0[:]
        psi_vals, nn_vals = psi(xs)
        for i, x in enumerate(xs):
            self.values[x] = psi_vals[i]
            self.nn_output[x] = nn_vals[i]
        
        for _ in range(self.burnin):
            self._run_single_chained_sample(xs, psi_vals, self.x_func, psi, modify = False)
        num_iters = self.num_samples // self.chains
        remainder = self.num_samples % self.chains
        for c in range(num_iters):
            for _ in range(self.lag):
                self._run_single_chained_sample(xs, psi_vals, self.x_func, psi, modify = False)
            self._run_single_chained_sample(xs, psi_vals, self.x_func, psi, modify = True)
        if remainder != 0:
            xs = xs[:remainder]
            psi_vals = psi(xs)[0]
            self._run_single_chained_sample(xs, psi_vals, self.x_func, psi, modify = True)
    
    def _run_single_chained_sample(self, xs, psis, x_func, psi_function, modify = False):
        start = time.time()
        # new_xs = torch.tensor([x_func(x) for x in xs])
        new_xs = bitflip_batch(xs, self.N, 1)
        self.times['x_func'] += time.time() - start

        new_psi_vals, new_nn_vals = psi_function(new_xs)
        start = time.time()
        ratios = torch.abs(new_psi_vals) ** 2 / torch.abs(psis) ** 2
        self.times['calc_ratio'] += time.time() - start

        start = time.time()
        accept_mask = (ratios > 1) | (ratios > self.rand_uniform[self.index : self.index + len(xs)])
        xs[accept_mask] = new_xs[accept_mask]
        psis[accept_mask] = new_psi_vals[accept_mask]

        if modify:
            accepted = new_xs[accept_mask]
            rejected = xs[~accept_mask]

            for x in accepted.tolist():
                self.distribution[x] = self.distribution.get(x, 0) + 1
                self.list.append(x)

            for x in rejected.tolist():
                self.distribution[x] = self.distribution.get(x, 0) + 1
                self.list.append(x)

        for x, val, nn in zip(new_xs.tolist(), new_psi_vals.tolist(), new_nn_vals.tolist()):
            self.values[x] = val
            self.nn_output[x] = nn
        
        self.index += len(xs)

        # for i, (x, new_x, new_psi_val, new_nn_val, ratio) in enumerate(zip(xs, new_xs, new_psi_vals, new_nn_vals, ratios)):
        #     # start = time.time()
        #     # ratio = abs(new_psi_val) ** 2 / abs(psis[i]) ** 2
        #     # self.times['calc_ratio'] += time.time() - start
            
        #     if ratio > 1 or ratio > self.rand_uniform[self.index]:
        #         if modify:
        #             # input_samples[i] = torch.tensor(utils.generate_state_array(new_xs[i], self.N), dtype = torch.int32)
        #             self.distribution[new_x] = self.distribution.get(new_xs[i], 0) + 1
        #             self.list.append(new_x.item())
        #         xs[i] = new_x
        #         psis[i] = new_psi_val
        #     elif modify:
        #         self.distribution[x] = self.distribution.get(x, 0) + 1
        #         self.list.append(x.item())
        #     self.values[new_x] = new_psi_val
        #     self.nn_output[new_x] = new_nn_val
        #     self.index += 1
        self.times['modifications'] += time.time() - start
        

            

In [4]:
torch.manual_seed(0)
npr.seed(0)

In [5]:
N = 10
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))
mlp_model = nn.Sequential(*layers)

In [6]:
print(torch.rand((2, 2)))

tensor([[0.6973, 0.1897],
        [0.5673, 0.7153]])


In [7]:
print(utils.bitflip_x(0, 10, 1))

32


In [8]:
state_old = MHNeuralState(N, mlp_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), 0, 20, burnin = 2, lag = 1, chains = 1, new = False)

0 0.8442657485810173
512 0.8579456176227568


In [9]:
print(state_old.list) # the results are the same for new and old method using the same seed, check parallel_MH_2_seeding.ipynb
print(state_old.forwards)

[712, 717, 741, 245, 165, 165, 189, 429, 437, 277, 308, 278, 798, 831, 828, 808, 813, 801, 928, 384]
42


In [10]:
state_new = MHNeuralState(N, mlp_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), 0, 20, burnin = 2, lag = 1, chains = 1, new = True)

here 0


In [11]:
print(state_new.list)
print(state_new.forwards)

[128, 644, 644, 648, 666, 158, 140, 460, 448, 449, 384, 456, 192, 192, 204, 201, 201, 449, 197, 221]
32


In [13]:
chains = 10
state_chained = MHNeuralState(N, mlp_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), torch.zeros(chains, dtype = torch.long), 200, burnin = 50, lag = 5, chains = chains)

In [14]:
print(state_chained.list)
print(state_chained.forwards)

[699, 702, 113, 325, 748, 604, 405, 736, 197, 921, 445, 589, 330, 952, 349, 696, 500, 31, 353, 72, 55, 580, 442, 430, 725, 892, 212, 49, 360, 863, 93, 1016, 175, 214, 597, 296, 308, 146, 361, 540, 454, 884, 249, 484, 617, 925, 533, 542, 357, 520, 916, 999, 198, 885, 832, 707, 628, 307, 514, 437, 443, 231, 645, 327, 352, 194, 453, 573, 787, 970, 934, 470, 540, 199, 361, 451, 466, 801, 530, 283, 300, 576, 685, 327, 260, 995, 979, 167, 537, 216, 61, 323, 437, 398, 337, 873, 497, 133, 521, 538, 445, 288, 993, 506, 627, 366, 945, 173, 45, 513, 493, 49, 883, 793, 804, 162, 15, 192, 352, 648, 268, 419, 866, 1007, 16, 304, 540, 162, 102, 420, 508, 522, 787, 998, 919, 695, 774, 170, 212, 945, 617, 556, 465, 38, 722, 295, 422, 447, 936, 553, 1002, 254, 690, 42, 81, 160, 631, 276, 863, 913, 718, 47, 558, 321, 816, 1016, 217, 897, 405, 739, 714, 589, 65, 812, 268, 0, 236, 716, 280, 277, 430, 749, 231, 134, 905, 96, 230, 524, 512, 305, 1006, 675, 999, 702, 940, 198, 483, 256, 376, 473]
171


In [15]:
N = 30
layers = []
layers.append(nn.Linear(N, 32))
for _ in range(2):
    layers.append(nn.Linear(32, 32))
    layers.append(nn.SELU())
layers.append(nn.Linear(32, 2))
larger_model = nn.Sequential(*layers)

In [16]:
start = time.time()
state_old = MHNeuralState(N, larger_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), 0, 2000, burnin = 50, lag = 10, chains = 1, new = True)
print(time.time() - start)

here 0
3.635636806488037


In [17]:
print(state_old.forwards)
print(state_old.times)
print(state_old.gen_states)

21166
{'forwards': 1.391538381576538, 'x_func': 0, 'gen_samples': 0.06081223487854004, 'convert_samples': 0.15440583229064941, 'output_to_psi': 0, 'calc_ratio': 0, 'modifications': 0}
21166


In [26]:
start = time.time()
chains = 50
state_chained = MHNeuralState(N, larger_model, utils.log_amp_phase, lambda x: utils.bitflip_x(x, N, 1), torch.tensor([0] * chains, dtype = torch.long), 2000, burnin = 50, lag = 10, chains = chains)
print(time.time() - start)

0.15192699432373047


In [27]:
print(state_chained.forwards)
print(state_chained.times)
print(state_chained.gen_states)

491
{'forwards': 0.03354287147521973, 'x_func': 0.017861604690551758, 'gen_samples': 0.01392674446105957, 'convert_samples': 0, 'output_to_psi': 0.0161285400390625, 'calc_ratio': 0.010540008544921875, 'modifications': 0.05142545700073242}
24550
