In [None]:
# no need to restart kernel after code changes (useful to separate code to modules)
%load_ext autoreload
%autoreload 2

In [None]:
import gc

from time import time
from tqdm import tqdm

import sys
sys.path.append('..')

In [None]:
import random
import numpy as np

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from utils.general import compute_last_token_embedding_grad_emb, get_whole

In [None]:
def set_seed(seed: int = 8):
    """Set seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    # Ensure deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


## Torch Optimizers

In [None]:
gc.collect()

model_id = 'roneneldan/TinyStories-1M'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

### Gradient based on projection

In [None]:
def find_token(
    token_idx,
    embedding_matrix,
    discovered_embeddings, discovered_ids,
    llm, layer_idx, h_target,
    optimizer_cls, lr
    
):
    copy_embedding_matrix = embedding_matrix.clone().detach().requires_grad_(False)

    token_id = torch.randint(0, embedding_matrix.size(0), (1,)).item()
    
    embedding = copy_embedding_matrix[token_id].clone().requires_grad_(True)
    temp_embedding = copy_embedding_matrix[token_id].clone().detach()

    optimizer = optimizer_cls([embedding], lr=lr)

    bar = tqdm(
        range(embedding_matrix.size(0)), 
        desc=f'Token [{token_idx + 1:2d}/{h_target.size(0):2d}]'
    )

    for _ in bar:
        input_embeddings = torch.stack(
            discovered_embeddings + [temp_embedding]
        ).unsqueeze(0) 

        grad_oracle, loss = compute_last_token_embedding_grad_emb(
            embeddings=input_embeddings, 
            llm=llm,
            layer_idx=layer_idx,
            h_target=h_target[token_idx],
        )

        grad_norm = grad_oracle.norm().item()
        string_so_far = tokenizer.decode(discovered_ids + [token_id], skip_special_tokens=True)
        bar.set_postfix_str(f"Loss: {loss:.2e} - Gradient norm: {grad_norm:.2e} - String: {string_so_far}")

        if loss < 1e-5 or grad_norm < 1e-12:
            break

        embedding.grad = grad_oracle
        optimizer.step()

        copy_embedding_matrix[token_id] = float('inf')
        distances = torch.norm(copy_embedding_matrix - embedding, dim=1)
        token_id = int(torch.argmin(distances))
        temp_embedding = copy_embedding_matrix[token_id].clone()

    return token_id, copy_embedding_matrix[token_id]


def find_prompt(
    llm, layer_idx, h_target,
    optimizer_cls, lr,
):
    embedding_matrix = model.get_input_embeddings().weight

    if h_target.dim() == 1:
        h_target = h_target.unsqueeze(0)

    discovered_embeddings = []
    discovered_ids        = []

    start_time = time()
    for i in range(h_target.size(0)):
        next_token_id, next_token_embedding = find_token(
            i, embedding_matrix, 
            discovered_embeddings, discovered_ids, 
            llm, layer_idx, h_target,
            optimizer_cls, lr
        )

        discovered_embeddings.append(next_token_embedding)
        discovered_ids.append(next_token_id)
    
    end_time = time()

    final_string = tokenizer.decode(discovered_ids, skip_special_tokens=True)

    return end_time - start_time, final_string

# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
def inversion_attack(
    prompt, llm, layer_idx,
    optimizer_cls, lr,
    seed=8
):
    
    set_seed(seed)
    h_target = get_whole(prompt, model, tokenizer, layer_idx)

    invertion_time, predicted_prompt = find_prompt(
        llm, layer_idx, h_target, 
        optimizer_cls, lr
    )

    print(f'Orignial prompt : {prompt}')
    print(f'Predicted prompt: {predicted_prompt}')
    print(f'Invertion time  : {invertion_time:.2f} seconds')

inversion_attack(
    # prompt='my name is george and my secret is that i have a house in greece with the key: b92n0999olaellinika',
    prompt='12autoZeinai ena~~ !poli, a1212kiro pr33-=ompt tao op"\oio ;::/>elpizo na d1212isko1212leyt5646ei na ma77ntepsei to montelo',
    # llm=model, layer_idx=4, 
    # optimizer_cls=torch.optim.SGD, lr=1e-0
    llm=model, layer_idx=8, 
    optimizer_cls=torch.optim.AdamW, lr=1e-1
)

    

### Batched with smart init

In [None]:
import torch

def pdist_to_matrix(dists, n):
    """
    Converts pdist output (condensed upper-triangle vector) into a full (n, n) distance matrix.
    
    Args:
        dists (torch.Tensor): Output from torch.pdist (1D tensor)
        n (int): Number of rows (original matrix had shape (n, d))
    
    Returns:
        torch.Tensor: Full (n, n) symmetric distance matrix
    """
    mat = torch.zeros((n, n), dtype=dists.dtype, device=dists.device)
    idx = torch.triu_indices(n, n, offset=1)
    mat[idx[0], idx[1]] = dists
    mat[idx[1], idx[0]] = dists  # Make it symmetric
    return mat

x = torch.tensor([[0.0, 0.0],
                  [1.0, 0.0],
                  [0.0, 1.0]])

d = torch.pdist(x, p=2)
n = x.size(0)
dist_matrix = pdist_to_matrix(d, n)
print(dist_matrix)


In [None]:
def pdist_index(i, j, n):
    """Given indices i < j and matrix size n, return index into pdist vector."""
    i, j = min(i, j), max(i, j)
    assert i < j, 'Lol my g'
    return i * (2 * n - i - 1) // 2 + (j - i - 1)


# Dists essentially needs to be computed once in the very beginning
# We could also try to do some dimensionality reduction of matrix 
def kmeans_pp_init(matrix):
    dists = torch.pdist(matrix)
    
    




def find_token(
    token_idx,
    embedding_matrix,
    discovered_embeddings, discovered_ids,
    llm, layer_idx, h_target,
    optimizer_cls, lr,
    B = 10
    
):
    copy_embedding_matrix = embedding_matrix.clone().detach().requires_grad_(False)

    token_id = torch.randint(0, embedding_matrix.size(0), (1,)).item()

    # I need to sample B tokens and create a B x dim matrix

    
    embedding = copy_embedding_matrix[token_id].clone().requires_grad_(True)
    temp_embedding = copy_embedding_matrix[token_id].clone().detach()

    optimizer = optimizer_cls([embedding], lr=lr)

    bar = tqdm(
        range(embedding_matrix.size(0)), 
        desc=f'Token [{token_idx + 1:2d}/{h_target.size(0):2d}]'
    )

    for _ in bar:
        input_embeddings = torch.stack(
            discovered_embeddings + [temp_embedding]
        ).unsqueeze(0) 

        grad_oracle, loss = compute_last_token_embedding_grad_emb(
            embeddings=input_embeddings, 
            llm=llm,
            layer_idx=layer_idx,
            h_target=h_target[token_idx],
        )

        grad_norm = grad_oracle.norm().item()
        string_so_far = tokenizer.decode(discovered_ids + [token_id], skip_special_tokens=True)
        bar.set_postfix_str(f"Loss: {loss:.2e} - Gradient norm: {grad_norm:.2e} - String: {string_so_far}")

        if loss < 1e-5 or grad_norm < 1e-12:
            break

        embedding.grad = grad_oracle
        optimizer.step()

        copy_embedding_matrix[token_id] = float('inf')
        distances = torch.norm(copy_embedding_matrix - embedding, dim=1)
        token_id = int(torch.argmin(distances))
        temp_embedding = copy_embedding_matrix[token_id].clone()

    return token_id, copy_embedding_matrix[token_id]


def find_prompt(
    llm, layer_idx, h_target,
    optimizer_cls, lr,
):
    embedding_matrix = model.get_input_embeddings().weight

    if h_target.dim() == 1:
        h_target = h_target.unsqueeze(0)

    discovered_embeddings = []
    discovered_ids        = []

    start_time = time()
    for i in range(h_target.size(0)):
        next_token_id, next_token_embedding = find_token(
            i, embedding_matrix, 
            discovered_embeddings, discovered_ids, 
            llm, layer_idx, h_target,
            optimizer_cls, lr
        )

        discovered_embeddings.append(next_token_embedding)
        discovered_ids.append(next_token_id)
    
    end_time = time()

    final_string = tokenizer.decode(discovered_ids, skip_special_tokens=True)

    return end_time - start_time, final_string

# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
def inversion_attack(
    prompt, llm, layer_idx,
    optimizer_cls, lr,
    seed=8
):
    
    set_seed(seed)
    h_target = get_whole(prompt, model, tokenizer, layer_idx)

    invertion_time, predicted_prompt = find_prompt(
        llm, layer_idx, h_target, 
        optimizer_cls, lr
    )

    print(f'Orignial prompt : {prompt}')
    print(f'Predicted prompt: {predicted_prompt}')
    print(f'Invertion time  : {invertion_time:.2f} seconds')

inversion_attack(
    # prompt='my name is george and my secret is that i have a house in greece with the key: b92n0999olaellinika',
    prompt='12autoZeinai ena~~ !poli, a1212kiro pr33-=ompt tao op"\oio ;::/>elpizo na d1212isko1212leyt5646ei na ma77ntepsei to montelo',
    # llm=model, layer_idx=4, 
    # optimizer_cls=torch.optim.SGD, lr=1e-0
    llm=model, layer_idx=8, 
    optimizer_cls=torch.optim.AdamW, lr=1e-1
)

    

### Gradient Descent, No projection

In [None]:
def get_closest_token_id(embedding_matrix, embedding):
    distances = torch.norm(embedding_matrix - embedding, dim=1)
    return int(torch.argmin(distances))

def find_token_no_proj(
    token_idx,
    embedding_matrix,
    discovered_embeddings, discovered_ids,
    llm, layer_idx, h_target,
    optimizer_cls, lr
    
):
    copy_embedding_matrix = embedding_matrix.clone().detach().requires_grad_(False)

    token_id = torch.randint(0, embedding_matrix.size(0), (1,)).item()
    
    embedding = torch.zeros_like(copy_embedding_matrix[token_id], requires_grad=True)
    embedding.data = copy_embedding_matrix.data[token_id].clone()

    optimizer = optimizer_cls([embedding], lr=lr)

    bar = tqdm(
        range(embedding_matrix.size(0)), 
        desc=f'Token [{token_idx + 1:2d}/{h_target.size(0):2d}]'
    )

    for _ in bar:
        input_embeddings = torch.stack(
            discovered_embeddings + [embedding]
        ).unsqueeze(0) 

        grad_oracle, loss = compute_last_token_embedding_grad_emb(
            embeddings=input_embeddings, 
            llm=llm,
            layer_idx=layer_idx,
            h_target=h_target[token_idx],
        )

        grad_norm = grad_oracle.norm().item()
        string_so_far = tokenizer.decode(
            discovered_ids + [get_closest_token_id(copy_embedding_matrix, embedding)], 
            skip_special_tokens=True
        )
        bar.set_postfix_str(f"Loss: {loss:.2e} - Gradient norm: {grad_norm:.2e} - String: {string_so_far}")

        if loss < 1e-6 or grad_norm < 1e-12:
            break

        if loss < 1e-4:
            token_id = get_closest_token_id(copy_embedding_matrix, embedding)
            embedding.data = copy_embedding_matrix.data[token_id].clone()
            copy_embedding_matrix.data[token_id] = float('inf')
            continue

        embedding.grad = grad_oracle
        optimizer.step()

    token_id = get_closest_token_id(copy_embedding_matrix, embedding)
    return token_id, copy_embedding_matrix[token_id]


def find_prompt_no_proj(
    llm, layer_idx, h_target,
    optimizer_cls, lr,
):
    embedding_matrix = model.get_input_embeddings().weight

    if h_target.dim() == 1:
        h_target = h_target.unsqueeze(0)

    discovered_embeddings = []
    discovered_ids        = []

    start_time = time()
    for i in range(h_target.size(0)):
        next_token_id, next_token_embedding = find_token_no_proj(
            i, embedding_matrix, 
            discovered_embeddings, discovered_ids, 
            llm, layer_idx, h_target,
            optimizer_cls, lr
        )

        discovered_embeddings.append(next_token_embedding)
        discovered_ids.append(next_token_id)
    
    end_time = time()

    final_string = tokenizer.decode(discovered_ids, skip_special_tokens=True)

    return end_time - start_time, final_string

# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
def inversion_attack_no_proj(
    prompt, llm, layer_idx,
    optimizer_cls, lr,
    seed=8
):
    
    set_seed(seed)
    h_target = get_whole(prompt, model, tokenizer, layer_idx)

    invertion_time, predicted_prompt = find_prompt_no_proj(
        llm, layer_idx, h_target, 
        optimizer_cls, lr
    )

    print(f'Orignial prompt : {prompt}')
    print(f'Predicted prompt: {predicted_prompt}')
    print(f'Invertion time  : {invertion_time:.2f} seconds')

inversion_attack_no_proj(
    prompt='my name is george and my secret is that i have a house in greece with the key: b92n0999olaellinika',
    llm=model, layer_idx=4, 
    optimizer_cls=torch.optim.Adam, lr=1e-2
)

    

### PGD

In [None]:
def find_token(
    token_idx,
    embedding_matrix,
    discovered_embeddings, discovered_ids,
    llm, layer_idx, h_target,
    optimizer_cls, lr
    
):
    copy_embedding_matrix = embedding_matrix.clone().detach().requires_grad_(False)

    token_id = torch.randint(0, embedding_matrix.size(0), (1,)).item()
    embedding = copy_embedding_matrix[token_id].clone().requires_grad_(True)

    optimizer = optimizer_cls([embedding], lr=lr)

    bar = tqdm(
        range(embedding_matrix.size(0)), 
        desc=f'Token [{token_idx + 1:2d}/{h_target.size(0):2d}]'
    )

    for _ in bar:
        input_embeddings = torch.stack(
            discovered_embeddings + [embedding]
        ).unsqueeze(0) 

        grad_oracle, loss = compute_last_token_embedding_grad_emb(
            embeddings=input_embeddings, 
            llm=llm,
            layer_idx=layer_idx,
            h_target=h_target[token_idx],
        )

        grad_norm = grad_oracle.norm().item()
        string_so_far = tokenizer.decode(discovered_ids + [token_id], skip_special_tokens=True)
        bar.set_postfix_str(f"Loss: {loss:.2e} - Gradient norm: {grad_norm:.2e} - String: {string_so_far}")

        if loss < 1e-5 or grad_norm < 1e-12:
            break

        embedding.grad = grad_oracle
        optimizer.step()

        copy_embedding_matrix.data[token_id] = float('inf')
        token_id = get_closest_token_id(copy_embedding_matrix, embedding)
        embedding.data = copy_embedding_matrix.data[token_id].clone()

    return token_id, copy_embedding_matrix[token_id]


def find_prompt(
    llm, layer_idx, h_target,
    optimizer_cls, lr,
):
    embedding_matrix = model.get_input_embeddings().weight

    if h_target.dim() == 1:
        h_target = h_target.unsqueeze(0)

    discovered_embeddings = []
    discovered_ids        = []

    start_time = time()
    for i in range(h_target.size(0)):
        next_token_id, next_token_embedding = find_token(
            i, embedding_matrix, 
            discovered_embeddings, discovered_ids, 
            llm, layer_idx, h_target,
            optimizer_cls, lr
        )

        discovered_embeddings.append(next_token_embedding)
        discovered_ids.append(next_token_id)
    
    end_time = time()

    final_string = tokenizer.decode(discovered_ids, skip_special_tokens=True)

    return end_time - start_time, final_string

# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
def inversion_attack(
    prompt, llm, layer_idx,
    optimizer_cls, lr,
    seed=8
):
    
    set_seed(seed)
    h_target = get_whole(prompt, model, tokenizer, layer_idx)

    invertion_time, predicted_prompt = find_prompt(
        llm, layer_idx, h_target, 
        optimizer_cls, lr
    )

    print(f'Orignial prompt : {prompt}')
    print(f'Predicted prompt: {predicted_prompt}')
    print(f'Invertion time  : {invertion_time:.2f} seconds')

inversion_attack(
    # prompt='my name is george and my secret is that i have a house in greece with the key: b92n0999olaellinika',
    prompt='12autoZeinai ena~~ !poli, a1212kiro pr33-=ompt tao op"\oio ;::/>elpizo na d1212isko1212leyt5646ei na ma77ntepsei to montelo',
    llm=model, layer_idx=2, 
    # optimizer_cls=torch.optim.SGD, lr=1e+1,
    optimizer_cls=torch.optim.Adam, lr=1e-0,
)

    

In [None]:
import torch

def pairwise_distance_matrix(x):
    """
    Computes full pairwise Euclidean distance matrix using torch.pdist.
    
    Args:
        x (torch.Tensor): Input tensor of shape (n, d)
        
    Returns:
        torch.Tensor: Full (n, n) matrix of pairwise distances
    """
    dists = torch.pdist(x, p=2)  # Condensed vector form
    print(dists.shape)
    print(dists)
    n = x.size(0)
    
    # Create a full (n x n) distance matrix and fill upper triangle
    dist_matrix = torch.zeros((n, n), device=x.device, dtype=x.dtype)
    idx = torch.triu_indices(n, n, offset=1)
    dist_matrix[idx[0], idx[1]] = dists
    
    # Mirror to the lower triangle
    # dist_matrix = dist_matrix + dist_matrix.T
    return dist_matrix

x = torch.tensor([[0.0, 0.0],
                  [1.0, 0.0],
                  [1.0, 1.0],
                  [0.0, 1.0]])

points = x.size(0)

i = 1
j = 2
idx = i * points + j 

dist_matrix = pairwise_distance_matrix(x)
print(dist_matrix)
