In [1]:
import numpy
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai import OpenAI
from datasets import load_from_disk, Dataset
from document_store import DocumentStore, E5EmbeddingFunction
from dotenv import load_dotenv
import os, json
from tqdm import tqdm
import pickle
from time import time
load_dotenv()

True

In [2]:
dataset = load_from_disk('dataset_curated')

In [4]:
dataset['train'][0]

{'document_id': '0013205',
 'document_source': 'MPlusHealthTopics',
 'document_url': 'https://www.nlm.nih.gov/medlineplus/a1c.html',
 'category': 'Other',
 'umls_cui': None,
 'umls_semantic_types': None,
 'umls_semantic_group': None,
 'synonyms': 'Glycohemoglobin|HbA1C|Hemoglobin A1C test',
 'question_id': '0000001-1',
 'question_focus': 'A1C',
 'question_type': 'information',
 'question': 'Do you have information about A1C',
 'answer': 'Summary : A1C is a blood test for type 2 diabetes and prediabetes. It measures your average blood glucose, or blood sugar, level over the past 3 months. Doctors may use the A1C alone or in combination with other diabetes tests to make a diagnosis. They also use the A1C to see how well you are managing your diabetes. This test is different from the blood sugar checks that people with diabetes do every day.    Your A1C test result is given in percentages. The higher the percentage, the higher your blood sugar levels have been:       - A normal A1C level 

In [2]:
client = OpenAI(
    api_key=os.environ.get('FIREWORKS_API_KEY'),
    base_url="https://api.fireworks.ai/inference/v1"
)

tools = [{
            "type": "function",
            "function": {
                "name": "search",
                "description": "Search documents in the document store with a detailed long query.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "query": {"type": "string", "description": "Query to search documents"},
                    },
                    "required": ["query"]
                }
            }
        }]

doc_store = DocumentStore(
    collection_name="documents",
    chroma_db_path="./chroma_db",
    bm25_index_path="./bm25_index",
    document_content_dir="./document_content",
    device="cuda"
)


In [6]:
def get_model_response(prompt):
    """Get response from strong model with tool calls"""
    try:

        messages = [
            {
                'role': "user",
                'content': prompt
            }
        ]
        
        # Allow up to 5 conversation turns (same as llm_rollout)
        for _ in range(5):
            if isinstance(messages[-1], ChatCompletionMessage) and messages[-1].tool_calls:
                # Handle tool calls (document search)
                messages.append({
                    'role': 'tool',
                    'tool_call_id':messages[-1].tool_calls[0].id,
                    'content': json.dumps(
                        doc_store.search(
                            **json.loads(messages[-1].tool_calls[0].function.arguments),
                            n_results=5
                        )
                    )
                })
                continue
            
            if isinstance(messages[-1], dict):
                # Get model response
                chat_completion = client.chat.completions.create(
                    model="accounts/fireworks/models/deepseek-v3",  # Using strong model for distillation
                    messages=messages,
                    tools=tools,
                )
                response = chat_completion.choices[0].message
                messages.append(response)
                continue
        
        return messages  # Return full conversation history like llm_rollout
        
    except Exception as e:
        raise e
        print(f"Error getting model response: {e}")
        return None

In [6]:
messages = get_model_response(dataset['train'][0]["question"])

In [18]:
rollout_results = []

for index, example in tqdm(enumerate(dataset['train']), desc="Processing examples"):
    if index < 534:
        continue
    prompt = example["question"]
    
    # Get model response
    time_start = time()
    messages = get_model_response(prompt)
    if messages is None:
        print(index)
        continue
        
    rollout_results.append({
        'example': example,
        'messages': messages,
        'time_elapsed': time()-time_start
    })

    if (index % 100 == 0 and index > 0) or index == len(dataset['train']) - 1:
        with open(f'sft_rollout/{index}.pkl', 'wb') as f:
            pickle.dump(rollout_results, f)
            rollout_results = []

Processing examples: 872it [3:39:44, 15.12s/it] 


In [21]:
len(rollout_results)

0

In [3]:
all_rollout_result = []

for file in os.listdir("sft_rollout"):
    with open(os.path.join("sft_rollout", file), 'rb') as f:
        result = pickle.load(f)
        print(len(result))
        all_rollout_result.extend(result)

100
100
67
100
34
71
100
100
100
1


In [4]:
# deduplicate 
all_rollout_result = {question: result for question, result in zip([element['example']['question'] for element in all_rollout_result], all_rollout_result)}

In [5]:
all_rollout_result = list(all_rollout_result.values())

In [6]:
from llm_rollout import get_reward_functions

In [7]:
reward_functions = get_reward_functions(doc_store.embedding_function)

In [8]:
def process_example(example):
    rewards = {}
    for name, func in reward_functions.items():
        rewards[name] = func(example['example'], example['messages'])

    return rewards

In [9]:
for example in tqdm(all_rollout_result, desc="Computing Reward"):
    example['reward'] = process_example(example)

Computing Reward: 100%|██████████| 772/772 [00:35<00:00, 21.54it/s]


In [10]:
sft_dataset = [a for a in all_rollout_result if a['reward']['mrr'] > 0]

In [11]:
def hot_fix(example):
    messages = example['messages']
    for message in messages:
        if isinstance(message, dict) and message.get('role') == 'tool':
            content = json.loads(message['content'])
            content['ids'] = content['ids'][:5]
            content['documents'] = content['documents'][:5]
            message['content'] = json.dumps(content)
    example['messages'] = messages
    return example
    

In [12]:
sft_dataset = [hot_fix(example) for example in sft_dataset]

In [13]:
sft_dataset = [example for example in sft_dataset if example['reward']['mrr'] >= 0.2]

In [18]:
len(sft_dataset)

57

In [1]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-3B-Instruct", use_fast=True)

with open('tool_chat_template_llama3.2_pythonic.jinja', 'r') as f:
    chat_template = f.read()

tokenizer.chat_template = chat_template

In [15]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-7B-Instruct-bnb-4bit", use_fast=True)

In [17]:
tokenizer.special_tokens_map

{'eos_token': '<|im_end|>',
 'pad_token': '<|endoftext|>',
 'additional_special_tokens': ['<|im_start|>',
  '<|im_end|>',
  '<|object_ref_start|>',
  '<|object_ref_end|>',
  '<|box_start|>',
  '<|box_end|>',
  '<|quad_start|>',
  '<|quad_end|>',
  '<|vision_start|>',
  '<|vision_end|>',
  '<|vision_pad|>',
  '<|image_pad|>',
  '<|video_pad|>']}

In [16]:
def format_messages_for_tokenizer(messages):
    result_messages = []
    for message in messages:
        if isinstance(message, dict):
            result_messages.append(message)
        else:
            if message.tool_calls:
                result_messages.append({
                    'role':'assistant',
                    'tool_calls':[
                        {
                            'function': {
                                'arguments': json.loads(tool_call.function.arguments),
                                'name': tool_call.function.name
                            }
                        } for tool_call in message.tool_calls
                    ]
                })
            else:
                result_messages.append({
                    'role':'assistant',
                    'content':message.content
                })
    return result_messages

In [16]:
from rich import print as rprint

In [4]:
tokenizer.apply_chat_template(
    format_messages_for_tokenizer(sft_dataset[0]['messages']),
    tokenize=False,
    tools=tools
).split('<|eot_id|>')

NameError: name 'sft_dataset' is not defined

In [22]:
tokenizer.encode('<|im_end|>', add_special_tokens=False)

[151645, 198]

In [None]:
# curate the SFT dataset llama
input_ids = []
labels = []

for example in sft_dataset:
    templated_text = tokenizer.apply_chat_template(
        format_messages_for_tokenizer(example['messages']),
        tokenize=False,
        tools=tools
    )

    templated_text_splitted = [x + '<|eot_id|>' for x in templated_text.split('<|eot_id|>') if x != '\n']

    _input_ids = []
    _lab = []

    for t in templated_text_splitted:
        tokens = tokenizer.encode(t, add_special_tokens=False)
        _input_ids.extend(tokens)
        if '' in t or '<|im_start|>user\n<tool_response>' in t:
            _completion_masks.extend([0]*len(tokens))
        else:
            _completion_masks.extend([1]*len(tokens))
    
    input_ids.append(_input_ids)
    completion_masks.append(_completion_masks)

In [19]:
# curate the SFT dataset qwen
input_ids = []
labels = []

for example in tqdm(sft_dataset):
    templated_text = tokenizer.apply_chat_template(
        format_messages_for_tokenizer(example['messages']),
        tokenize=False,
        tools=tools
    )

    templated_text_splitted = [x + '<|im_end|>' for x in templated_text.split('<|im_end|>') if x != '\n']

    _input_ids = []
    _completion_masks = []

    for t in templated_text_splitted:
        tokens = tokenizer.encode(t, add_special_tokens=False)
        _input_ids.extend(tokens)
        if '<|im_start|>system' in t or '<|im_start|>user\n<tool_response>' in t:
            _completion_masks.extend([0]*len(tokens))
        else:
            _completion_masks.extend([1]*len(tokens))
    
    _labels = _input_ids[1:] + [-100]
    for index, mask in enumerate(_completion_masks):
        if mask == 0:
            _labels[index] = -100
    
    input_ids.append(_input_ids)
    labels.append(_labels)

100%|██████████| 57/57 [00:00<00:00, 164.15it/s]


In [20]:
tokenizer.apply_chat_template(
    format_messages_for_tokenizer(sft_dataset[0]['messages']),
    tools=tools
)[:-1] == input_ids[0]

True

In [23]:
len(input_ids[0])

3169

In [24]:
Dataset.from_dict(
    {
        'input_ids':input_ids,
        'labels':labels
    }
).save_to_disk('sft_dataset')

Saving the dataset (0/1 shards):   0%|          | 0/57 [00:00<?, ? examples/s]

In [54]:
from datasets import load_from_disk

In [55]:
dataset = load_from_disk('sft_dataset')

In [58]:
dataset[0]['completion_masks'][-10:]

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [1]:
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM, LoraConfig, TaskType, get_peft_model
from torch.nn import CrossEntropyLoss
from datasets import load_from_disk

In [16]:
model = AutoModelForCausalLM.from_pretrained(
    "Llama-3.2-3B-Instruct",
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [17]:
# Define LoRA Config
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.1,
    bias='none',
    task_type=TaskType.CAUSAL_LM
)

# Get PEFT model
model = get_peft_model(model, lora_config)

# Print trainable parameters
model.print_trainable_parameters()

trainable params: 12,156,928 || all params: 3,224,906,752 || trainable%: 0.3770


In [5]:
model = AutoPeftModelForCausalLM.from_pretrained('Llama-3.2-3B-Instruct-lora', is_trainable=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
model.print_trainable_parameters()

trainable params: 12,156,928 || all params: 3,224,906,752 || trainable%: 0.3770


In [7]:
dataset = load_from_disk('sft_dataset').with_format("torch")

In [10]:
dataset[:2]

{'input_ids': [tensor([128000, 128006,   9125,  ...,  32653,     13, 128009]),
  tensor([128000, 128006,   9125,  ...,   1862,     13, 128009])],
 'completion_masks': [tensor([0, 0, 0,  ..., 1, 1, 1]),
  tensor([0, 0, 0,  ..., 1, 1, 1])]}

: 

In [9]:
model(input_ids=dataset[0]['input_ids'].reshape(1,-1))

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 1.8155,  1.9450,  7.3831,  ..., -1.2538, -1.2538, -1.2535],
         [-8.8965, -6.4822, -5.3258,  ...,  9.5690,  9.5682,  9.5681],
         [ 4.3489,  2.7922,  3.8053,  ..., -1.4698, -1.4689, -1.4697],
         ...,
         [19.1167,  5.4830,  2.1013,  ...,  2.0884,  2.0887,  2.0883],
         [ 3.9380, -1.7175,  1.4737,  ...,  3.5807,  3.5820,  3.5811],
         [-2.8495, -1.6710, -3.4438,  ...,  7.2351,  7.2344,  7.2331]]]), past_key_values=DynamicCache(), hidden_states=None, attentions=None)

In [1]:
from datasets import Dataset
from random import randint

In [3]:
input_ids = [[randint(0, 128000) for _ in range(10000)] for _ in range(1000)]

In [5]:
Dataset.from_dict({
    'input_ids': input_ids,
    'labels': [input_id[:-1] + [-100] for input_id in input_ids]
}).save_to_disk('dummy_dataset')

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]