<a href="https://colab.research.google.com/github/Reennon/multigec-models/blob/main/notebooks/gemma_3_12b/multigec/multigec_prediction_minimal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os

from google.colab import userdata

os.environ["GIT_TOKEN"] = userdata.get('git_token')

In [2]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [3]:
!git clone https://$GIT_TOKEN@github.com/Reennon/multigec-models.git

fatal: destination path 'multigec-models' already exists and is not an empty directory.


In [4]:
%cd multigec-models

/content/multigec-models


In [5]:
!git pull

Already up to date.


In [6]:
!pip install -U bitsandbytes peft accelerate datasets sentencepiece wandb python-dotenv wtpsplit -q
!pip install flash-attn --no-build-isolation -q
!pip install wtpsplit==2.1.1 -q
!pip install syntok==1.4.4 -q
!pip install omegaconf -q
!pip install wandb -q
!pip install --upgrade git+https://github.com/huggingface/transformers.git -q
!pip install --upgrade trl -q
!pip install pandas numpy -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
adapters 1.1.0 requires transformers~=4.47.1, but you have transformers 4.52.0.dev0 which is incompatible.
wtpsplit 2.1.1 requires huggingface-hub==0.25.2, but you have huggingface-hub 0.30.1 which is incompatible.[0m[31m
[0m

In [7]:
import os

from omegaconf import OmegaConf
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import login
from src.utils.multigec import sentences, LANG_TO_CODE, LANG_CODE_TO_TOKEN
from langchain_core.prompts import PromptTemplate

from src.instruction_templates import multigec_prompts

import torch
import wandb

from transformers import BitsAndBytesConfig
from tqdm import tqdm
from trl.trainer import ConstantLengthDataset
import pandas as pd
from datasets import Dataset
from transformers.trainer_callback import EarlyStoppingCallback

from transformers import TrainingArguments
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, PeftModelForCausalLM

tqdm.pandas()

In [8]:
parameters = OmegaConf.load("./params/gemma_3_12b.yaml")

In [9]:
track     = "minimal"
model_name = "gemma-3-12b-it"
fine_tuned_model_name = f"gemma-3-12b-it-multigec"
experiment_name = f"multigec-{track}-{model_name}"

hf_key   = userdata.get("hf_key")
secret_wandb = userdata.get("wandb_key")
in_path  = f"/gdrive/MyDrive/multigec/datasets/multigec_{track}.csv"

# Path where the output will be saved to
out_path = f"/gdrive/MyDrive/multigec/preds/{model_name}/multigec_test_{track}.csv"
out_model_dir = f"/gdrive/MyDrive/multigec/models/multigec/{fine_tuned_model_name}"
QUANTIZE_4BIT = True
device   = "cuda:0"

In [10]:
login(hf_key)

In [11]:
!env TORCH_USE_CUDA_DSA=1 -q

env: ‘-q’: No such file or directory


In [12]:
base_model = "google/gemma-3-12b-it"
saved_checkpoint = out_model_dir + "/checkpoint-500-minimal-best"

quantization_config = None
if QUANTIZE_4BIT:
  quantization_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_use_double_quant=True,
      bnb_4bit_compute_dtype=torch.bfloat16,
      bnb_4bit_quant_storage=torch.bfloat16,
  )
tokenizer = AutoTokenizer.from_pretrained(saved_checkpoint)
config = AutoConfig.from_pretrained(base_model)
config.text_config.use_cache = False
base_model_instance = AutoModelForCausalLM.from_pretrained(
    base_model,
    config=config,
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    device_map=device,
    attn_implementation="eager",
)
base_model_instance.resize_token_embeddings(len(tokenizer))
model = PeftModelForCausalLM.from_pretrained(
    base_model_instance,
    saved_checkpoint,
    torch_dtype=torch.bfloat16,
    device_map=device,
    ignore_mismatched_sizes=True
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You have set `use_cache` to `False`, but cache_implementation is set to hybrid. cache_implementation will have no effect.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [13]:
multigec_df = pd.read_csv(in_path)
train_df = multigec_df.loc[multigec_df.loc[:, "split"] == "train"]
val_df = multigec_df.loc[multigec_df.loc[:, "split"] == "val"]
test_df = multigec_df.loc[multigec_df.loc[:, "split"] == "test"]
test_df.target = None

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df.target = None


In [14]:
def formatting_prompts_func(example):
    language_code = LANG_TO_CODE[example["language"]]
    # Since special tokens for Gemma models does not have |, we remove them
    language_token = LANG_CODE_TO_TOKEN[language_code].replace("|", "")

    user_input = example['feature']
    prompt_template = multigec_prompts[example["language"]].prompt_template
    instruction = prompt_template.format(original_text=user_input)

    text = f"<start_of_turn>user\n{language_token}{instruction}<end_of_turn>\n<start_of_turn>model\n"

    return text

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

max_new_tokens = 1600
batch_size = 15
save_each = 15

# Assume test_df already exists and out_path is defined.
# Filter rows that need processing (target is empty/NaN).
to_process_df = test_df[test_df["target"].isna()].reset_index()  # preserve original index in "index" column

def collate_fn(examples):
    indices = [ex["index"] for ex in examples]
    texts = [formatting_prompts_func(example) for example in examples]
    tokenized = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    tokenized["indices"] = indices
    return tokenized


# Create DataLoader using only the rows that need processing.
dataloader = DataLoader(
    to_process_df.to_dict(orient="records"),
    batch_size=batch_size,
    collate_fn=collate_fn
)

processed_rows = 0
for batch in tqdm(dataloader):
    input_ids = batch["input_ids"].to(model.device)
    attention_mask = batch.get("attention_mask", None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(model.device)
    prompt_padded_len = len(input_ids[0])

    gen_tokens = model.generate(
        input_ids,
        attention_mask=attention_mask,
        temperature=parameters.baseline.temperature,
        top_p=parameters.baseline.top_p,
        top_k=parameters.baseline.top_k,
        max_new_tokens=max_new_tokens,
        do_sample=True,
    )

    # Remove the prompt tokens from the generated tokens
    gen_tokens = [gt[prompt_padded_len:] for gt in gen_tokens]

    # Decode generated tokens to text corrections
    corrections = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
    corrections = ["".join(c) for c in corrections]

    # Update the original DataFrame using the indices provided in the batch.
    for idx, corr in zip(batch["indices"], corrections):
        test_df.loc[idx, "target"] = corr
        processed_rows += 1

        if processed_rows % save_each == 0:
            test_df.to_csv(out_path, index=False)
            print(f"Saved progress after processing {processed_rows} rows.")

# Final save after processing all batches.
test_df.to_csv(out_path, index=False)
print("Final save complete!")


  0%|          | 0/141 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  1%|          | 1/141 [06:09<14:23:19, 370.00s/it]

Saved progress after processing 20 rows.


  1%|▏         | 2/141 [07:53<8:14:06, 213.28s/it] 

Saved progress after processing 40 rows.


In [None]:
import gc

gc.collect(generation=2)

In [None]:
torch.cuda.empty_cache()

In [None]:
from google.colab import runtime
runtime.unassign()