In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [2]:
from matplotlib import gridspec

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [4]:
from bitarray import bitarray

In [5]:
from functools import reduce

In [7]:
def softmax(x, axis=-1):
    # Subtract the maximum value for numerical stability
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / np.sum(e_x, axis=axis, keepdims=True)

def normalize(v):
    assert len(np.where(v < 0)[0]) == 0
    s = v.sum()
    return v/s if s > 0 else v

def entropy(p, axis=None, base=2):
    p = np.asarray(p, dtype=float)
    
    # Handle zero probabilities: 0*log(0) = 0
    # Only compute log for positive elements to avoid numerical issues
    mask = p > 0
    log_p = np.zeros_like(p)
    log_p[mask] = np.log(p[mask]) / np.log(base)
    
    # Shannon entropy: H = -sum(p * log(p))
    entropy = -np.sum(p * log_p, axis=axis)
    
    return entropy

In [8]:
def block_partition(data, block_size):
    bitdata = bitarray(data)
    # pad
    r = len(bitdata)%block_size
    bitdata += bitarray('0'*(block_size-r) )
#     print(len(bitdata))
    assert len(bitdata) % block_size == 0
    
    blocks = []
    for i in range(0, len(bitdata), block_size):
        block = bitdata[i:i+block_size]
        blocks.append(block)
    return blocks

def barr_to_int(barr):
    val = 0
    for bit in barr:
        val = (val << 1) | bit
    return val


def int_to_barr(n: int, width: int = None) -> bitarray:
    if n < 0:
        raise ValueError("Only non-negative integers are supported.")
    if n == 0:
        bits = bitarray('0')
    else:
        bits = bitarray()
        while n > 0:
            bits.append(n & 1)  # extract least significant bit
            n >>= 1
        bits.reverse()
    
    # Pad with leading zeros if width is specified
    if width is not None:
        if len(bits) > width:
            raise ValueError(f"Integer too large to fit in {width} bits")
        pad = bitarray('0' * (width - len(bits)))
        bits = pad + bits
    
    return bits

In [9]:
# Load small model (runs on CPU fine)
model_name = "distilgpt2"  # or "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)


In [10]:
def get_topk_distribution(context: str, k: int = 10, decode=False):
    # Encode context
    inputs = tokenizer(context, return_tensors="pt")
    
    # Forward pass (no gradient needed)
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get logits for last token
#     print(outputs.logits.shape)
    logits = outputs.logits[:, -1, :]  # shape: [batch_size, vocab_size]
#     print(logits.shape)
    
    # Convert to probabilities
    probs = torch.softmax(logits, dim=-1)
    
    # Get top-k
    topk_probs, topk_indices = torch.topk(probs, k)
    
    # Decode tokens
    if decode:
        topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices[0]]
        return list(zip(topk_tokens, topk_probs[0].tolist()))
    else:
        return topk_probs.numpy().squeeze(0), topk_indices.numpy().squeeze(0)

In [11]:
def mec(p: np.array, q: np.array):
    """
    Algorithm 1:  https://arxiv.org/pdf/1611.04035.pdf
    We adjust Algorithm 1 and follow the advice in the text in order to reconstruct the matrix.

    Supposedly has 1-bit guarantee - unfortunately not clear if equal to kacaoglu2

    We require len(p) == q.
    """
    p = p.copy()#.astype(np.longdouble)
    p /= p.sum()
    q = q.copy()#.astype(np.longdouble)
    q /= q.sum()
    
    if p.shape[0] > q.shape[0]:
        q = np.concatenate([q, np.zeros(p.shape[0]-q.shape[0]#, dtype=np.longdouble
                                       )])
    elif q.shape[0] > p.shape[0]:
        p = np.concatenate([p, np.zeros(q.shape[0]-p.shape[0]#, dtype=np.longdouble
                                       )])
    assert len(p) == len(q), "len(p) must be equal to len(q)!"
    # Joint distribution
    J = np.zeros((q.shape[0], p.shape[0])#, dtype=np.longdouble
                )  

    # e = []
    M = np.stack((p, q), 0)
    r = M.max(axis=1).min()
    while r > 0:
        # e.append(r)
        a_i = M.argmax(axis=1)
        M[0, a_i[0]] -= r
        M[1, a_i[1]] -= r
        J[a_i[0], a_i[1]] = r
        r = M.max(axis=1).min()
    return J

In [12]:
# VOCAB_SIZE = 50_257 # domain of covertext distribution

def encode(ciphertext_bits_arr, context, steps, topk=40, block_size=10):
    '''
    Encode ciphertext into covertext.
    
    '''
    mu_dom = 2**block_size
    ct_blocks = block_partition(ciphertext_bits_arr, block_size)
    ct_blocks_idxs = list(map(barr_to_int, ct_blocks))
    mus = [np.ones(mu_dom,)/mu_dom for _ in ct_blocks] # init uniforms
    mus_entropy = np.array([entropy(mu) for mu in mus])

    # autoregressive conditional: p(C_j | C_1:j-1 = S_1:j-1)
    C_ac_probs, C_ac_idxs = get_topk_distribution(context, k=topk, decode=False)  
    S = []

    step = 0
    while step < steps:
        print(f'Step: {step}')
        istar = np.argmax(mus_entropy) # i in [0, len(mus)]
        mu_istar = mus[istar]
        # coupling
        p, q = mu_istar, C_ac_probs
        M = mec(p,q) # (mu_dom, topk)
        M = M[:p.shape[0], :q.shape[0]] # eliminate padding 0s
        # condition on realization of block
        d_token = normalize(M[ct_blocks_idxs[istar]]) # (topk, 1)
        print(f'\tH( g(C_j|X_i* = x_i*) ): {entropy(d_token)}')

        S_j_ix = np.random.choice(np.arange(0,topk), p=d_token) # S_j_ix is index in [0, topk]
        S_j = C_ac_idxs[S_j_ix] # S_j is token index in [0, VOCAB_SIZE]
        S.append(S_j)
        # update context, generate new AC distribution
        context += tokenizer.decode([S_j])
        C_ac_probs, C_ac_idxs = get_topk_distribution(context, k=topk, decode=False)  
        print(f'\tContext: {context}')

        # condition on realization of next token
        mu_istar_prime = normalize(M[:, S_j_ix])
        # update
        mus[istar] = mu_istar_prime
        mus_entropy[istar] = entropy(mu_istar_prime)

        step += 1

    return S, context

In [40]:
global_M = 0

In [41]:
def decode(S, context, steps, topk=40, block_size=10):
#     block_size = 10
#     ciphertext_bits_arr=bitarray(b'hello_world')
#     context = 'The quick brown fox'
#     S = tokenizer.encode('es on their way home. It has a nice warm, white face so let out a bit of')
#     topk=40

    mu_dom = 2**block_size
    mus = [np.ones(mu_dom,)/mu_dom for _ in range(9)] # init uniforms
    mus_entropy = np.array([entropy(mu) for mu in mus])
    C_ac_probs, C_ac_idxs = get_topk_distribution(context, k=topk, decode=False)  

    step = 0
    while step < steps:
#         print(f'Step: {step}')
        istar = np.argmax(mus_entropy) # i in [0, len(mus)]
        mu_istar = mus[istar]
        # coupling
        p, q = mu_istar, C_ac_probs
        M = mec(p,q) # (mu_dom, topk)
        
        M = M[:p.shape[0], :q.shape[0]] # eliminate padding 0s
        global_M = M.copy()
        print(global_M)

        S_j_ix = np.where(C_ac_idxs == S[step])[0].item()
    #         S_j_ix = np.random.choice(np.arange(0,topk), p=d_token) # S_j_ix is index in [0, topk]
    #         S_j = C_ac_idxs[S_j_ix] # S_j is token index in [0, VOCAB_SIZE]
    #         S.append(S_j)
        # update context, generate new AC distribution
        context += tokenizer.decode([ S[step] ])
        C_ac_probs, C_ac_idxs = get_topk_distribution(context, k=topk, decode=False)  
#         print(f'\tContext: {context}')

        # condition on realization of next token
        mu_istar_prime = normalize(M[:, S_j_ix])
        # update
        mus[istar] = mu_istar_prime
        mus_entropy[istar] = entropy(mu_istar_prime)

        step += 1
       
    
    out_barr = bitarray('')
    block_ids = []
    for i in mus:
        idcand = np.random.choice(np.arange(1024), p=i)
        block_ids.append(idcand)
        barr = int_to_barr(idcand, width=block_size)
        out_barr += barr
    return block_ids, out_barr.tobytes()

In [46]:
S, finaltext = encode(bitarray(b'hello_world'), 'The quick brown fox', 50)

Step: 0
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes
Step: 1
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on
Step: 2
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their
Step: 3
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their way
Step: 4
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their way home
Step: 5
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their way home.
Step: 6
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their way home. It
Step: 7
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their way home. It has
Step: 8
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their way home. It has a
Step: 9
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their way home. It has a nice
Step: 10
	H( g(C_j|X_i* = x_i*) ): -0.0
	Context: The quick brown foxes on their way home. It has a nice warm
Step: 11


	Context: The quick brown foxes on their way home. It has a nice warm, white beak-back shape and is easy to see.<|endoftext|>What's not always your typical Android phone is a very portable phone. However, what it's worth is just one
Step: 48
	H( g(C_j|X_i* = x_i*) ): 3.8391480452111444
	Context: The quick brown foxes on their way home. It has a nice warm, white beak-back shape and is easy to see.<|endoftext|>What's not always your typical Android phone is a very portable phone. However, what it's worth is just one small
Step: 49
	H( g(C_j|X_i* = x_i*) ): 4.494296552080877
	Context: The quick brown foxes on their way home. It has a nice warm, white beak-back shape and is easy to see.<|endoftext|>What's not always your typical Android phone is a very portable phone. However, what it's worth is just one small addition


In [47]:
S

[274,
 319,
 511,
 835,
 1363,
 13,
 632,
 468,
 257,
 3621,
 5814,
 11,
 2330,
 307,
 461,
 12,
 1891,
 5485,
 290,
 318,
 2562,
 284,
 766,
 13,
 50256,
 2061,
 338,
 407,
 1464,
 534,
 7226,
 5565,
 3072,
 318,
 257,
 845,
 17726,
 3072,
 13,
 2102,
 11,
 644,
 340,
 338,
 2861,
 318,
 655,
 530,
 1402,
 3090]

In [48]:
decode(S, 'The quick brown fox', 50)

[[0.00097656 0.         0.         ... 0.         0.         0.        ]
 [0.00097656 0.         0.         ... 0.         0.         0.        ]
 [0.00097656 0.         0.         ... 0.         0.         0.        ]
 ...
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]]
[[0.00097656 0.         0.         ... 0.         0.         0.        ]
 [0.00097656 0.         0.         ... 0.         0.         0.        ]
 [0.00097656 0.         0.         ... 0.         0.         0.        ]
 ...
 [0.         0.         0.00055819 ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.00052305 0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]]
[[9.76562500e-04 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [9.76562500e

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ...

([417, 598, 795, 111, 381, 886, 988, 620, 400], b'hello_world\x00')

In [50]:
c2 = 'The quick brown fox'
for _ in range(75):
    probs, ixs = get_topk_distribution(c2, k=40, decode=False)
    s_ix = torch.multinomial(torch.tensor(probs), 1)
    c2 += tokenizer.decode([ ixs[s_ix] ])
print(c2)

The quick brown foxes have never been known to have had a tendency to bite. They need to have a nice mouth.


There are over 1,000 foxes on the continent in this genus - they have been found in Europe's most populated European states. This is because, since the beginning of the last great European European expansion, they have been around for many centuries in Europe


In [18]:
def plot_joint_with_marginals(joint_dist, cmap="Blues"):
    """
    Plot a joint distribution as a heatmap with marginals.


    Parameters
    ----------
    joint_dist : 2D numpy array
    Discrete joint distribution (rows ~ X, cols ~ Y).
    cmap : str
    Colormap for the heatmap.
    """
    joint_dist = np.array(joint_dist, dtype=float)
    joint_dist /= joint_dist.sum() # normalize if not already


    # Marginals
    marg_x = joint_dist.sum(axis=1) # sum over Y
    marg_y = joint_dist.sum(axis=0) # sum over X


    # Figure layout
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4],
    wspace=0.05, hspace=0.05)


    ax_joint = fig.add_subplot(gs[1, 0])
    ax_marg_x = fig.add_subplot(gs[0, 0], sharex=ax_joint)
    ax_marg_y = fig.add_subplot(gs[1, 1], sharey=ax_joint)


    # Heatmap (square aspect)
    im = ax_joint.imshow(joint_dist, origin="lower", aspect="equal", cmap=cmap)
    ax_joint.set_xlabel("Y")
    ax_joint.set_ylabel("X")


    # Marginals
    ax_marg_x.bar(range(joint_dist.shape[1]), marg_y)
    ax_marg_y.barh(range(joint_dist.shape[0]), marg_x)


    # Clean axes
    plt.setp(ax_marg_x.get_xticklabels(), visible=False)
    plt.setp(ax_marg_y.get_yticklabels(), visible=False)


    ax_marg_x.set_yticks([])
    ax_marg_y.set_xticks([])



    plt.show()