In [1]:
import curl

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

curl.init()

  from .autonotebook import tqdm as notebook_tqdm


Using Communicator type:  <class 'curl.communicator.distributed_communicator.DistributedCommunicator'>
[<>] Waiting for connections...
[<>] DEFAULT ARGS: {'DISTRIBUTED_BACKEND': 'gloo', 'RENDEZVOUS': 'file:///tmp/xcrypten-Vcrypten-rcrypten-wcrypten-Jcrypten-Jcrypten-Ocrypten-Lcrypten-ncrypten-v', 'WORLD_SIZE': 1, 'RANK': 0, 'TTP': False}
[Device] LUTs initialized for cpu



In [2]:
MODEL_NAME = "roneneldan/TinyStories-1M"

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [3]:
# This is a trick to enable tokenizer padding if not yet possible before
tokenizer.pad_token = tokenizer.eos_token

In [4]:
# NOTE: think could also have a different name like `n_embd` in the config, depending on the model
hidden_size = model.config.hidden_size

In [5]:
INPUT_TEXT = "the little girl"

tokens = tokenizer(
    INPUT_TEXT,
    return_tensors='pt',
)
NUM_INPUT_TOKENS = len(tokens)  # actual number of input tokens

In [6]:
# To ensure a constant input size, we will always pad or truncate to a fixed size
encoded_input = tokenizer(
    INPUT_TEXT,
    return_tensors='pt',
    padding="max_length",
    truncation=True,
    max_length=hidden_size,
)

output = model.forward(
    input_ids=encoded_input["input_ids"],
    attention_mask=encoded_input["attention_mask"],
)

In [7]:
# shape: batch_size x sequence length x vocab size
# interpretation: for every token in the sequence, what is the probability distribution over all tokens in the vocab for what the following token is?
output.logits.shape

torch.Size([1, 64, 50257])

In [8]:
# The prediction is the logit argmax of the next token prediction for the last position of the actual input sequence

predicted_token_id = torch.argmax(output.logits[:, NUM_INPUT_TOKENS, :], dim=-1)
print("Next predicted token is: ", tokenizer.decode(predicted_token_id))
print("Text completion is: ", INPUT_TEXT + tokenizer.decode(predicted_token_id))

Next predicted token is:   was
Text completion is:  the little girl was


In [9]:
private_model = curl.nn.from_pytorch(model, encoded_input["input_ids"])

  if input_shape[-1] > 1 or self.sliding_window is not None:
  if past_key_values_length > 0:
  mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
  param = torch.from_numpy(numpy_helper.to_array(node))


In [10]:
private_model.encrypt()

Graph encrypted module

In [11]:
private_input = curl.cryptensor(encoded_input["input_ids"], precision=0)

In [12]:
private_output = private_model.forward(private_input)

In [13]:
next_token_id = torch.argmax(private_output.get_plain_text()[:, NUM_INPUT_TOKENS, :], dim=-1)
print("Next predicted token is: ", tokenizer.decode(next_token_id))
print("Text completion is: ", INPUT_TEXT + tokenizer.decode(next_token_id))

Next predicted token is:  urious
Text completion is:  the little girlurious


In [14]:
def generate(
    model: curl.nn.module.Graph,
    input_tokens: curl.mpc.mpc.MPCTensor,
    sequence_length: int,
    max_new_tokens: int,
) -> list[curl.mpc.mpc.MPCTensor]:

    sequence = input_tokens.clone()

    generated_tokens = []
    for _ in range(max_new_tokens):
        # Perform forward pass & compute logits
        [logits] = model.forward(sequence)  # NOTE: assumption = batch size of 1
        # Select logit of last token in input sequence (which hold logits over the vocab for next token)
        next_logit = logits[sequence_length - 1]
        # Compute argmax to determine next token
        next_token_id = next_logit.argmax(one_hot=False)
        generated_tokens.append(next_token_id)
        next_token_id = next_token_id._tensor.data.item()  # NOTE: if we encoded the IDs as rationals, this would need to be upscaled
        # Check if we reached the maximum sequence length
        # Here this is equal to the logits dim because we padded / truncated the sequence to this maximum length
        if sequence_length >= len(logits):
            # We make room for the next token by discarding the first token in the sequence
            raw_truncated_input_tokens = raw_truncated_input_tokens._tensor.data[1:]
            sequence._tensor.data = torch.cat(
                (
                    raw_truncated_input_tokens,
                    torch.tensor([[next_token_id]]),
                ),
                dim=1,
            )  
        else:
            # Replace padding token by new token & update sequence length bc input sequence is now 1 token longer
            sequence._tensor.data[0][sequence_length] = next_token_id
            sequence_length += 1

    return generated_tokens

In [15]:
[input_token_ids] = private_input.reveal().int()
input_seq = tokenizer.decode(
    input_token_ids,
    skip_special_tokens=True,
)
print("Input sequence:", input_seq)

private_generated_token_ids = generate(
    model=private_model,
    input_tokens=private_input,
    sequence_length=NUM_INPUT_TOKENS,
    max_new_tokens=3,
)

# TODO: There has to be a cleaner way to structure this
generated_token_ids = [
    x.reveal().int().item()
    for x in private_generated_token_ids
]
generated_seq = tokenizer.decode(
    generated_token_ids,
    skip_special_tokens=True,
)

output_seq = input_seq + generated_seq
print("Output sequence:", output_seq)

Input sequence: the little girl
Output sequence: the little girlokes again emergencies
