In [1]:
%reload_ext autoreload
%autoreload 2

## Train

In [2]:
import os
import torch
import logging
from transformers import (
    Trainer,
    TrainingArguments,
    AutoTokenizer,
)
from opencoconut import (
    AutoCoconutForCausalLM,
    CoconutConfig,
    CoTDataset,
)
from pathlib import Path
import gc

# Configure logging
from loguru import logger

def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()

def get_device():
    if torch.backends.mps.is_available():
        return "mps"
    elif torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"

In [3]:
config_debug = {
    # 'model_name': "Qwen/Qwen2.5-0.5B",
    # 'model_name': 'axel-datos/qwen2.5-0.5b-instruct_MATH_lisa',
    'model_name': 'plaguss/Qwen2.5-0.5B-Math-Shepherd-PRM-0.1',
    'batch_size': 5,
    'learning_rate': 4e-4,
    'samples_per_epoch': 12,
    'output_dir': "./output/small",
    'num_epochs': 1,
}
config_small = {
    'model_name': "Qwen/Qwen2.5-0.5B",
    'batch_size': 5,
    'learning_rate': 1e-4,
    'samples_per_epoch': 1000,
    'output_dir': "./output/small",
    'num_epochs': 1,
}
config_medium = {
    'model_name': "Qwen/Qwen2.5-2.5B",
    'batch_size': 1,
    'learning_rate': 5e-5,
    'samples_per_epoch': 30000,
    'output_dir': "./output/small",
    'num_epochs': 3,
}
config = config_small

DEBUG = False
if DEBUG:
    config = config_debug
    os.environ["DEBUG"] = "1"

config

{'model_name': 'Qwen/Qwen2.5-0.5B',
 'batch_size': 5,
 'learning_rate': 0.0001,
 'samples_per_epoch': 1000,
 'output_dir': './output/small',
 'num_epochs': 1}

## Load model

In [4]:
tokenizer = AutoTokenizer.from_pretrained(config['model_name'], padding_side="left")

In [5]:
tokenizer.get_chat_template()

'{%- if tools %}\n    {{- \'<|im_start|>system\\n\' }}\n    {%- if messages[0][\'role\'] == \'system\' %}\n        {{- messages[0][\'content\'] }}\n    {%- else %}\n        {{- \'You are a helpful assistant.\' }}\n    {%- endif %}\n    {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}\n    {%- for tool in tools %}\n        {{- "\\n" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}\n    {%- if messages[0][\'role\'] == \'system\' %}\n        {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}\n    {%- else %}\n        {{- \'<|im_start|>system\\nYou are a he

In [6]:


# Initialize model and tokenizer
logger.info("Initializing model and tokenizer")
output_dir = Path(config['output_dir'])

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['model_name'], padding_side="left")
tokenizer.bos_token = "<|im_start|>"
tokenizer.eos_token = "<|im_end|>"
tokenizer.pad_token = "<|endoftext|>"

# config and model
# FIXME we wont need custom tokens anymore
coconut_config = CoconutConfig.from_tokenizer(tokenizer)
model = AutoCoconutForCausalLM.from_pretrained(
    config['model_name'], coconut_config, torch_dtype=torch.bfloat16, device_map=get_device(), 
)
# model.resize_token_embeddings(len(tokenizer))
model.tokenizer = tokenizer

# Set up training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=config['batch_size'],
    gradient_accumulation_steps=1,
    learning_rate=config['learning_rate'],
    num_train_epochs=config['num_epochs'],
    warmup_ratio=0.1,
    max_steps=config['samples_per_epoch']//config['batch_size']*config['num_epochs'],
    logging_steps=20, # TODO ideally we log to tensorboard every step, but to ui every 100 steps
    save_steps=10000,
    bf16=True,
    bf16_full_eval=True,
    optim="adamw_torch", # save memory: adamw_bnb_8bit adamw_bnb_8bit or PagedAdamW8bit
)


[32m2025-01-16 09:42:59.061[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mInitializing model and tokenizer[0m
Some weights of CoconutQwen2ForCausalLM were not initialized from the model checkpoint at Qwen/Qwen2.5-0.5B and are newly initialized: ['switch.0.weight', 'switch.1.bias', 'switch.1.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Infer

In [7]:
from transformers import AutoTokenizer, TextStreamer
streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)

prompt = "John cuts his grass to 2 inches. " \
         "It grows .5 inches per month. " \
         "When it gets to 4 inches he cuts it back down to 2 inches. " \
         "It cost $100 to get his grass cut. How much does he pay per year?"

ans = """
# since it starts at 2 and never gets cut below 2, we can consider only the extra growth
growth_annual = 0.5*12
cost_per_inch = 100/2
cuts = growth_annual // 2 # round it down
cost_per_year = growth_annual * cost_per_inch
print(f"cost per year: {cost_per_year}==300.0")
"""

# inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# print('## With out thought token')
# outputs1 = model.generate(
#     **inputs,
#     do_sample=True,
#     max_new_tokens=256,
#     streamer=streamer,
#     eos_token_id=tokenizer.eos_token_id,
# )


def infer_example(model, tokenizer, prompt=prompt, ans="300"):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    print('---BEGIN---')
    outputs = model.generate(
        **inputs,
        do_sample=False,
        max_new_tokens=64,
        streamer=streamer,
        eos_token_id=tokenizer.eos_token_id,
    )
    print('---END---')
    # print(f"Prompt: {prompt}")
    # print(f"Answer: {ans}")
    # print(f"Generated: {outputs[0]['generated_text']}")
    # print()
    return outputs

infer_example(model, tokenizer)

Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


---BEGIN---
John cuts his grass to 2 inches. It grows .5 inches per month. When it gets to 4 inches he cuts it back down to 2 inches. It cost $100 to get his grass cut. How much does he pay per year?En el lado, 19999999999999999999999999999999999999999999999999999999999
---END---


tensor([[13079, 15104,   806, 16359,   311,   220,    17, 14924,    13,  1084,
         27715,   659,    20, 14924,   817,  2254,    13,  3197,   432,  5221,
           311,   220,    19, 14924,   566, 15104,   432,  1182,  1495,   311,
           220,    17, 14924,    13,  1084,  2783,   400,    16,    15,    15,
           311,   633,   806, 16359,  3931,    13,  2585,  1753,  1558,   566,
          2291,   817,  1042,    30,  1702,   655, 43424,    11,   220,    16,
            24,    24,    24,    24,    24,    24,    24,    24,    24,    24,
            24,    24,    24,    24,    24,    24,    24,    24,    24,    24,
            24,    24,    24,    24,    24,    24,    24,    24,    24,    24,
            24,    24,    24,    24,    24,    24,    24,    24,    24,    24,
            24,    24,    24,    24,    24,    24,    24,    24,    24,    24,
            24,    24,    24,    24,    24,    24,    24,    24]],
       device='cuda:0')

## Base eval

In [8]:
from opencoconut.evaluate import evaluate, extract_generated_answer
from torch.utils.data import DataLoader

# Load val data
dataset_val = CoTDataset(
    "casperhansen/gsm8k_synthetic_cot",
    tokenizer,
    max_length=256,
    coconut_config=model.coconut_config,
    current_stage=model.coconut_config.stages,
    split="valid[:64]" if DEBUG else "valid",
)
dataset_val[0]
dataloader_val = DataLoader(dataset_val, batch_size=config['batch_size'], shuffle=False)
max_new_tokens = 6 if DEBUG else 64
# eval final model
# accuracy = evaluate(dataloader_val, tokenizer, model, max_new_tokens)
# print(f"Accuracy: {accuracy}")

### QC

In [9]:
if DEBUG:
    # QC forward: part 1
    clear_memory()
    dataset = CoTDataset(
        "casperhansen/gsm8k_synthetic_cot",
        tokenizer,
        max_length=256, # all less than 256, most < 128
        coconut_config=coconut_config,
        current_stage=0,
        split=f"train[:{config['samples_per_epoch']}]",
    )
    dl = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'])
    batch = next(iter(dl))
    batch = {k: v.to(get_device()) for k, v in batch.items()}
    print(batch.keys(), batch['input_ids'].shape)

In [10]:
if DEBUG:
    # QC forward: part 2
    with torch.no_grad():
        model.eval()
        o = model(**batch, output_hidden_states=True)
        print(o.keys())
        model.train()
        o = model(**batch, output_hidden_states=True)
        print(o.keys(), o['loss'].shape)
        print(o['loss'].item())

In [11]:
# if DEBUG:
#     # QC train
#     clear_memory()
#     dataset = CoTDataset(
#         "casperhansen/gsm8k_synthetic_cot",
#         tokenizer,
#         max_length=256, # all less than 256, most < 128
#         coconut_config=coconut_config,
#         current_stage=1,
#         split=f"train[:{config['samples_per_epoch']}]",
#     )
#     trainer = Trainer(
#         model=model,
#         args=training_args,
#         train_dataset=dataset,
#     )
#     trainer.train()

### Train

In [12]:
# save base checkpoint
if not DEBUG:
    current_output_dir = output_dir/"stage0}"
    current_output_dir = current_output_dir/"checkpoint-base"
    model.save_pretrained(current_output_dir)
    tokenizer.save_pretrained(current_output_dir)

In [13]:
# torch.autograd.set_detect_anomaly(True)

In [None]:

# Initialize trainer
stage = 0
logger.info("preparing dataset")
dataset = CoTDataset(
    "casperhansen/gsm8k_synthetic_cot",
    tokenizer,
    max_length=256, # all less than 256, most < 128
    coconut_config=coconut_config,
    current_stage=stage,
    split=f"train[:{config['samples_per_epoch']}]",
)
logger.info(f"dataset size: {len(dataset)}")

current_output_dir = output_dir/f"stage{stage}"
training_args.output_dir = current_output_dir

from transformers import TrainerCallback
class GenExCallback(TrainerCallback):

    # on_save on_epoch_start on_log 

    def on_epoch_end(self, args, state, control, logs=None, **kwargs):
        if state.is_local_process_zero:
            outs = infer_example(self.model, self.model.tokenizer)

    def on_log(self, args, state, control, logs=None, **kwargs):
        # _ = logs.pop("total_flos", None)
        if state.is_local_process_zero:
            outs = infer_example(self.model, self.model.tokenizer)
            print(logs)


infer_example(model, model.tokenizer)
logger.info("starting training")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    callbacks=[GenExCallback],
)
trainer.train()

# save tokenizer to all checkpoints after training
for folder in os.listdir(current_output_dir):
    if folder.startswith("checkpoint-"):
        checkpoint_folder = os.path.join(current_output_dir, folder)
        if os.path.isdir(checkpoint_folder):
            tokenizer.save_pretrained(checkpoint_folder)

# save final checkpoint
current_output_dir = current_output_dir/"checkpoint-final"
model.save_pretrained(current_output_dir)
tokenizer.save_pretrained(current_output_dir)
logger.info(f"finished stage {stage}. Saved to {current_output_dir}")

clear_memory()

# TODO add callback for gen

[32m2025-01-16 09:44:43.763[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mpreparing dataset[0m
[32m2025-01-16 09:44:47.725[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mdataset size: 1000[0m


---BEGIN---


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


John cuts his grass to 2 inches. It grows .5 inches per month. When it gets to 4 inches he cuts it back down to 2 inches. It cost $100 to get his grass cut. How much does he pay per year?En el lado, 19999999999999999999999999999999999999999999999999999999999


[32m2025-01-16 09:58:29.241[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m35[0m - [1mstarting training[0m


---END---


Step,Training Loss


## Infer

In [12]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [13]:
# model = model.cpu() # HACK

In [None]:

# add beginning of thought token?
# inputs['input_ids'], inputs['attention_mask'] = model.append_bot_token(inputs['input_ids'], inputs['attention_mask'])

# print('\n## With thought token')
# outputs2 = model.generate(
#     **inputs,
#     do_sample=True,
#     max_new_tokens=64,
#     streamer=streamer,
#     eos_token_id=tokenizer.eos_token_id,
# )

## Eval

In [15]:
from transformers import (
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from opencoconut import AutoCoconutForCausalLM, CoTDataset, split_sequences
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [16]:
tokenizer.padding_side = "left"
model.tokenizer.padding_side = "left"

In [None]:
# eval final model
accuracy = evaluate(dataloader_val, tokenizer, model, max_new_tokens)
print(f"Accuracy: {accuracy}")

In [None]:
# # eval all check, we want to see acc increasing from the first, which is the base model

# # FIXME loading seems broken :(

# checkpoints = sorted(output_dir.glob("stage*/*base")) + sorted(output_dir.glob("stage*/*final"))

# for checkpoint in checkpoints:
#     print(f"Loading checkpoint: {checkpoint}")
#     model = AutoCoconutForCausalLM.from_pretrained(
#         checkpoint, torch_dtype=torch.bfloat16, device_map=get_device()
#     ).eval()
#     model.tokenizer = tokenizer
#     accuracy = evaluate(dataloader_val, tokenizer, model, max_new_tokens)
#     print(f"Checkpoint: {checkpoint}, Accuracy: {accuracy}")
#     clear_memory()