In [2]:
import sys

import transformers
import torch

import numllama

In [3]:
# Unfortunately, I was not careful enough to avoid custom types
# in the lighning module __init__, which causes that saved
# checkpoints have references to classes (that are not available
# when moving code to a different module or project).
# This hack redirects the references in the saved checkpoint
# to correct classes. In the future, just avoid custom types
# in the lightning module __init__ and use only built-in types
# to make checkpoints portable.

sys.modules["svgai"] = numllama
sys.modules["svgai.train"] = numllama.addition

In [4]:
load_numeric_checkpoint = None
load_numeric_checkpoint = '/home/xkadlci2/svgai/checkpoints/vocal-frost-603__c5a8xfde/global-step=240000__valid-acc=1.000.ckpt'

if load_numeric_checkpoint:
    addition_model = numllama.addition.AdditionLightning.load_from_checkpoint(load_numeric_checkpoint)
    numeric_input_emb_config = addition_model.model.embedding_config.model_dump()
    numeric_encoder_config = addition_model.model.num_encoder_config
else:
    addition_model = None
    numeric_input_emb_config = dict(
        embedding_dim=256,
        min_value=0,
        max_value=10000,
        use_l2_norm=False,
        norm_const=None,
    )
    numeric_encoder_config = dict(
        _target_="numllama.nn.feedforward_backbone",
        model_dim=256,
        ff_dim=128,
        num_blocks=8,
        normalization=None,
        use_skips=True,
        skips_are_learnable=False,
        linears_constraint=None,
        dropout=0,
        activation_fn=dict(
            _target_="torch.nn.GELU"
        ),
    )



In [5]:
# load the pretrained llama model
checkpoint_name = "meta-llama/Llama-3.2-1B"
original_config = transformers.LlamaConfig.from_pretrained(checkpoint_name)
config = numllama.NumLlamaConfig(
    numeric_input_emb_config=numeric_input_emb_config,
    numeric_encoder_config=numeric_encoder_config,
    **original_config.to_dict(),
)

num_llama: numllama.NumLlamaForCausalLM
num_llama = numllama.NumLlamaForCausalLM.from_pretrained(checkpoint_name, config=config)

# create the new numeric embedding layer inside llama
num_llama.apply_numeric_patch()

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_name)

In [7]:
# check original tokenization
test_string = "hello 0 1 2 3 250 6401 131070 131071"
list(zip(tokenizer.tokenize(test_string), tokenizer.encode(test_string, add_special_tokens=False), strict=True))

[('hello', 15339),
 ('Ġ', 220),
 ('0', 15),
 ('Ġ', 220),
 ('1', 16),
 ('Ġ', 220),
 ('2', 17),
 ('Ġ', 220),
 ('3', 18),
 ('Ġ', 220),
 ('250', 5154),
 ('Ġ', 220),
 ('640', 14033),
 ('1', 16),
 ('Ġ', 220),
 ('131', 9263),
 ('070', 17819),
 ('Ġ', 220),
 ('131', 9263),
 ('071', 24508)]

In [8]:
# change how llama tokenizes numbers
numllama.patch_llama_digit_splitting(tokenizer)



In [9]:
numllama.add_num_tokens_to_tokenizer(
    numeric_input_emb_config["min_value"],
    numeric_input_emb_config["max_value"],
    tokenizer,
    num_llama
)



In [10]:
# checking the new tokenization
list(zip(tokenizer.tokenize(test_string), tokenizer.encode(test_string, add_special_tokens=False), strict=True))

[('hello', 15339),
 (' 0', 128256),
 (' 1', 128257),
 (' 2', 128258),
 (' 3', 128259),
 (' 250', 128506),
 (' 6401', 134657),
 (' 131070', 259326),
 (' 131071', 259327)]

In [11]:
# load trained weights from the pretrained addition model

if addition_model is not None:
    num_state_dict = addition_model.model.embedding.state_dict()
    num_llama.get_numeric_emb().load_state_dict(num_state_dict)

In [12]:
# Trying out generation

input_str = [
    "Hello, it's me.",
    "Hello, 2 is what?"
]

inputs = tokenizer(input_str, return_tensors="pt", truncation=True)

with num_llama.build_num_latents():
    outputs = num_llama.generate(**inputs.to(num_llama.device), max_new_tokens=5)

tokenizer.batch_decode(outputs)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


["<|begin_of_text|>Hello, it's me. 35293 10674 63234 7739 36662",
 '<|begin_of_text|>Hello, 2 is what? 63214 7750 35282 63447 63155']