In [1]:
%reload_ext autoreload
%autoreload 2

## Train

In [2]:
import os
import torch
import logging
from transformers import (
    Trainer,
    TrainingArguments,
    AutoTokenizer,
)
from opencoconut import (
    AutoCoconutForCausalLM,
    CoconutConfig,
    CoTDataset,
)
from pathlib import Path
import gc

# Configure logging
from loguru import logger

def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()

def get_device():
    if torch.backends.mps.is_available():
        return "mps"
    elif torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"

In [3]:
config_debug = {
    'model_name': "Qwen/Qwen2.5-0.5B",
    'batch_size': 4,
    'learning_rate': 4e-4,
    'samples_per_epoch': 12,
    'output_dir': "./output/small",
    'num_epochs': 1,
}
config_small = {
    'model_name': "Qwen/Qwen2.5-0.5B",
    'batch_size': 4,
    'learning_rate': 1e-4,
    'samples_per_epoch': 2000,
    'output_dir': "./output/small",
    'num_epochs': 1,
}
config_medium = {
    'model_name': "Qwen/Qwen2.5-2.5B",
    'batch_size': 1,
    'learning_rate': 5e-5,
    'samples_per_epoch': 30000,
    'output_dir': "./output/small",
    'num_epochs': 3,
}
config = config_small

DEBUG = False
if DEBUG:
    config = config_debug
    os.environ["DEBUG"] = "1"

config

{'model_name': 'Qwen/Qwen2.5-0.5B',
 'batch_size': 4,
 'learning_rate': 0.0001,
 'samples_per_epoch': 2000,
 'output_dir': './output/small',
 'num_epochs': 1}

## Load model

In [4]:


# Initialize model and tokenizer
logger.info("Initializing model and tokenizer")
output_dir = Path(config['output_dir'])

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['model_name'], padding_side="left")
tokenizer.bos_token = "<|im_start|>"
tokenizer.eos_token = "<|im_end|>"
tokenizer.pad_token = "<|endoftext|>"

# config and model
# FIXME we wont need custom tokens anymore
coconut_config = CoconutConfig.from_tokenizer(
    tokenizer,
    stages=4,
    continuous_thoughts=2,
)
model = AutoCoconutForCausalLM.from_pretrained(
    config['model_name'], coconut_config, torch_dtype=torch.bfloat16, device_map=get_device()
)
model.resize_token_embeddings(len(tokenizer))
model.tokenizer = tokenizer

# Set up training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=config['batch_size'],
    gradient_accumulation_steps=1,
    learning_rate=config['learning_rate'],
    warmup_ratio=0.1,
    max_steps=config['samples_per_epoch']//config['batch_size']*config['num_epochs'],
    logging_steps=100, # TODO ideally we log to tensorboard every step, but to ui every 100 steps
    save_steps=10000,
    bf16=True,
    bf16_full_eval=True,
    optim="adamw_torch", # save memory: adamw_bnb_8bit
)


[32m2025-01-15 20:53:48.708[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mInitializing model and tokenizer[0m
Some weights of CoconutQwen2ForCausalLM were not initialized from the model checkpoint at Qwen/Qwen2.5-0.5B and are newly initialized: ['switch.0.bias', 'switch.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Base eval

In [5]:
from opencoconut.evaluate import evaluate, extract_generated_answer
from torch.utils.data import DataLoader

# Load val data
dataset_val = CoTDataset(
    "casperhansen/gsm8k_synthetic_cot",
    tokenizer,
    max_length=256,
    coconut_config=model.coconut_config,
    current_stage=model.coconut_config.stages,
    split="valid[:64]" if DEBUG else "valid",
)
dataloader_val = DataLoader(dataset_val, batch_size=config['batch_size'], shuffle=False)
max_new_tokens = 6 if DEBUG else 64
# eval final model
# accuracy = evaluate(dataloader_val, tokenizer, model, max_new_tokens)
# print(f"Accuracy: {accuracy}")

### QC

In [None]:
if DEBUG:
    # QC forward: part 1
    clear_memory()
    dataset = CoTDataset(
        "casperhansen/gsm8k_synthetic_cot",
        tokenizer,
        max_length=256, # all less than 256, most < 128
        coconut_config=coconut_config,
        current_stage=1,
        split=f"train[:{config['samples_per_epoch']}]",
    )
    dl = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'])
    batch = next(iter(dl))
    batch = {k: v.to(get_device()) for k, v in batch.items()}
    print(batch.keys(), batch['input_ids'].shape)

In [None]:
if DEBUG:
    # QC forward: part 2
    for i in [0, 1, 2]:
        model.current_stage = i
        print(f"Stage {i}")
        with torch.no_grad():
            model.eval()
            o = model(**batch, output_hidden_states=True)
            print(o.keys())
            model.train()
            o = model(**batch, output_hidden_states=True)
            print(o.keys(), o['loss'].shape)
            print(o['loss'].item())

In [8]:
# if DEBUG:
#     # QC train
#     clear_memory()
#     dataset = CoTDataset(
#         "casperhansen/gsm8k_synthetic_cot",
#         tokenizer,
#         max_length=256, # all less than 256, most < 128
#         coconut_config=coconut_config,
#         current_stage=1,
#         split=f"train[:{config['samples_per_epoch']}]",
#     )
#     trainer = Trainer(
#         model=model,
#         args=training_args,
#         train_dataset=dataset,
#     )
#     trainer.train()

### Train

In [9]:
# save base checkpoint
if not DEBUG:
    current_output_dir = output_dir/"stage0}"
    current_output_dir = current_output_dir/"checkpoint-base"
    model.save_pretrained(current_output_dir)
    tokenizer.save_pretrained(current_output_dir)

In [10]:
# torch.autograd.set_detect_anomaly(True)

In [None]:

# Initialize trainer
for stage in range(coconut_config.stages):

    dataset = trainer = None    
    clear_memory()
    
    logger.info(f"starting stage {stage}")
    logger.info("preparing dataset")
    dataset = CoTDataset(
        "casperhansen/gsm8k_synthetic_cot",
        tokenizer,
        max_length=256, # all less than 256, most < 128
        coconut_config=coconut_config,
        current_stage=stage,
        split=f"train[:{config['samples_per_epoch']}]",
    )
    logger.info(f"dataset size: {len(dataset)}")
    model.current_stage = stage
    current_output_dir = output_dir/f"stage{stage}"
    training_args.output_dir = current_output_dir

    if stage == 0:
        training_args.num_train_epochs = config['num_epochs']
    elif stage == coconut_config.stages-2:
        # Penultimate stage removes all the remaining language reasoning chain
        # This handles the long-tail distribution of reasoning chains longer than 3 steps
        dataset.include_reasoning_steps = False
        training_args.num_train_epochs = config['num_epochs']
    elif stage == coconut_config.stages-1:
        # For all datasets, after the standard schedule,
        # the model stays in the final training stage, until the 50th epoch.
        dataset.include_reasoning_steps = True
        training_args.num_train_epochs = config['num_epochs'] * 3

    logger.info("starting training")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()

    # save tokenizer to all checkpoints after training
    for folder in os.listdir(current_output_dir):
        if folder.startswith("checkpoint-"):
            checkpoint_folder = os.path.join(current_output_dir, folder)
            if os.path.isdir(checkpoint_folder):
                tokenizer.save_pretrained(checkpoint_folder)

    # save final checkpoint
    current_output_dir = current_output_dir/"checkpoint-final"
    model.save_pretrained(current_output_dir)
    tokenizer.save_pretrained(current_output_dir)
    logger.info(f"finished stage {stage}. Saved to {current_output_dir}")

    clear_memory()

## Infer

In [12]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [21]:
model = model.cpu() # HACK

In [None]:
from transformers import AutoTokenizer, TextStreamer
streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)

prompt = "John cuts his grass to 2 inches. " \
         "It grows .5 inches per month. " \
         "When it gets to 4 inches he cuts it back down to 2 inches. " \
         "It cost $100 to get his grass cut. How much does he pay per year?"

ans = """
# since it starts at 2 and never gets cut below 2, we can consider only the extra growth
growth_annual = 0.5*12
cost_per_inch = 100/2
cuts = growth_annual // 2 # round it down
cost_per_year = growth_annual * cost_per_inch
print(f"cost per year: {cost_per_year}==300.0")
"""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

print('## With out thought token')
outputs1 = model.generate(
    **inputs,
    do_sample=True,
    max_new_tokens=64,
    streamer=streamer,
    eos_token_id=tokenizer.eos_token_id,
)

# add beginning of thought token?
# inputs['input_ids'], inputs['attention_mask'] = model.append_bot_token(inputs['input_ids'], inputs['attention_mask'])

# print('\n## With thought token')
# outputs2 = model.generate(
#     **inputs,
#     do_sample=True,
#     max_new_tokens=64,
#     streamer=streamer,
#     eos_token_id=tokenizer.eos_token_id,
# )

## Eval

In [22]:
from transformers import (
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from opencoconut import AutoCoconutForCausalLM, CoTDataset, split_sequences
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [23]:
tokenizer.padding_side = "left"
model.tokenizer.padding_side = "left"

In [None]:
# eval final model
accuracy = evaluate(dataloader_val, tokenizer, model, max_new_tokens)
print(f"Accuracy: {accuracy}")

In [None]:
# what if we explicitly start of BoThought token
accuracy = evaluate(dataloader_val, tokenizer, model, max_new_tokens, add_bot=True)
print(f"Accuracy: {accuracy}")

In [None]:
# # eval all check, we want to see acc increasing from the first, which is the base model

# # FIXME loading seems broken :(

# checkpoints = sorted(output_dir.glob("stage*/*base")) + sorted(output_dir.glob("stage*/*final"))

# for checkpoint in checkpoints:
#     print(f"Loading checkpoint: {checkpoint}")
#     model = AutoCoconutForCausalLM.from_pretrained(
#         checkpoint, torch_dtype=torch.bfloat16, device_map=get_device()
#     ).eval()
#     model.tokenizer = tokenizer
#     accuracy = evaluate(dataloader_val, tokenizer, model, max_new_tokens)
#     print(f"Checkpoint: {checkpoint}, Accuracy: {accuracy}")
#     clear_memory()