In [1]:
%reload_ext autoreload
%autoreload 2

## Train

In [2]:
import os
import torch
import logging
from transformers import (
    Trainer,
    TrainingArguments,
    AutoTokenizer,
)
from opencoconut import (
    AutoCoconutForCausalLM,
    CoconutConfig,
    CoTDataset,
)
from pathlib import Path
import gc

# Configure logging
from loguru import logger

def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()

def get_device():
    if torch.backends.mps.is_available():
        return "mps"
    elif torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"

In [3]:
config_debug = {
    'model_name': "Qwen/Qwen2.5-0.5B",
    'batch_size': 12,
    'learning_rate': 4e-4,
    'samples_per_epoch': 48,
    'output_dir': "./output/small",
    'num_epochs': 1,
}
config_small = {
    'model_name': "Qwen/Qwen2.5-0.5B",
    'batch_size': 12,
    'learning_rate': 1e-4,
    'samples_per_epoch': 2000,
    'output_dir': "./output/small",
    'num_epochs': 3,
}
config_medium = {
    'model_name': "Qwen/Qwen2.5-2.5B",
    'batch_size': 1,
    'learning_rate': 5e-5,
    'samples_per_epoch': 30000,
    'output_dir': "./output/small",
    'num_epochs': 3,
}
config = config_small

DEBUG = True
if DEBUG:
    config = config_debug
    os.environ["DEBUG"] = "1"

config

{'model_name': 'Qwen/Qwen2.5-0.5B',
 'batch_size': 12,
 'learning_rate': 0.0004,
 'samples_per_epoch': 48,
 'output_dir': './output/small',
 'num_epochs': 1}

In [4]:


# Initialize model and tokenizer
logger.info("Initializing model and tokenizer")
output_dir = Path(config['output_dir'])

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
tokenizer.bos_token = "<|im_start|>"
tokenizer.eos_token = "<|im_end|>"
tokenizer.pad_token = "<|endoftext|>"

# config and model
coconut_config = CoconutConfig.from_tokenizer(
    tokenizer,
    stages=4,
    continuous_thoughts=2,
)
model = AutoCoconutForCausalLM.from_pretrained(
    config['model_name'], coconut_config, torch_dtype=torch.bfloat16, device_map=get_device()
)
model.resize_token_embeddings(len(tokenizer))
if os.getenv("DEBUG", "0") == "1":
    model.tokenizer = tokenizer

# Set up training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=config['batch_size'],
    gradient_accumulation_steps=1,
    learning_rate=config['learning_rate'],
    warmup_ratio=0.1,
    max_steps=config['samples_per_epoch']//config['batch_size']*config['num_epochs'],
    logging_steps=100, # TODO ideally we log to tensorboard every step, but to ui every 100 steps
    save_steps=10000,
    bf16=True,
    bf16_full_eval=True,
    optim="adamw_torch", # save memory: adamw_bnb_8bit
)


[32m2025-01-14 20:23:30.528[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mInitializing model and tokenizer[0m


In [5]:
# save base checkpoint
if not DEBUG:
    current_output_dir = output_dir/"stage0}"
    current_output_dir = current_output_dir/"checkpoint-base"
    model.save_pretrained(current_output_dir)
    tokenizer.save_pretrained(current_output_dir)

In [6]:
# QC test
clear_memory()
dataset = CoTDataset(
    "casperhansen/gsm8k_synthetic_cot",
    tokenizer,
    max_length=256, # all less than 256, most < 128
    coconut_config=coconut_config,
    current_stage=1,
    split=f"train[:{config['samples_per_epoch']}]",
)
dl = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'])
batch = next(iter(dl))
batch = {k: v.to(get_device()) for k, v in batch.items()}
print(batch.keys(), batch['input_ids'].shape)

dict_keys(['input_ids', 'attention_mask', 'labels']) torch.Size([12, 256])


In [None]:
# QC test
for i in [0, 1, 2]:
    model.current_stage = i
    print(f"Stage {i}")
    with torch.no_grad():
        # model.eval()
        # o = model(**batch, output_hidden_states=True)
        # print(o.keys())
        model.train()
        o = model(**batch, output_hidden_states=True)
        print(o.keys(), o['loss'].shape)
        print(o['loss'].item())

Stage 0
odict_keys(['loss', 'logits', 'past_key_values', 'hidden_states']) torch.Size([])
3.0440244674682617
Stage 1
kv0 torch.Size([12, 2, 42, 64])
kv.0-a torch.Size([12, 2, 42, 64]) 42


In [None]:
o.hidden_states

In [None]:
# QC test
clear_memory()
dataset = CoTDataset(
    "casperhansen/gsm8k_synthetic_cot",
    tokenizer,
    max_length=256, # all less than 256, most < 128
    coconut_config=coconut_config,
    current_stage=1,
    split=f"train[:{config['samples_per_epoch']}]",
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

In [None]:

# Initialize trainer
for stage in range(coconut_config.stages):
    clear_memory()
    logger.info(f"starting stage {stage}")
    logger.info("preparing dataset")
    dataset = CoTDataset(
        "casperhansen/gsm8k_synthetic_cot",
        tokenizer,
        max_length=256, # all less than 256, most < 128
        coconut_config=coconut_config,
        current_stage=stage,
        split=f"train[:{config['samples_per_epoch']}]",
    )
    logger.info(f"dataset size: {len(dataset)}")
    model.current_stage = stage
    current_output_dir = output_dir/f"stage{stage}"
    training_args.output_dir = current_output_dir

    if stage == 0:
        training_args.num_train_epochs = config['num_epochs']
    elif stage == coconut_config.stages-2:
        # Penultimate stage removes all the remaining language reasoning chain
        # This handles the long-tail distribution of reasoning chains longer than 3 steps
        dataset.include_reasoning_steps = False
        training_args.num_train_epochs = config['num_epochs']
    elif stage == coconut_config.stages-1:
        # For all datasets, after the standard schedule,
        # the model stays in the final training stage, until the 50th epoch.
        dataset.include_reasoning_steps = True
        training_args.num_train_epochs = config['num_epochs'] * 3

    logger.info("starting training")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()

    # save tokenizer to all checkpoints after training
    for folder in os.listdir(current_output_dir):
        if folder.startswith("checkpoint-"):
            checkpoint_folder = os.path.join(current_output_dir, folder)
            if os.path.isdir(checkpoint_folder):
                tokenizer.save_pretrained(checkpoint_folder)

    # save final checkpoint
    current_output_dir = current_output_dir/"checkpoint-final"
    model.save_pretrained(current_output_dir)
    tokenizer.save_pretrained(current_output_dir)
    logger.info(f"finished stage {stage}. Saved to {current_output_dir}")

    clear_memory()

## Infer

In [None]:
from transformers import AutoTokenizer, TextStreamer
streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)

prompt = "John cuts his grass to 2 inches. " \
         "It grows .5 inches per month. " \
         "When it gets to 4 inches he cuts it back down to 2 inches. " \
         "It cost $100 to get his grass cut. How much does he pay per year?"

ans = """
# since it starts at 2 and never gets cut below 2, we can consider only the extra growth
growth_annual = 0.5*12
cost_per_inch = 100/2
cuts = growth_annual // 2 # round it down
cost_per_year = growth_annual * cost_per_inch
print(f"cost per year: {cost_per_year}==300.0")
"""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

print('## With out thought token')
outputs1 = model.generate(
    **inputs,
    do_sample=True,
    max_new_tokens=64,
    streamer=streamer,
    eos_token_id=tokenizer.eos_token_id,
)

# add beginning of thought token?
inputs['input_ids'], inputs['attention_mask'] = model.append_bot(inputs['input_ids'], inputs['attention_mask'])

print('\n## With thought token')
outputs2 = model.generate(
    **inputs,
    do_sample=True,
    max_new_tokens=64,
    streamer=streamer,
    eos_token_id=tokenizer.eos_token_id,
)

## Eval

In [7]:
from transformers import (
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from opencoconut import AutoCoconutForCausalLM, CoTDataset, split_sequences
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [41]:

def extract_generated_answer(model_output: str, eos_token="<|im_end|>"):
    answer_prefix = "Answer: "
    start_index = model_output.find(answer_prefix)

    if start_index == -1:
        return None

    start_index += len(answer_prefix)
    end_index = model_output.find(eos_token, start_index)

    if end_index == -1:
        return None

    extracted_answer = model_output[start_index:end_index].strip()
    return extracted_answer

def pp(s):
    s = s.replace(tokenizer.eos_token, '')
    s = s.replace(tokenizer.bos_token, '')
    s = s.replace(tokenizer.pad_token, '')
    return s

@torch.no_grad()
def evaluate(
    dataloader,
    tokenizer: PreTrainedTokenizer,
    model: PreTrainedModel,
    max_new_tokens: int,
    verbose = 1,
    add_bot = False
):
    total_instances = 0
    total_correct = 0
    for batch in tqdm(dataloader):
        if add_bot:
            batch["input_ids"], batch["attention_mask"] = model.append_bot(batch["input_ids"], batch["attention_mask"])

        (
            thought_ids,
            language_ids,
            thought_mask,
            _,
            _,
            _,
        ) = split_sequences(**batch, coconut_config=model.coconut_config)
        batch_size = thought_ids.shape[0]
        total_instances += batch_size

        # Generate
        beam_output = model.generate(
            input_ids=thought_ids.to(model.device),
            attention_mask=thought_mask.to(model.device),
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        # Evaluate
        for thought_ids_batch, output_batch in zip(thought_ids, beam_output):
            decoded_language_ids = tokenizer.decode(language_ids[0])
            decoded_pred_text = tokenizer.decode(output_batch)
            answer = extract_generated_answer(
                decoded_language_ids, eos_token=tokenizer.eos_token
            )
            pred_answer = extract_generated_answer(
                decoded_pred_text, eos_token=tokenizer.eos_token
            )
            if answer == pred_answer:
                total_correct += 1
            if verbose>1:
                print(
                    f"Input: {pp(tokenizer.decode(thought_ids_batch, skip_special_tokens=True))}\n"
                    f"decoded_language_ids: {pp(decoded_language_ids)}\n"
                    f"decoded_pred_text: {pp(decoded_pred_text)}\n"
                    f"Target: {pp(answer)}\n"
                    f"Predicted: {pp(pred_answer)}\n"
                )
    if verbose>0:
        print(
            f"Input: {pp(tokenizer.decode(thought_ids_batch, skip_special_tokens=True))}\n"
            f"decoded_language_ids: {pp(decoded_language_ids)}\n"
            f"decoded_pred_text: {pp(decoded_pred_text)}\n"
            f"Target: {pp(answer)}\n"
            f"Predicted: {pp(pred_answer)}\n"
        )
    accuracy = total_correct / total_instances
    return accuracy

In [42]:
tokenizer.padding_side = "left"

In [None]:
max_new_tokens = 256

# Load data
dataset = CoTDataset(
    "casperhansen/gsm8k_synthetic_cot",
    tokenizer,
    max_length=max_new_tokens,
    coconut_config=model.coconut_config,
    current_stage=model.coconut_config.stages,
    split="valid",
)
dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False)

# eval final model
accuracy = evaluate(dataloader, tokenizer, model, max_new_tokens)
print(f"Accuracy: {accuracy}")

In [None]:
# what if we explicitly start of BoThough token
accuracy = evaluate(dataloader, tokenizer, model, max_new_tokens, add_bot=True)
print(f"Accuracy: {accuracy}")

In [None]:
# eval all check

checkpoints = sorted(output_dir.glob("stage*/*base")) + sorted(output_dir.glob("stage*/*final"))

for checkpoint in checkpoints:
    print(f"Loading checkpoint: {checkpoint}")
    model = AutoCoconutForCausalLM.from_pretrained(
        checkpoint, torch_dtype=torch.bfloat16, device_map=get_device()
    ).eval()
    model.tokenizer = tokenizer
    accuracy = evaluate(dataloader, tokenizer, model, max_new_tokens)
    print(f"Checkpoint: {checkpoint}, Accuracy: {accuracy}")
    clear_memory()