<a href="https://colab.research.google.com/github/Reennon/multigec-models/blob/main/notebooks/aya_expanse_8b/multigec/multigec_prediction_chunked.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os

from google.colab import userdata

os.environ["GIT_TOKEN"] = userdata.get('git_token')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://$GIT_TOKEN@github.com/Reennon/multigec-models.git

In [None]:
%cd multigec-models

In [None]:
!git pull

In [None]:
!pip install -U bitsandbytes peft accelerate datasets sentencepiece wandb python-dotenv wtpsplit -q
!pip install flash-attn --no-build-isolation -q
!pip install wtpsplit==2.1.1 -q
!pip install syntok==1.4.4 -q
!pip install omegaconf -q
!pip install wandb -q
!pip install --upgrade transformers trl -q
!pip install pandas numpy -q

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

from omegaconf import OmegaConf
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import login
from src.utils.multigec import sentences, LANG_TO_CODE, LANG_CODE_TO_TOKEN
from langchain_core.prompts import PromptTemplate

from src.instruction_templates import multigec_prompts

import torch
import wandb

from transformers import BitsAndBytesConfig
from tqdm import tqdm
from trl.trainer import ConstantLengthDataset
import pandas as pd
from datasets import Dataset
from transformers.trainer_callback import EarlyStoppingCallback

from transformers import TrainingArguments
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

from peft import PeftModel, PeftModelForCausalLM

tqdm.pandas()

In [None]:
parameters = OmegaConf.load("./params/aya_expanse_8b.yaml")

In [None]:
track     = "fluency"
fine_tuned_model_name = f"aya-expanse-8b-multigec-{track}"

hf_key   = userdata.get("hf_key")
secret_wandb = userdata.get("wandb_key")

in_path  = f"/content/drive/MyDrive/multigec/datasets/multigec_{track}.csv"
out_path = f"/content/drive/MyDrive/multigec/preds/multigec_test_{track}.csv"
temp_out_path = f"/content/drive/MyDrive/multigec/preds/temp_multigec_test_{track}_chunked.csv"
out_model_dir = f"/content/drive/MyDrive/multigec/models/{fine_tuned_model_name}"
QUANTIZE_4BIT = True
device   = "cuda:0"

In [None]:
login(hf_key)

In [None]:
!env TORCH_USE_CUDA_DSA=1 -q

In [None]:
base_model = "CohereForAI/aya-expanse-8b"
saved_checkpoint = out_model_dir + "/checkpoint-400"

quantization_config = None
if QUANTIZE_4BIT:
  quantization_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_use_double_quant=True,
      bnb_4bit_compute_dtype=torch.bfloat16,
  )
tokenizer = AutoTokenizer.from_pretrained(saved_checkpoint)
config = AutoConfig.from_pretrained(base_model)
base_model_instance = AutoModelForCausalLM.from_pretrained(
    base_model,
    config=config,
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    device_map=device,
    attn_implementation="flash_attention_2",
)
base_model_instance.resize_token_embeddings(len(tokenizer))
model = PeftModelForCausalLM.from_pretrained(
    base_model_instance,
    saved_checkpoint,
    torch_dtype=torch.bfloat16,
    device_map=device,
    ignore_mismatched_sizes=True
)


In [None]:
multigec_df = pd.read_csv(in_path)
train_df = multigec_df.loc[multigec_df.loc[:, "split"] == "train"]
val_df = multigec_df.loc[multigec_df.loc[:, "split"] == "val"]
test_df = multigec_df.loc[multigec_df.loc[:, "split"] == "test"]

In [None]:
def formatting_prompts_func(example):
    language_code = LANG_TO_CODE[example["language"]]
    language_token = LANG_CODE_TO_TOKEN[language_code]

    user_input = example['feature']
    prompt_template = multigec_prompts[example["language"]].prompt_template
    instruction = prompt_template.format(original_text=user_input)

    text = f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{language_token}{instruction}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"

    return text

In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from wtpsplit import SaT

# ------------------------------------------------------------------------------
# ASSUMPTIONS:
# 1) You already have:
#       model         -> Your loaded model
#       tokenizer     -> Your tokenizer
#       test_df       -> A DataFrame with at least columns ["feature", "target"]
#       out_path      -> Path to save final CSV
#       temp_out_path -> A "partial progress" CSV path
# 2) You want to process only rows where "target" is NaN, splitting the "feature"
#    if it exceeds a certain max_tokens threshold, then recombining outputs.
# ------------------------------------------------------------------------------

# Instantiate the sentence splitter (wtpsplit)
sat = SaT("sat-3l")
sat.half().to("cuda")

# Adjustable parameters
max_new_tokens = 600
batch_size = 100
save_each = 100
max_tokens = 600  # maximum tokens allowed per chunk prompt

# ----------------------------------------------------------------------
# Helper function to chunk large texts using newline -> sentence -> word.
# ----------------------------------------------------------------------
def chunk_large_texts(df, tokenizer, text_col="feature", max_tokens=2048):
    """
    Splits texts in 'df[text_col]' if they exceed 'max_tokens'. Returns a new DF
    with multiple rows (chunks) per original row. Preserves original row index in 'index'
    and adds a 'chunk_id'.
    """
    chunked_rows = []

    for i, row in df.iterrows():
        original_df_index = row["index"]  # the original test_df index
        full_text = row[text_col]

        # Quickly check if it needs chunking
        initial_ids = tokenizer(full_text, add_special_tokens=False).input_ids
        if len(initial_ids) < max_tokens:
            # No chunking needed
            new_row = row.copy()
            new_row["chunk_id"] = 0
            chunked_rows.append(new_row)
            continue

        # Otherwise, we chunk
        splitted_texts = full_text.split("\n")

        all_chunks = []
        for s in splitted_texts:
            source_inputs = tokenizer(s, add_special_tokens=False).input_ids
            if len(source_inputs) < max_tokens:
                all_chunks.append(s)
            else:
                # Split by sentences using SaT
                sent_splits = sat.split(s)
                curr_input = ""

                for sent in sent_splits:
                    candidate = (curr_input + " " + sent).strip()
                    curr_tokens = tokenizer(candidate, add_special_tokens=False).input_ids

                    if len(curr_tokens) < max_tokens:
                        curr_input = candidate
                    else:
                        # Store what we had so far
                        if curr_input:
                            all_chunks.append(curr_input)

                        # Now chunk 'sent' by spaces if it's still too big
                        words = sent.split()
                        sub_chunk = []
                        for w in words:
                            test_sub = " ".join(sub_chunk + [w])
                            test_sub_ids = tokenizer(test_sub, add_special_tokens=False).input_ids
                            if len(test_sub_ids) < max_tokens:
                                sub_chunk.append(w)
                            else:
                                # store the chunk we had
                                if sub_chunk:
                                    all_chunks.append(" ".join(sub_chunk))
                                # start new chunk
                                sub_chunk = [w]

                        # flush remainder
                        if sub_chunk:
                            all_chunks.append(" ".join(sub_chunk))

                        # reset curr_input
                        curr_input = ""

                # leftover in curr_input
                if curr_input:
                    all_chunks.append(curr_input)

        # Now create rows for each chunk
        for chunk_i, chunk_text in enumerate(all_chunks):
            new_row = row.copy()
            new_row[text_col] = chunk_text
            new_row["chunk_id"] = chunk_i
            chunked_rows.append(new_row)

    chunked_df = pd.DataFrame(chunked_rows).reset_index(drop=True)
    return chunked_df

# ------------------------------------------------------------------------------
# Define your prompt-formatting function here:
# ------------------------------------------------------------------------------
def formatting_prompts_func(ex):
    """
    Given a row (dictionary) from expanded_df,
    create the prompt text that the model will see.
    Adjust to your own style.
    """
    return f"Some system prompt:\n\n{ex['feature']}"

# ------------------------------------------------------------------------------
# Try to resume from partial expansions (temp_out_path) if available
# ------------------------------------------------------------------------------
if os.path.exists(temp_out_path):
    print(f"Resuming from partial expansions at {temp_out_path} ...")
    expanded_df = pd.read_csv(temp_out_path)
else:
    print("No partial file found. Splitting & creating chunk-level data ...")
    # Only split for rows that have no final 'target' yet
    to_process_df = test_df[test_df["target"].isna()].reset_index()  # keeps old index in "index"

    expanded_df = chunk_large_texts(
        df=to_process_df,
        tokenizer=tokenizer,
        text_col="feature",
        max_tokens=max_tokens
    )
    # Initialize empty 'target' column if not present
    if "target" not in expanded_df.columns:
        expanded_df["target"] = None

# ------------------------------------------------------------------------------
# 4) Create DataLoader on chunk-rows that still need answers
# ------------------------------------------------------------------------------
def collate_fn(examples):
    indices = [ex["index"] for ex in examples]   # original test_df row index
    chunk_ids = [ex["chunk_id"] for ex in examples]

    texts = [formatting_prompts_func(ex) for ex in examples]
    tokenized = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    tokenized["indices"] = indices
    tokenized["chunk_ids"] = chunk_ids
    return tokenized

# We only generate for chunk rows whose 'target' is still NaN (unanswered).
expanded_df_for_gen = expanded_df[expanded_df["target"].isna()]

if len(expanded_df_for_gen) == 0:
    print("No unanswered chunks found; skipping generation.")
else:
    dataloader = DataLoader(
        expanded_df_for_gen.to_dict(orient="records"),
        batch_size=batch_size,
        collate_fn=collate_fn
    )

    processed_rows = 0
    for batch in tqdm(dataloader, desc="Generating"):
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch.get("attention_mask", None)
        if attention_mask is not None:
            attention_mask = attention_mask.to(model.device)

        prompt_padded_len = input_ids.shape[1]

        # Generate
        gen_tokens = model.generate(
            input_ids,
            attention_mask=attention_mask,
            temperature=parameters.baseline.temperature,
            top_p=parameters.baseline.top_p,
            top_k=parameters.baseline.top_k,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            use_cache=True,
        )

        # Strip off the prompt portion from each
        gen_tokens = [gt[prompt_padded_len:] for gt in gen_tokens]

        # Decode
        corrections = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)

        # Save these chunk-level outputs into expanded_df
        for idx, cid, corr in zip(batch["indices"], batch["chunk_ids"], corrections):
            mask = (expanded_df["index"] == idx) & (expanded_df["chunk_id"] == cid)
            expanded_df.loc[mask, "target"] = corr
            processed_rows += 1

            # Partial save for safety
            if processed_rows % save_each == 0:
                expanded_df.to_csv(temp_out_path, index=False)
                print(f"[Checkpoint] Saved expanded_df progress ({processed_rows} chunks).")

    # Final save of chunk-level expansions
    expanded_df.to_csv(temp_out_path, index=False)
    print(f"Chunk-level expansions saved to {temp_out_path}.")

# ------------------------------------------------------------------------------
# 5) Recombine chunk outputs into test_df (one output per original row)
# ------------------------------------------------------------------------------
grouped = expanded_df.groupby("index")["target"].apply(list).reset_index(name="chunks")
for _, row in grouped.iterrows():
    orig_idx = row["index"]      # the real test_df index
    chunk_outputs = row["chunks"]
    combined_output = "\n".join(str(x) for x in chunk_outputs if pd.notnull(x))
    test_df.loc[orig_idx, "target"] = combined_output

# Final save of test_df
test_df.to_csv(out_path, index=False)
print("Final save complete!")


In [None]:
from google.colab import runtime
runtime.unassign()