# Eval coconut

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from coconut import Coconut
from dataset import (
    get_dataset,
    get_question_latent_dataset,
    get_cot_latent_dataset,
    MyCollator,
    collate_and_add_latent
)
from tqdm import tqdm
import os
import json
from collections import defaultdict
from utils import Config, set_seed
import numpy as np
torch.set_grad_enabled(False)

set_seed(0)

In [2]:
coconut = True
cot=False
no_thoughts=False
no_cot = False
device = 'cuda'
model_path = "meta-llama/Llama-3.2-1B"
gen = False # if want to test the generalizing model.
load_from_hf = True

if load_from_hf:
    model_path = "weijie210/Llama-3.2-1B_cladder_6"
model = AutoModelForCausalLM.from_pretrained(model_path,device_map = device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens("<|start-latent|>")
tokenizer.add_tokens("<|end-latent|>")
tokenizer.add_tokens("<|latent|>")
latent_id = tokenizer.convert_tokens_to_ids("<|latent|>")
start_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
end_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")

if gen:
    if coconut:
        load_model_path = "checkpoint/cladder-coconut_gen_llama_1B/checkpoint_8" # use checkpoint_5 for latent 5, 1 natural langauge step
    else:
        load_model_path = "checkpoint/cladder-cot_gen_llama_1B/checkpoint_3"
else:
    if coconut:
        load_model_path = "checkpoint/cladder-coconut_llama_1B/checkpoint_8"
    else:
        load_model_path = "checkpoint/cladder-cot_llama_1B/checkpoint_12"


if coconut:
    if not load_from_hf:
        model.resize_token_embeddings(len(tokenizer))
        embeddings = model.get_input_embeddings()
        target_id = tokenizer.convert_tokens_to_ids("<<")
        # initialize the new token embeddings with a known token
        # it helps stablize the training
        for token_id in [latent_id, start_id, end_id]:
            target_embedding = embeddings.weight.data[token_id]
            embeddings.weight.data[token_id] = target_embedding
            # The input embeddings and lm heads are tied in GPT2. So the code below is not necessary
            lm_head = model.lm_head
            lm_head.weight.data[token_id] = lm_head.weight.data[target_id]

    # load model
    model = Coconut(model, latent_id, start_id, end_id, tokenizer.eos_token_id)
if not load_from_hf:
    saved_weights = torch.load(load_model_path, map_location=device)
    model.load_state_dict(saved_weights, strict=False)


config.json:   0%|          | 0.00/883 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/180 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/51.1k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/335 [00:00<?, ?B/s]

In [3]:
if coconut:
    num_latent_tokens = 6
    c_thought = 1
else:
    num_latent_tokens = 0
    c_thought = 0

In [4]:
if gen:
    test_path = "data/cladder_test_gen.json" 
else:
    test_path = "data/cladder_test.json"
n_samples = 300
ds = get_dataset(test_path, tokenizer)
ds = ds.select(range(n_samples))
original_ds = json.load(open(test_path))
original_ds = original_ds[:n_samples]
configs = Config({'pad_latent_to_max':True,'max_latent_stage':num_latent_tokens,'c_thought':c_thought})
test_ds = get_question_latent_dataset(
        num_latent_tokens,
        ds,
        configs,
        start_id,
        latent_id,
        end_id,
        no_special_marker= not coconut,
    )

Map (num_proc=32):   0%|          | 0/300 [00:00<?, ? examples/s]

Map (num_proc=32):   0%|          | 0/300 [00:00<?, ? examples/s]

Get base performance

In [5]:

def get_perf(ds,compare_step=False):
    acc = []
    acc_step = []
    for i,batch in tqdm(enumerate(ds),total=len(ds)):
        idx = batch['idx']
        step = original_ds[idx]['steps']
        last_step = step[-1].split(',')[-1].strip()
        ans = original_ds[idx]['answer'].strip().lower()
        ans_id = tokenizer.encode(' '+ ans, add_special_tokens=False)[0] # add space to the start of the answer
        batch = {k: torch.tensor(v).to(device).unsqueeze(0) for k, v in batch.items() if v != None and k not in ["idx", "position_ids"]}
        if coconut:
            outputs,output_logits = model.generate(**batch,max_new_tokens=10 if num_latent_tokens == 6 else 50,return_logits = True) # more token if language steps
            ans_prob = output_logits.to(device).softmax(-1)[-1,ans_id].item() # prob of the answer
        else:
            outputs = model.generate(**batch,max_new_tokens=200,pad_token_id=tokenizer.eos_token_id)
        if compare_step: # only possible if num_latent_tokens < 6
            pred_step = pred.split('#')[0].split(',')[-1].strip()
            acc_step.append(pred_step == last_step) # compare step as well. (answer is only binary.)

        pred = tokenizer.decode(outputs[0, batch['input_ids'].shape[1]:], skip_special_tokens=True)
        pred_ans = pred.split('#')[-1].strip().lower()
        acc.append(pred_ans == ans)
    return np.mean(acc), None if not compare_step else np.mean(acc_step)
            

In [6]:
acc,_ = get_perf(test_ds,compare_step=False)
print (f'Base Accuracy: {acc:.2f}')

  0%|                                                                                                       | 0/300 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:49<00:00,  6.08it/s]

Base Accuracy: 0.82





What happens if we instead ablate the number of latent tokens?

In [None]:
num_latent_range = range(num_latent_tokens)
for num_latent in tqdm(num_latent_range,total = len(num_latent_range)): # change number of latent tokens given.
    configs = Config({'pad_latent_to_max':True,'max_latent_stage':num_latent,'c_thought':c_thought})
    test_ds = get_question_latent_dataset(
            num_latent,
            ds,
            configs,
            start_id,
            latent_id,
            end_id,
            no_special_marker= not coconut, # if 0, no special marker as well
        )
    acc,_ = get_perf(test_ds,compare_step=False)
    print (f'Accuracy with {num_latent} latent tokens: {acc:.2f}')
    

  0%|                                                                                                         | 0/6 [00:00<?, ?it/s]

Map (num_proc=32):   0%|          | 0/300 [00:00<?, ? examples/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:27<00:00, 11.06it/s]
 17%|████████████████▏                                                                                | 1/6 [00:28<02:22, 28.42s/it]

Accuracy with 0 latent tokens: 0.42


Map (num_proc=32):   0%|          | 0/300 [00:00<?, ? examples/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:31<00:00,  9.52it/s]
 33%|████████████████████████████████▎                                                                | 2/6 [01:01<02:03, 30.98s/it]

Accuracy with 1 latent tokens: 0.42


Map (num_proc=32):   0%|          | 0/300 [00:00<?, ? examples/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:35<00:00,  8.46it/s]
 50%|████████████████████████████████████████████████▌                                                | 3/6 [01:37<01:40, 33.61s/it]

Accuracy with 2 latent tokens: 0.42


Map (num_proc=32):   0%|          | 0/300 [00:00<?, ? examples/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:39<00:00,  7.56it/s]
 67%|████████████████████████████████████████████████████████████████▋                                | 4/6 [02:18<01:13, 36.52s/it]

Accuracy with 3 latent tokens: 0.42


Map (num_proc=32):   0%|          | 0/300 [00:00<?, ? examples/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:43<00:00,  6.83it/s]
 83%|████████████████████████████████████████████████████████████████████████████████▊                | 5/6 [03:04<00:39, 39.64s/it]

Accuracy with 4 latent tokens: 0.42


Map (num_proc=32):   0%|          | 0/300 [00:00<?, ? examples/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:48<00:00,  6.24it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [03:53<00:00, 38.90s/it]

Accuracy with 5 latent tokens: 0.41





In [None]:
# Get latents first
all_latents = []
for i,batch in tqdm(enumerate(test_ds), total=len(test_ds)):
    batch = {k: torch.tensor(v).to(device).unsqueeze(0) for k, v in batch.items()}
    idx = batch['idx'].item()
    steps = original_ds[idx]['steps']
    batch['labels'] = batch['input_ids'].clone()
    output = model(**batch,store_latent=True)
    all_latents.append(output.latents)

    # for i in range(output.latents.shape[1]): # cosine similarity
    #     print (torch.nn.functional.cosine_similarity(output.latents[0,i], output.latents[0,-1], dim=-1))
    # break
    # decoded_latent = []
    # for latent in output.latents[0]:
    #     logits = latent @ model.base_causallm.lm_head.weight.T
    #     decoded_latent.append((tokenizer.decode(logits.argmax(dim = 0))))


all_latents = torch.cat(all_latents, dim=0)


See what would happen if we intervene and insert a random latent at each step.

In [None]:
step_acc = defaultdict(list)
step_diff = []
for i,batch in tqdm(enumerate(test_ds),total=len(test_ds)):
    idx = batch['idx']
    rung = original_ds[idx]['rung']
    step = original_ds[idx]['steps'][-1].split(',')[-1].strip()
    ans = original_ds[idx]['answer'].strip().lower()
    ans_id = tokenizer.encode(' '+ ans, add_special_tokens=False)[0] # add space to the start of the answer
    batch = {k: torch.tensor(v).to(device).unsqueeze(0) for k, v in batch.items() if v != None and k not in ["idx", "position_ids"]}

    random_idx = i
    while random_idx == i: # make sure we don't sample the same latent
        random_idx = np.random.randint(0,all_latents.shape[0],(1,)).item()
    for step_idx in range(num_latent_tokens):
        sample_latent = all_latents[random_idx]
        outputs = model.generate(**batch,max_new_tokens=10,insert_latent = sample_latent,insert_step=step_idx)
        pred = tokenizer.decode(outputs[0, batch['input_ids'].shape[1]:])
        pred_ans = pred.split('#')[-1].strip().lower()
        step_acc[step_idx].append(pred_ans == ans)


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [04:54<00:00,  1.02it/s]


In [None]:
for step,scores in step_acc.items():
    acc = sum(scores) / len(scores)
    print(f"Step {step}: ans acc: {acc:.2f}")

Step 0: ans acc: 0.74
Step 1: ans acc: 0.74
Step 2: ans acc: 0.74
Step 3: ans acc: 0.74
Step 4: ans acc: 0.74
Step 5: ans acc: 0.74


# save HF model

In [None]:
from huggingface_hub import HfApi, HfFolder, Repository, create_repo, upload_folder

model_name = "Llama-3.2-1B_cladder_6"
local_dir = "checkpoint/Llama-3.2-1B_cladder_6"
model.base_causallm.save_pretrained(local_dir)
tokenizer.save_pretrained(local_dir)

create_repo(model_name, repo_type="model", exist_ok=True)
upload_folder(
    repo_id=f"weijie210/{model_name}",
    folder_path=local_dir,
    repo_type="model"
)