In [1]:
import torch
import torch.nn as nn
import libraries.utils as utils

In [2]:
x = torch.rand(3)
x

tensor([0.5091, 0.1381, 0.3728])

In [3]:
x[0]

tensor(0.5091)

In [4]:
x[0] = 1

In [5]:
x

tensor([1.0000, 0.1381, 0.3728])

In [6]:
x = torch.rand((2, 2))
x

tensor([[0.4747, 0.3035],
        [0.5381, 0.5888]])

In [7]:
x[0] = torch.tensor([0, 0])
x

tensor([[0.0000, 0.0000],
        [0.5381, 0.5888]])

In [8]:
utils.generate_state_array(1, 5)

[1, 0, 0, 0, 0]

In [9]:
def generate_state_array(state_num, N):
    """
    Vectorized PyTorch version: returns binary representation (LSB first)
    """
    state_num = torch.tensor(state_num, dtype=torch.long)
    return ((state_num >> torch.arange(N)) & 1).to(torch.float32)

In [10]:
generate_state_array(1, 5)

tensor([1., 0., 0., 0., 0.])

In [11]:
%%timeit
torch.tensor([utils.generate_state_array(1, 5)], dtype=torch.float32)

3.56 μs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [12]:
%%timeit
generate_state_array(1, 5).unsqueeze(0)

13.9 μs ± 322 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [13]:
generate_state_array(1, 5).unsqueeze(0)

tensor([[1., 0., 0., 0., 0.]])

In [None]:
def generate_input_samples(state_nums, N):
    state_nums = torch.tensor(state_nums, dtype=torch.long)  # shape: (B,)
    powers = torch.arange(N, dtype=torch.long)               # shape: (N,)
    bits = (state_nums.unsqueeze(1) >> powers) & 1           # shape: (B, N)
    return bits # bits.to(torch.float32)

In [15]:
N = 10

In [16]:
%%timeit
utils.generate_input_samples(N, [n for n in range(2 ** N)])

2.05 ms ± 51.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
%%timeit
generate_input_samples([n for n in range(2 ** N)], N)

119 μs ± 2.78 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
N = 3
state_nums = [n for n in range(2 ** N)]
state_nums = torch.tensor(state_nums, dtype=torch.long)  # shape: (B,)
powers = torch.arange(N, dtype=torch.long)               # shape: (N,)
bits = (state_nums.unsqueeze(1) >> powers) & 1           # shape: (B, N)
print(state_nums)
print(powers)
print(state_nums.unsqueeze(1))
print(state_nums.unsqueeze(1) >> powers)
print(bits)
print(bits.to(torch.float32))

tensor([0, 1, 2, 3, 4, 5, 6, 7])
tensor([0, 1, 2])
tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7]])
tensor([[0, 0, 0],
        [1, 0, 0],
        [2, 1, 0],
        [3, 1, 0],
        [4, 2, 1],
        [5, 2, 1],
        [6, 3, 1],
        [7, 3, 1]])
tensor([[0, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [1, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [0, 1, 1],
        [1, 1, 1]])
tensor([[0., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 1., 0.],
        [0., 0., 1.],
        [1., 0., 1.],
        [0., 1., 1.],
        [1., 1., 1.]])


In [26]:
def bitflip_batch(xs, N, flips):
    """
    Vectorized random bit flips on a batch of integers.

    Args:
        xs (Tensor): shape (B,), integers
        N (int): number of bits
        flips (int): number of random bit flips per element

    Returns:
        Tensor of shape (B,), integers after bit flips
    """
    B = xs.shape[0]
    xs = xs.clone()

    # Generate random bit indices for each flip and sample
    bit_indices = torch.randint(0, N, size=(B, flips))

    # Compute bitmasks: 1 << bit index
    bitmasks = (1 << bit_indices)  # shape: (B, flips)

    flip_masks = bitmasks[:, 0]
    for i in range(1, flips):
        flip_masks = flip_masks ^ bitmasks[:, i]

    return xs ^ flip_masks


In [28]:
xs = torch.tensor([3, 5, 12])  # e.g., 0b0011, 0b0101, 0b1100
N = 4
flips = 1
bitflip_batch(xs, N, flips)

tensor([1, 7, 4])

In [37]:
N = 10

In [38]:
%%timeit
f = lambda x : utils.bitflip_x(x, N, 1)
[f(x) for x in range(2 ** N)]

1.6 ms ± 129 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [39]:
%%timeit
bitflip_batch(torch.arange(0, 2 ** N), N, 1)

25.1 μs ± 1.72 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [44]:
N = 5
xs = torch.arange(0, 2 ** N)
flips = 2
B = xs.shape[0]
xs = xs.clone()
print(B, xs) 

# Generate random bit indices for each flip and sample
bit_indices = torch.randint(0, N, size=(B, flips))
print(bit_indices)

# Compute bitmasks: 1 << bit index
bitmasks = (1 << bit_indices)  # shape: (B, flips)
print('b', bitmasks)

flip_masks = bitmasks[:, 0]
print('f1', flip_masks)
for i in range(1, flips):
    flip_masks = flip_masks ^ bitmasks[:, i]
print('f2', flip_masks)

res = xs ^ flip_masks
print(res)

32 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
tensor([[1, 2],
        [3, 2],
        [2, 2],
        [0, 2],
        [1, 1],
        [0, 3],
        [0, 0],
        [2, 1],
        [0, 3],
        [4, 0],
        [3, 1],
        [3, 2],
        [4, 1],
        [3, 0],
        [0, 3],
        [2, 1],
        [0, 2],
        [4, 4],
        [0, 2],
        [3, 1],
        [1, 1],
        [3, 4],
        [1, 1],
        [4, 2],
        [3, 1],
        [2, 3],
        [1, 0],
        [4, 4],
        [0, 4],
        [1, 3],
        [3, 2],
        [0, 4]])
b tensor([[ 2,  4],
        [ 8,  4],
        [ 4,  4],
        [ 1,  4],
        [ 2,  2],
        [ 1,  8],
        [ 1,  1],
        [ 4,  2],
        [ 1,  8],
        [16,  1],
        [ 8,  2],
        [ 8,  4],
        [16,  2],
        [ 8,  1],
        [ 1,  8],
        [ 4,  2],
        [ 1,  4],
        [16, 16],
        [ 1, 