In [17]:
import os, sys
sys.path.append('/playpen-ssd/smerrill/llm_decisions')
from dataclasses import dataclass, field
from datasets import (Dataset, IterableDataset,)
import torch
from transformers import AutoTokenizer, TrainingArguments
from trl.commands.cli_utils import  TrlParser
from transformers import (
    AutoModelForCausalLM,
    PreTrainedTokenizerBase,
    AutoTokenizer,
    BitsAndBytesConfig,
        set_seed,
)

from trl import setup_chat_format
from peft import LoraConfig
import numpy as np

from trl import (
   SFTTrainer)
from collections import Counter
import re
import wandb
from utils import train_on_responses_only, train_test_split


In [None]:
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_data, test_data, train_completion_data = train_test_split('kateacuff')

tmp = [{"text": text} for text in train_data]
train_data = Dataset.from_list(tmp)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
print(train_data['text'][19])

<|begin_of_text|><|start_header_id|>assistant<|end_header_id|>

kateacuff: I would agree that that the online retreat and half a day, but I would keep it open to having part to either online or with a delayed larger group. I think having all day long online, you would see people ducking under their desk. Probably I'm entering.<|eot_id|>

<|start_header_id|>user<|end_header_id|>

rossholden: I'm entering our seven of Zoom. I can tell you I'm. My attention is drifting.
katrinacallsen: I move or do we have to make a motion you just we just have to tell you.
patrickmclau: I think we just need some consensus from the board, it sounds like we might have that around a half day. zoom meeting would we want to keep the the June 5 date, which was our original date for the retreat.
jonnoalcaro: I would be in favor of keeping the June 5th date. And on the half day issue, I agree with Kate that we need to potentially set another date for a part two, because at least some of the things that I'd like 

### Add special tokens to tokenizer

In [18]:
# Extract speaker tokens
def extract_speakers(text):
    return re.findall(r"^(?:speaker \d+|[a-zA-Z0-9_]+):", text, flags=re.MULTILINE)

speaker_counter = Counter()
for sample in train_data:
    speakers = extract_speakers(sample["text"])
    speaker_counter.update(speakers)

speaker_tokens = list(speaker_counter.keys())

# Add special tokens
special_tokens = {
    "additional_special_tokens": speaker_tokens + [
        "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"
    ]
}

tokenizer.add_special_tokens(special_tokens)


def debug_tokenization(example_text: str, tokenizer: PreTrainedTokenizerBase):
    tokens = tokenizer(example_text, return_tensors="pt", add_special_tokens=False)
    input_ids = tokens["input_ids"][0]
    decoded = [tokenizer.decode([tid]) for tid in input_ids]

    print("=== Tokenized Input ===")
    for i, (tid, tok) in enumerate(zip(input_ids, decoded)):
        print(f"{i:03}: {tid.item():>5}  ->  {repr(tok)}")

debug_tokenization(train_data[0]['text'], tokenizer)

=== Tokenized Input ===
000: 128000  ->  '<|begin_of_text|>'
001: 128006  ->  '<|start_header_id|>'
002:   882  ->  'user'
003: 128007  ->  '<|end_header_id|>'
004:   271  ->  '\n\n'
005: 128256  ->  'katrinacallser:'
006:  2100  ->  ' So'
007:   374  ->  ' is'
008:   430  ->  ' that'
009:  1057  ->  ' our'
010:  2218  ->  ' target'
011:    13  ->  '.'
012:  1120  ->  ' just'
013:   520  ->  ' at'
014:   279  ->  ' the'
015:  7314  ->  ' beginning'
016:   315  ->  ' of'
017:  7552  ->  ' February'
018:    13  ->  '.'
019:   358  ->  ' I'
020:  2846  ->  "'m"
021: 20910  ->  ' wondering'
022:   422  ->  ' if'
023:   430  ->  ' that'
024:   596  ->  "'s"
025:   279  ->  ' the'
026:  2218  ->  ' target'
027:   627  ->  '.\n'
028: 128257  ->  'patrickmclaughlin:'
029:  3011  ->  ' That'
030:   574  ->  ' was'
031:   279  ->  ' the'
032:  4113  ->  ' original'
033:  2218  ->  ' target'
034:   994  ->  ' when'
035:   584  ->  ' we'
036:  1051  ->  ' were'
037:  1701  ->  ' using'
038:  6790 

### Model

In [None]:
# Model    
torch_dtype = torch.bfloat16
quant_storage_dtype = torch.bfloat16

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_quant_storage=quant_storage_dtype,
)

model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
attn_implementation="sdpa", # use sdpa, alternatively use "flash_attention_2"
torch_dtype=quant_storage_dtype,
use_cache=False if training_args.gradient_checkpointing else True,  # this is needed for gradient checkpointing
)

ValueError: `rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {'factor': 8.0, 'low_freq_factor': 1.0, 'high_freq_factor': 4.0, 'original_max_position_embeddings': 8192, 'rope_type': 'llama3'}

In [7]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_data,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 300,
        learning_rate = 1e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.1,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
    ),
)



trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n"
)

Unsloth: Tokenizing ["text"] (num_proc=2): 100%|██████████| 202/202 [00:00<00:00, 203.88 examples/s]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Map (num_proc=128): 100%|██████████| 202/202 [00:01<00:00, 137.99 examples/s]


In [8]:
tokenizer.decode(trainer.train_dataset[5]["input_ids"])

"<|begin_of_text|><|begin_of_text|><|start_header_id|>assistant<|end_header_id|>\n\nkateacuff: Well, this. Let's see, will this be open to the public or school board members? Yes.<|eot_id|>\n\n<|start_header_id|>user<|end_header_id|>\n\nkwicks: You will have to request an invitation and get the Zoom link sent directly to you.\njonnoalcaro: Karen, I don't have a question, but I've already put it on my calendar, and I'd love to get an invitation.\nkwicks: Fantastic. We will get one out to each of you. That would be great.\nunknownspeaker: Yeah, I'd like to have one, too.\nbernardhairston: You can actually click on the link in the PowerPoint in front of you, and it will allow you to register.\njenniferjohnston: Okay, doesn't look like it's working, but all right. It's located on the announcement section of the electronic school board.\nbernardhairston: We will also be working with our communications department to advertise this through the media and also our sources of communicating, such

In [9]:
space = tokenizer(" ", add_special_tokens = False).input_ids[0]
tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5]["labels"]])

"      kateacuff: Well, this. Let's see, will this be open to the public or school board members? Yes.<|eot_id|>\n\n                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 "

In [10]:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 202 | Num Epochs = 6 | Total steps = 300
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 10,485,760/8,000,000,000 (0.13% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,1.2877



KeyboardInterrupt



## Single Evaluation

In [15]:
train_completion_data[0]['completion']

" She's down there. Oh, she's there with her. There she is."

In [16]:
train_completion_data[0]['prompt']

"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\ngrahampaige: Virginia Freedom of Information Act, Section 2-23711A of the Code of Virginia, under subsection 1, for the discussion, consideration, or interviews of prospective candidates for employment in the assignment, appointment, promotion, performance, demotion, salaries, disciplining, or resignation of public officers, appointees, or employees of any public body. Do I have a second?\nunknownspeaker: Second. We have a motion and a second. Ms. Johnston? Ms. Colson? Yes. Ms. Osborne? Yes. Dr. Acuff? Yes. Mr. Page? Yes. Mr. Alcaro? Yes.\njonnoalcaro: The motion passes. We are in closed session. See you on the other side.\nunknownspeaker: Bye. Hello again. Is that Dave's dog?\ngrahampaige: No, actually, I think that one of those is mine. Oh, OK. We have four 19 and 20-year-olds hanging out doing a bonfire tonight.\nunknownspeaker: Oh, OK.\njonnoalcaro: And so this is going to excite the dogs all evening long. So I'm going t

In [None]:
input_text = train_completion_data[5]['prompt']

inputs = tokenizer(
    input_text,  # your full prompt string
    return_tensors="pt"
).to(model.device)

# Create a streamer to print output as it's generated (optional)
text_streamer = TextStreamer(tokenizer, skip_prompt=True)

# Generate a response
_ = model.generate(
    **inputs,
    streamer=text_streamer,
    max_new_tokens=128,
    use_cache=True,
    temperature=1.5,
    top_p=0.9  # (use `top_p`, not `min_p`)
)

 The goal in making a prediction about vaccine development is a fascinating area of modeling, isn't it? So, I'd love to help you dive deeper, perhaps to explore modeling complex interactions or making predictive decisions with some level of uncertainty. I'm more than just a statistical tool - I'm here to assist in various domains!

When modeling vaccination outcomes, a crucial approach is to use machine learning or simulate dynamic models using techniques, such as, say agent based modeling using techniques like ABM in MATLAB.

A crucial area of concern here might be how many people to vaccinate, to stop any particular infection from spreading. So, a model could


### Evaluation

In [31]:
import evaluate

def compute_perplexity(
    model,
    dataset,
    tokenizer,
    max_length=1024,
    device=None,
    verbose=True):
    """
    Compute average perplexity of model predicting 'completion' given 'prompt'.

    Parameters:
    - model: causal LM
    - dataset: list of dicts with keys 'prompt' and 'completion'
    - tokenizer: tokenizer matching model
    - max_length: max tokens to feed into model
    - device: torch device
    - verbose: if True, print per-example details

    Returns:
    - perplexity (float)
    - generated_texts (list of dicts with 'prompt', 'completion', 'decoded_completion')
    """
    model.eval()
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    model.to(device)

    total_loss = 0.0
    total_tokens = 0
    
    generated_texts = []
    reference_texts = []
    
    if tokenizer.pad_token is None:
        print("No pad token set, assigning eos_token as pad_token")
        tokenizer.pad_token = tokenizer.eos_token

    for idx, example in enumerate(tqdm(dataset, desc="Computing perplexity")):
        prompt_text = example['prompt']
        completion_text = example['completion']

        prompt_tokens = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids
        completion_tokens = tokenizer(completion_text, return_tensors="pt", add_special_tokens=False).input_ids

        input_ids = torch.cat([prompt_tokens, completion_tokens], dim=1)

        if input_ids.size(1) > max_length:
            input_ids = input_ids[:, -max_length:]

        prompt_len = prompt_tokens.size(1)
        if input_ids.size(1) < prompt_len:
            prompt_len = input_ids.size(1)

        labels = input_ids.clone()
        labels[:, :prompt_len] = -100

        input_ids = input_ids.to(device)
        labels = labels.to(device)
        attention_mask = (input_ids != tokenizer.pad_token_id).long()

        try:
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
        except Exception as e:
            print(f"\n❌ Skipping example {idx + 1} due to model error: {e}")
            print(f"input_ids shape: {input_ids.shape}, prompt_len: {prompt_len}")
            continue

        completion_token_count = (labels != -100).sum().item()
        total_loss += loss.item() * completion_token_count
        total_tokens += completion_token_count

        decoded_completion = tokenizer.decode(input_ids[0, prompt_len:], skip_special_tokens=True)

        # Save to generated_texts
        generated_texts.append(decoded_completion)
        reference_texts.append(completion_text)


    if total_tokens == 0:
        print("⚠️ No valid completion tokens found, returning inf perplexity")
        return float('inf'), generated_texts

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)

    print(f"\n✅ Overall average loss per token: {avg_loss:.4f}")
    print(f"✅ Overall perplexity: {perplexity:.2f}")

    return perplexity, generated_texts, reference_texts

def compute_metrics(generated_texts, reference_texts):

    # Compute metrics
    bleu_score = bleu.compute(predictions=generated_texts, references=[[r] for r in reference_texts])
    rouge_score = rouge.compute(predictions=generated_texts, references=reference_texts)
    bertscore_result = bertscore.compute(predictions=generated_texts, references=reference_texts, lang="en")

    # Average BERTScore F1
    avg_bertscore_f1 = sum(bertscore_result['f1']) / len(bertscore_result['f1'])
            
    return bleu_score, rouge_score, bertscore_result, avg_bertscore_f1

In [28]:
#test_data, train_completion_data
perplexity, generated_texts, reference_texts = compute_perplexity(model,
                                                                train_completion_data,
                                                                tokenizer,
                                                                max_length=1024,
                                                                device=None,
                                                                verbose=False)

Computing perplexity:  24%|██▍       | 59/242 [00:07<00:22,  8.17it/s]


❌ Skipping example 58 due to model error: 
input_ids shape: torch.Size([1, 863]), prompt_len: 837


Computing perplexity:  28%|██▊       | 68/242 [00:08<00:21,  8.28it/s]


❌ Skipping example 67 due to model error: 
input_ids shape: torch.Size([1, 851]), prompt_len: 740


Computing perplexity: 100%|██████████| 242/242 [00:28<00:00,  8.64it/s]


✅ Overall average loss per token: 0.6880
✅ Overall perplexity: 1.99





In [None]:
import evaluate
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
bleu_score, rouge_score, bertscore_result, avg_bertscore_f1 = compute_metrics(generated_texts, reference_texts)

FileNotFoundError: Couldn't find a module script at /playpen-ssd/smerrill/llm_decisions/notebooks/bertscore/bertscore.py. Module 'bertscore' doesn't exist on the Hugging Face Hub either.