<a href="https://colab.research.google.com/github/Reennon/multigec-models/blob/main/notebooks/aya_expanse_8b/multigec/multigec_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os

from google.colab import userdata

os.environ["GIT_TOKEN"] = userdata.get('git_token')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://$GIT_TOKEN@github.com/Reennon/omnigec-models.git

In [None]:
%cd omnigec-models

In [None]:
!git pull

In [None]:
!pip install -U bitsandbytes peft accelerate datasets sentencepiece wandb python-dotenv wtpsplit -q
!pip install flash-attn --no-build-isolation -q
!pip install wtpsplit==2.1.1 -q
!pip install syntok==1.4.4 -q
!pip install omegaconf -q
!pip install wandb -q
!pip install --upgrade transformers trl -q
!pip install pandas numpy -q

In [None]:
from google.colab import drive
drive.mount('/gdrive')

In [None]:
import os

from omegaconf import OmegaConf
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import login
from src.utils.multigec import sentences, LANG_TO_CODE, LANG_CODE_TO_TOKEN
from src.utils.aya_utils import inference_formatting_prompts_func as formatting_prompts_func
from langchain_core.prompts import PromptTemplate

from src.instruction_templates import multigec_prompts

import torch
import wandb

from transformers import BitsAndBytesConfig
from tqdm import tqdm
from trl.trainer import ConstantLengthDataset
import pandas as pd
from datasets import Dataset
from transformers.trainer_callback import EarlyStoppingCallback

from transformers import TrainingArguments
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

from peft import PeftModel, PeftModelForCausalLM

tqdm.pandas()

In [None]:
parameters = OmegaConf.load("./params/aya_expanse_8b.yaml")

In [None]:
track     = "minimal"
fine_tuned_model_name = f"aya-expanse-8b-multigec"

hf_key   = userdata.get("hf_key")
secret_wandb = userdata.get("wandb_key")

in_path  = f"/content/drive/MyDrive/multigec/datasets/multigec_{track}.csv"
out_path = f"/content/drive/MyDrive/multigec/preds/multigec_test_{track}.csv"

out_model_dir = f"/content/drive/MyDrive/multigec/models/{fine_tuned_model_name}"
QUANTIZE_4BIT = True
device   = "cuda:0"

In [None]:
login(hf_key)

In [None]:
!env TORCH_USE_CUDA_DSA=1 -q

In [None]:
base_model = "CohereForAI/aya-expanse-8b"
saved_checkpoint = out_model_dir + "/checkpoint-700"

quantization_config = None
if QUANTIZE_4BIT:
  quantization_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_use_double_quant=True,
      bnb_4bit_compute_dtype=torch.bfloat16,
  )
tokenizer = AutoTokenizer.from_pretrained(saved_checkpoint)
config = AutoConfig.from_pretrained(base_model)
base_model_instance = AutoModelForCausalLM.from_pretrained(
    base_model,
    config=config,
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    device_map=device,
    attn_implementation="flash_attention_2",
)
base_model_instance.resize_token_embeddings(len(tokenizer))
model = PeftModelForCausalLM.from_pretrained(
    base_model_instance,
    saved_checkpoint,
    torch_dtype=torch.bfloat16,
    device_map=device,
    ignore_mismatched_sizes=True
)


In [None]:
multigec_df = pd.read_csv(in_path)
train_df = multigec_df.loc[multigec_df.loc[:, "split"] == "train"]
val_df = multigec_df.loc[multigec_df.loc[:, "split"] == "val"]
test_df = multigec_df.loc[multigec_df.loc[:, "split"] == "test"]

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

max_new_tokens = 1600
batch_size = 10
save_each = 200

# Assume test_df already exists and out_path is defined.
# Filter rows that need processing (target is empty/NaN).
to_process_df = test_df[test_df["target"].isna()].reset_index()  # preserve original index in "index" column

def collate_fn(examples):
    indices = [ex["index"] for ex in examples]
    texts = [formatting_prompts_func(example) for example in examples]
    tokenized = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    tokenized["indices"] = indices
    return tokenized


# Create DataLoader using only the rows that need processing.
dataloader = DataLoader(
    to_process_df.to_dict(orient="records"),
    batch_size=batch_size,
    collate_fn=collate_fn
)

processed_rows = 0
for batch in tqdm(dataloader):
    input_ids = batch["input_ids"].to(model.device)
    attention_mask = batch.get("attention_mask", None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(model.device)
    prompt_padded_len = len(input_ids[0])

    gen_tokens = model.generate(
        input_ids,
        attention_mask=attention_mask,
        temperature=parameters.baseline.temperature,
        top_p=parameters.baseline.top_p,
        top_k=parameters.baseline.top_k,
        max_new_tokens=max_new_tokens,
        do_sample=True,
    )

    # Remove the prompt tokens from the generated tokens
    gen_tokens = [gt[prompt_padded_len:] for gt in gen_tokens]

    # Decode generated tokens to text corrections
    corrections = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
    corrections = ["".join(c) for c in corrections]

    # Update the original DataFrame using the indices provided in the batch.
    for idx, corr in zip(batch["indices"], corrections):
        test_df.loc[idx, "target"] = corr
        processed_rows += 1

        if processed_rows % save_each == 0:
            test_df.to_csv(out_path, index=False)
            print(f"Saved progress after processing {processed_rows} rows.")

# Final save after processing all batches.
test_df.to_csv(out_path, index=False)
print("Final save complete!")


In [None]:
test_df.to_csv(out_path)

In [None]:
test_df

In [None]:
test_df.reset_index()

In [None]:
test_df[:1].target

In [None]:
print(test_df[:1].feature.item())

In [None]:
print(all_corrections[0])

In [None]:
"".join(all_corrections)

In [None]:
test_df[:10]

In [None]:
test_df.to_csv(out_path)

In [None]:
def make_correction(
    model,
    input: pd.Series,
    tokenizer,
    parameters: dict,
    seq_lenght: int
):
    input = formatting_prompts_func(input)
    input_ids = tokenizer(
        [input],
        padding=True,
        return_tensors="pt",
    ).input_ids
    input_ids = input_ids.to(model.device)
    prompt_padded_len = len(input_ids[0])

    # Generate corrections
    gen_tokens = model.generate(
        input_ids,
        temperature=parameters.baseline.temperature,
        top_p=parameters.baseline.top_p,
        top_k=parameters.baseline.top_k,
        max_new_tokens=seq_lenght,
        do_sample=True,
    )
    gen_tokens = [
        gt[prompt_padded_len:] for gt in gen_tokens
    ]
    outputs: list[str] = tokenizer.batch_decode(
        gen_tokens,
        skip_special_tokens=True
    )
    # Join divided texts if any
    correction = "".join(outputs)

    return correction

test_df[:5]["target"] = test_df[:5].progress_apply(lambda x: make_correction(
    model=model,
    input=x,
    tokenizer=tokenizer,
    parameters=parameters,
    seq_lenght=1600,
), axis=1)
test_df

In [None]:
test_df[:1].feature.item()

In [None]:
print(test_df[:1].target.item())

In [None]:
raise Exception(e)

In [None]:
def formatting_prompts_func(example):
    language_code = LANG_TO_CODE[example["language"]]
    language_token = LANG_CODE_TO_TOKEN[language_code]

    user_input = example['feature']
    prompt_template = multigec_prompts[example["language"]].prompt_template
    instruction = prompt_template.format(original_text=user_input)
    target = example['target']

    text = f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{language_token}{instruction}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{target}"

    return text

In [None]:
num_added_toks = tokenizer.add_tokens(
    [v for v in LANG_CODE_TO_TOKEN.values()],
    special_tokens=True
)
model.resize_token_embeddings(len(tokenizer))

In [None]:
training_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)

seq_length = 1300

cld_train_dataset = ConstantLengthDataset(
    tokenizer=tokenizer,
    dataset=training_dataset,
    #dataset_text_field='feature',
    seq_length=seq_length,
    eos_token_id=tokenizer.eos_token_id,
    shuffle=True,
    append_concat_token=True,
    add_special_tokens=True,
    formatting_func=formatting_prompts_func,
)
cld_val_dataset = ConstantLengthDataset(
    tokenizer=tokenizer,
    dataset=val_dataset,
    #dataset_text_field='feature',
    seq_length=int(seq_length/2),
    eos_token_id=tokenizer.eos_token_id,
    shuffle=True,
    append_concat_token=True,
    add_special_tokens=True,
    formatting_func=formatting_prompts_func,
)

In [None]:
# parameters.training["gradient_accumulation_steps"] = 8
# parameters.training["per_device_train_batch_size"] = 5
# parameters.training["per_device_eval_batch_size"] = 2
# # Try these
# parameters.training["lr_scheduler_type"] = "cosine"
# parameters.training["max_grad_norm"] = 1.0
# parameters.training["output_dir"] = "results"
# parameters.training["num_train_epochs"] = 12
# parameters.training["eval_strategy"] = "steps"
# parameters.training["save_strategy"] = "steps"        # Save model after every step
# parameters.training["metric_for_best_model"] = "eval_loss"
# parameters.training["greater_is_better"] = False
# parameters.training["group_by_length"] = False # we already use CLD
# parameters.training["save_total_limit"] = 3           # Keep only the 3 most recent checkpoints
# parameters.training["load_best_model_at_end"] = True # EarlyStoppingCallback requires load_best_model_at_end = True
# parameters.training["warmup_steps"] = 50
# parameters.training["learning_rate"] = 3e-5
# parameters.training["save_steps"] = 25
# parameters.training["weight_decay"] = 0.0
# early_stopping_patience = 75

In [None]:
cld_train_dataset

In [None]:
dict(parameters.training)

In [None]:
# def formatting_prompts_func_old(example):
#     output_texts = []
#     for i in range(len(example['inputs'])):
#         language_code = LANG_TO_CODE[example["language"]]
#         language_token = LANG_CODE_TO_TOKEN[language_code]

#         user_input = example['inputs'][i]
#         prompt_template = multigec_prompts[example["language"]]
#         instruction = prompt_template.format(original_text=user_input)
#         target = example['targets'][i]


#         text = f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{language_token}{instruction}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{target}"
#         output_texts.append(text)
#     return output_texts

# def formatting_prompts_func(example):
#     language_code = LANG_TO_CODE[example["language"]]
#     language_token = LANG_CODE_TO_TOKEN[language_code]

#     user_input = example['feature']
#     prompt_template = multigec_prompts[example["language"]]
#     instruction = prompt_template.format(original_text=user_input)
#     target = example['target']

#     text = f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{language_token}{instruction}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{target}"

#     return text

In [None]:
run = wandb.init(
    project=wandb_project_name,
    job_type="training",
    anonymous="allow"
)

wandb.config.update(dict(parameters.training))

peft_config = LoraConfig(
    r=parameters.lora.r,
    lora_alpha=parameters.lora.lora_alpha,
    target_modules=list(parameters.lora.target_modules),
    bias=parameters.lora.bias,
    task_type=parameters.lora.task_type
)
training_arguments = SFTConfig(
    **parameters.training,
    packing=True,
    max_seq_length=seq_length,
    output_dir=out_model_dir,
)
trainer = SFTTrainer(
    model=model,
    train_dataset=cld_train_dataset,
    eval_dataset=cld_val_dataset,
    peft_config=peft_config,
    args=training_arguments,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=parameters.early_stopping.early_stopping_patience)],
    #formatting_func=formatting_prompts_func,
)

with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    trainer.unload() # Should remove peft warnings
    trainer.train(resume_from_checkpoint=True)

In [None]:
import time

run_name = wandb.run.name
model_out_path = f"/gdrive/MyDrive/models/multigec/{run_name}"
os.makedirs(model_out_path, exist_ok=True)
trainer.model.save_pretrained(model_out_path)
wandb.save(model_out_path + "/*")
model.config.use_cache = True
model.eval()
time.sleep(30)

In [None]:
test_df = multigec_df.loc[multigec_df.loc[:, "split"] == "test"]

def make_correction(
    model,
    input: pd.Series,
    tokenizer,
    parameters: dict,
    seq_lenght: int
):
    language: str = input["language"]
    language_code: str = LANG_TO_CODE[language]
    language_token: str = LANG_CODE_TO_TOKEN[language_code]
    prompt_template: PromptTemplate = multigec_prompts[language]

    # Prepare inputs
    inputs: list[str] = sentences(input)
    inputs = [language_token + prompt_template.format(original_text=input) for input in inputs]
    inputs: list[dict[str, str]] = get_message_format(inputs)
    input_ids = tokenizer.apply_chat_template(
        inputs,
        tokenize=True,
        add_generation_prompt=True,
        padding=True,
        return_tensors="pt",
    )
    input_ids = input_ids.to(model.device)
    prompt_padded_len = len(input_ids[0])

    # Generate corrections
    gen_tokens = model.generate(
        input_ids,
        temperature=parameters.baseline.temperature,
        top_p=parameters.baseline.top_p,
        top_k=parameters.baseline.top_k,
        max_new_tokens=seq_lenght,
        do_sample=True,
    )
    gen_tokens = [
        gt[prompt_padded_len:] for gt in gen_tokens
    ]
    outputs: list[str] = tokenizer.batch_decode(
        gen_tokens,
        skip_special_tokens=True
    )
    # Join divided texts if any
    correction = "".join(outputs)

    return correction

test_df.progress_apply(lambda x: make_correction(
    model=model,
    input=x,
    tokenizer=tokenizer,
    parameters=parameters,
    seq_lenght=seq_length,
))
test_df

In [None]:
test_df.to_csv(out_path)

In [None]:
wandb.finish()

In [None]:
raise Exception e

In [None]:
!nvidia-smi

In [None]:
torch.cuda.empty_cache()

In [None]:
!pip install pynvml -q

In [None]:
from google.colab import runtime
runtime.unassign()

In [None]:
import gc

gc.collect()

In [None]:
import time
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo

def wait_until_enough_gpu_memory(min_memory_available, max_retries=10, sleep_time=5):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(torch.cuda.current_device())

    for _ in range(max_retries):
        clear_gpu_memory()
        info = nvmlDeviceGetMemoryInfo(handle)
        if info.free >= min_memory_available:
            break
        print(f"Waiting for {min_memory_available} bytes of free GPU memory. Retrying in {sleep_time} seconds...")
        time.sleep(sleep_time)
    else:
        raise RuntimeError(f"Failed to acquire {min_memory_available} bytes of free GPU memory after {max_retries} retries.")

# Usage example
min_memory_available = 38 * 1024 * 1024 * 1024  # 30GB
clear_gpu_memory()
wait_until_enough_gpu_memory(min_memory_available)

In [None]:
del model

In [None]:
del trainer