# Our First Step: Run the Original Model

We are using `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` mainly in our experiment, which is a reasoning model based on Qwen-2.5B.
We are planning to evaluate Coconut to the distilled model, and we will compare the performance of Coconut with the original CoT.

## Import Dependencies and the Model

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = torch.device('mps')

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", device_map=device, torch_dtype=torch.float16)

## Load Datasets

In [5]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("open-r1/OpenThoughts-114k-math")

ds['train']

Dataset({
    features: ['source', 'problem', 'solution', 'messages', 'system', 'conversations', 'generated_token_count', 'correct'],
    num_rows: 89120
})

## Modify the CoT Chain to use `<sot>` and `<eot>`

In DeepSeek R1, the CoT chain is wrapped with `<think>` and `</think>` XML tag, but we will use `<sot>` and `<eot>` special tokens instead.

That's because in Coconut, we must find a **special** thing to determine whether the chain is terminated or not—in legacy CoT, we can just use HTML parsing and use `</think>` to determine the end of the chain. However, this doesn't apply to Coconut.

Hence, we will use `<sot>` and `<eot>` to determine the start and end of the chain.

## Add Special Tokens

In [8]:
# SOT: Start of Thought, EOT: End of Thought, SOS: Start of Solution, EOS: End of Solution
special_tokens_dict = {'additional_special_tokens': ['<sot>', '<eot>', '<sos>', '<eos>']}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

Embedding(151667, 1536)

### Preprocess the Dataset

We will preprocess the dataset to add `<sot>` and `<eot>` to the CoT chain.

In [11]:
import re

ft_ds = load_dataset('ServiceNow-AI/R1-Distill-SFT', 'v1')

def replace_tags(example):
    example['reannotated_assistant_content'] = re.sub(r'<think>(.*?)</think>', r'<sot>\1<eot>', example['reannotated_assistant_content'])
    example['reannotated_assistant_content'] = re.sub(r'\\boxed{(.*?)}', r'<sos>\1<eos>', example['reannotated_assistant_content'])
    return example


ft_ds = ft_ds.map(replace_tags)

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/51 [00:00<?, ?it/s]

Map:   0%|          | 0/1679162 [00:00<?, ? examples/s]

Due to the set size, we can use the 1024 items of the set to test the model.

In [24]:
train_subset = ft_ds['train'].shuffle(seed=42).select(range(1024))

def tokenize_function(example):
    tokens = tokenizer(
        example["reannotated_assistant_content"],
        truncation=True,
        padding="max_length",
        max_length=32768
    )
    tokens = {k: torch.tensor(v).to(device) for k, v in tokens.items()}  # Move to MPS
    return tokens

tokenized_ds = train_subset.map(tokenize_function, batched=True)
tokenized_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

Map:   0%|          | 0/1024 [00:00<?, ? examples/s]

### Try to train the model

In [27]:
from transformers import TrainingArguments, Trainer
import gc
import torch

gc.collect()
torch.mps.empty_cache()

training_args = TrainingArguments(
    output_dir="./stage_1",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=3,  # 3 epochs to make use of 1000 examples
    learning_rate=5e-5,
    warmup_steps=50,
    weight_decay=0.01,
    save_strategy="epoch",
    logging_dir="./logs"
)

class MPS_Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Move inputs to MPS
        outputs = model(**inputs)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss

trainer = MPS_Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds,
    tokenizer=tokenizer
)

trainer.train()

  trainer = MPS_Trainer(


RuntimeError: MPS backend out of memory (MPS allocated: 16.22 GB, other allocations: 816.00 KB, max allowed: 18.13 GB). Tried to allocate 2.00 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).