<a href="https://colab.research.google.com/github/Reennon/multigec-models/blob/main/notebooks/aya_expanse_8b/multigec/multigec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os

from google.colab import userdata

os.environ["GIT_TOKEN"] = userdata.get('git_token')

In [None]:
!git clone https://$GIT_TOKEN@github.com/Reennon/omnigec-models.git

In [None]:
%cd omnigec-models

In [None]:
!git pull

In [None]:
!pip install -U bitsandbytes peft accelerate datasets sentencepiece wandb python-dotenv wtpsplit -q
!pip install flash-attn --no-build-isolation -q
!pip install wtpsplit==2.1.1 -q
!pip install syntok==1.4.4 -q
!pip install omegaconf -q
!pip install wandb -q
!pip install --upgrade transformers trl -q
!pip install pandas numpy -q

In [None]:
from google.colab import drive
drive.mount('/gdrive')

In [None]:
import os

from omegaconf import OmegaConf
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import login
from src.utils.multigec import sentences, LANG_TO_CODE, LANG_CODE_TO_TOKEN
from src.utils.aya_utils import training_formatting_prompts_func as formatting_prompts_func
from langchain_core.prompts import PromptTemplate

from src.instruction_templates import multigec_prompts

import torch
import wandb

from transformers import BitsAndBytesConfig
from tqdm import tqdm
from trl.trainer import ConstantLengthDataset
import pandas as pd
from datasets import Dataset
from transformers.trainer_callback import EarlyStoppingCallback

from transformers import TrainingArguments
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

tqdm.pandas()

In [None]:
parameters = OmegaConf.load("./params/aya_expanse_8b.yaml")

In [None]:
track                   = "minimal"
model_name              = "aya-expanse-8b"
fine_tuned_model_name   = f"aya-expanse-8b-multigec-{track}"
experiment_name         = f"multigec-{track}-{model_name}"

hf_key                  = userdata.get("hf_key")
secret_wandb            = userdata.get("wandb_key")

in_path                 = f"/gdrive/MyDrive/omnigec/datasets/multigec_{track}.csv"
out_path                = f"/gdrive/MyDrive/omnigec/preds/multigec_test_{track}.csv"
out_model_dir           = f"/gdrive/MyDrive/omnigec/models/{fine_tuned_model_name}"

QUANTIZE_4BIT           = True
device                  = "cuda:0"

In [None]:
wandb_project_name = f'{model_name.upper()}-multigec-{track}'

wandb.login(key = secret_wandb)

In [None]:
login(hf_key)

In [None]:
!env TORCH_USE_CUDA_DSA=1

In [None]:
checkpoint = "CohereForAI/aya-expanse-8b"
quantization_config = None
if QUANTIZE_4BIT:
  quantization_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_use_double_quant=True,
      bnb_4bit_compute_dtype=torch.bfloat16,
  )
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
config = AutoConfig.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    config=config,
    quantization_config=quantization_config,
    torch_dtype="bfloat16",
    device_map=device,
    attn_implementation="flash_attention_2",
)

In [None]:
multigec_df = pd.read_csv(in_path)
train_df = multigec_df.loc[multigec_df.loc[:, "split"] == "train"]
val_df = multigec_df.loc[multigec_df.loc[:, "split"] == "val"]

In [None]:
num_added_toks = tokenizer.add_tokens(
    [v for v in LANG_CODE_TO_TOKEN.values()],
    special_tokens=True
)
model.resize_token_embeddings(len(tokenizer))

In [None]:
training_dataset = Dataset.from_pandas(train_df).shuffle()
val_dataset = Dataset.from_pandas(val_df).shuffle()

seq_length = 1600

cld_train_dataset = ConstantLengthDataset(
    tokenizer=tokenizer,
    dataset=training_dataset,
    seq_length=seq_length,
    eos_token_id=tokenizer.eos_token_id,
    shuffle=True,
    append_concat_token=True,
    add_special_tokens=True,
    formatting_func=formatting_prompts_func,
)
cld_val_dataset = ConstantLengthDataset(
    tokenizer=tokenizer,
    dataset=val_dataset,
    seq_length=int(seq_length/2),
    eos_token_id=tokenizer.eos_token_id,
    shuffle=True,
    append_concat_token=True,
    add_special_tokens=True,
    formatting_func=formatting_prompts_func,
)

In [None]:
parameters.training["per_device_train_batch_size"] = 6

In [None]:
run = wandb.init(
    project=wandb_project_name,
    job_type="training",
    anonymous="allow"
)

wandb.config.update(dict(parameters.training))

peft_config = LoraConfig(
    r=parameters.lora.r,
    lora_alpha=parameters.lora.lora_alpha,
    target_modules=list(parameters.lora.target_modules),
    bias=parameters.lora.bias,
    task_type=parameters.lora.task_type
)
training_arguments = SFTConfig(
    **parameters.training,
    packing=True,
    max_seq_length=seq_length,
    output_dir=out_model_dir,
)
trainer = SFTTrainer(
    model=model,
    train_dataset=cld_train_dataset,
    eval_dataset=cld_val_dataset,
    peft_config=peft_config,
    args=training_arguments,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=parameters.early_stopping.early_stopping_patience)],
)

with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    trainer.train()

In [None]:
from google.colab import runtime
runtime.unassign()