<a href="https://colab.research.google.com/github/Reennon/multigec-models/blob/main/notebooks/aya_expanse_8b/multigec/multigec_prediction_fluency.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os

from google.colab import userdata

os.environ["GIT_TOKEN"] = userdata.get('git_token')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://$GIT_TOKEN@github.com/Reennon/omnigec-models.git

In [None]:
%cd omnigec-models

In [None]:
!git pull

In [None]:
!pip install -U bitsandbytes peft accelerate datasets sentencepiece wandb python-dotenv wtpsplit -q
!pip install flash-attn --no-build-isolation -q
!pip install wtpsplit==2.1.1 -q
!pip install syntok==1.4.4 -q
!pip install omegaconf -q
!pip install wandb -q
!pip install --upgrade transformers trl -q
!pip install pandas numpy -q

In [None]:
from google.colab import drive
drive.mount('/gdrive')

In [None]:
import os

from omegaconf import OmegaConf
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import login
from src.utils.multigec import sentences, LANG_TO_CODE, LANG_CODE_TO_TOKEN
from src.utils.aya_utils import inference_formatting_prompts_func as formatting_prompts_func
from langchain_core.prompts import PromptTemplate

from src.instruction_templates import multigec_prompts

import torch
import wandb

from transformers import BitsAndBytesConfig
from tqdm import tqdm
from trl.trainer import ConstantLengthDataset
import pandas as pd
from datasets import Dataset
from transformers.trainer_callback import EarlyStoppingCallback

from transformers import TrainingArguments
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

from peft import PeftModel, PeftModelForCausalLM

tqdm.pandas()

In [None]:
parameters = OmegaConf.load("./params/aya_expanse_8b.yaml")

In [None]:
track     = "fluency"
fine_tuned_model_name = f"aya-expanse-8b-multigec-{track}"

hf_key   = userdata.get("hf_key")
secret_wandb = userdata.get("wandb_key")

in_path  = f"/content/drive/MyDrive/omnigec/datasets/multigec_{track}.csv"
out_path = f"/content/drive/MyDrive/omnigec/preds/multigec_test_{track}.csv"

out_model_dir = f"/content/drive/MyDrive/omnigec/models/{fine_tuned_model_name}"
QUANTIZE_4BIT = True
device   = "cuda:0"

In [None]:
login(hf_key)

In [None]:
!env TORCH_USE_CUDA_DSA=1 -q

In [None]:
base_model = "CohereForAI/aya-expanse-8b"
saved_checkpoint = out_model_dir + "/checkpoint-400"

quantization_config = None
if QUANTIZE_4BIT:
  quantization_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_use_double_quant=True,
      bnb_4bit_compute_dtype=torch.bfloat16,
  )
tokenizer = AutoTokenizer.from_pretrained(saved_checkpoint)
config = AutoConfig.from_pretrained(base_model)
base_model_instance = AutoModelForCausalLM.from_pretrained(
    base_model,
    config=config,
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    device_map=device,
    attn_implementation="flash_attention_2",
)
base_model_instance.resize_token_embeddings(len(tokenizer))
model = PeftModelForCausalLM.from_pretrained(
    base_model_instance,
    saved_checkpoint,
    torch_dtype=torch.bfloat16,
    device_map=device,
    ignore_mismatched_sizes=True
)


In [None]:
multigec_df = pd.read_csv(in_path)
train_df = multigec_df.loc[multigec_df.loc[:, "split"] == "train"]
val_df = multigec_df.loc[multigec_df.loc[:, "split"] == "val"]
test_df = multigec_df.loc[multigec_df.loc[:, "split"] == "test"]

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

max_new_tokens = 1600
batch_size = 10
save_each = 50

# Assume test_df already exists and out_path is defined.
# Filter rows that need processing (target is empty/NaN).
to_process_df = test_df[test_df["target"].isna()].reset_index()  # preserve original index in "index" column

def collate_fn(examples):
    indices = [ex["index"] for ex in examples]
    texts = [formatting_prompts_func(example) for example in examples]
    tokenized = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    tokenized["indices"] = indices
    return tokenized


# Create DataLoader using only the rows that need processing.
dataloader = DataLoader(
    to_process_df.to_dict(orient="records"),
    batch_size=batch_size,
    collate_fn=collate_fn
)

processed_rows = 0
for batch in tqdm(dataloader):
    input_ids = batch["input_ids"].to(model.device)
    attention_mask = batch.get("attention_mask", None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(model.device)
    prompt_padded_len = len(input_ids[0])

    gen_tokens = model.generate(
        input_ids,
        attention_mask=attention_mask,
        temperature=parameters.baseline.temperature,
        top_p=parameters.baseline.top_p,
        top_k=parameters.baseline.top_k,
        max_new_tokens=max_new_tokens,
        do_sample=True,
    )

    # Remove the prompt tokens from the generated tokens
    gen_tokens = [gt[prompt_padded_len:] for gt in gen_tokens]

    # Decode generated tokens to text corrections
    corrections = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
    corrections = ["".join(c) for c in corrections]

    # Update the original DataFrame using the indices provided in the batch.
    for idx, corr in zip(batch["indices"], corrections):
        test_df.loc[idx, "target"] = corr
        processed_rows += 1

        if processed_rows % save_each == 0:
            test_df.to_csv(out_path, index=False)
            print(f"Saved progress after processing {processed_rows} rows.")

# Final save after processing all batches.
test_df.to_csv(out_path, index=False)
print("Final save complete!")


In [None]:
from google.colab import runtime
runtime.unassign()