# Imports and configs

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from transformers import Gemma2ForSequenceClassification, GemmaTokenizerFast
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from sklearn.metrics import log_loss, accuracy_score
from sklearn.model_selection import StratifiedKFold
from datasets import Dataset
import pandas as pd
import warnings
import joblib
import torch

warnings.filterwarnings("ignore")

In [None]:
class CFG:
    checkpoint = "unsloth/gemma-2-9b-it-bnb-4bit"
    max_length = 3072
    n_splits = 5
    current_fold = 0
    optim_type = "adamw_8bit"
    per_device_train_batch_size = 2
    per_device_eval_batch_size = 8
    gradient_accumulation_steps = 2
    n_epochs = 1
    freeze_layers = 16
    lr = 2e-4
    warmup_steps = 20
    lora_r = 16
    lora_alpha = lora_r * 2
    lora_dropout = 0.05
    lora_bias = "none"
    seed=42

In [None]:
CHECKPOINT_BASE_NAME = f"gemma-2-9b-it-bnb-4bit-{CFG.max_length}-{CFG.per_device_train_batch_size}-f{CFG.current_fold}"

# Loading data

In [None]:
dataset = pd.read_parquet("/kaggle/input/wsdm-cup-multilingual-chatbot-arena/train.parquet").sample(100).reset_index(drop=True)
dataset["winner"] = dataset["winner"].map({"model_a": 0, "model_b": 1})

In [None]:
skf = StratifiedKFold(n_splits=CFG.n_splits, shuffle=True, random_state=CFG.seed)
for i, (_, val_index) in enumerate(skf.split(dataset, dataset["winner"])):
    dataset.loc[val_index, "fold"] = i

In [None]:
train = dataset[dataset["fold"] != CFG.current_fold]
val = dataset[dataset["fold"] == CFG.current_fold]

In [None]:
train = Dataset.from_pandas(train)
val = Dataset.from_pandas(val)

# Tokenizing

In [None]:
tokenizer = GemmaTokenizerFast.from_pretrained(CFG.checkpoint)
tokenizer.add_eos_token = True
tokenizer.padding_side = "right"

In [None]:
class Tokenizer:
    def __init__(self, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, batch):
        prompt = ["<prompt>: " + t for t in batch["prompt"]]
        response_a = ["\n\n<response_a>: " + t for t in batch["response_a"]]
        response_b = ["\n\n<response_b>: " + t for t in batch["response_b"]]
        texts = [p + r_a + r_b for p, r_a, r_b in zip(prompt, response_a, response_b)]
        tokenized = self.tokenizer(texts, max_length=self.max_length, truncation=True)
        return {**tokenized, "labels": batch["winner"]}

In [None]:
encode = Tokenizer(tokenizer, max_length=CFG.max_length)

train = train.map(encode, batched=True)
val = val.map(encode, batched=True)

# Modeling

In [None]:
lora_config = LoraConfig(
    r=CFG.lora_r,
    lora_alpha=CFG.lora_alpha,
    target_modules=["q_proj", "k_proj", "v_proj"],
    layers_to_transform=[i for i in range(42) if i >= CFG.freeze_layers],
    lora_dropout=CFG.lora_dropout,
    bias=CFG.lora_bias,
    task_type=TaskType.SEQ_CLS,
)

In [None]:
model = Gemma2ForSequenceClassification.from_pretrained(
    CFG.checkpoint,
    num_labels=2,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

# Training

In [None]:
def compute_metrics(eval_preds):
    preds = eval_preds.predictions
    labels = eval_preds.label_ids
    probs = torch.from_numpy(preds).float().softmax(-1).numpy()
    loss = log_loss(y_true=labels, y_pred=probs)
    acc = accuracy_score(y_true=labels, y_pred=preds.argmax(-1))
    return {"acc": acc, "log_loss": loss}

In [None]:
training_args = TrainingArguments(
    output_dir=CHECKPOINT_BASE_NAME,
    overwrite_output_dir=True,
    num_train_epochs=CFG.n_epochs,
    per_device_train_batch_size=CFG.per_device_train_batch_size,
    gradient_accumulation_steps=CFG.gradient_accumulation_steps,
    per_device_eval_batch_size=CFG.per_device_eval_batch_size,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="steps",
    save_steps=200,
    save_total_limit=1,
    optim=CFG.optim_type,
    fp16=True,
    learning_rate=CFG.lr,
    warmup_steps=CFG.warmup_steps,
    report_to="none"
)

In [None]:
trainer = Trainer(
    args=training_args,
    model=model,
    tokenizer=tokenizer,
    train_dataset=train,
    eval_dataset=val,
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer)
)

In [None]:
trainer.train()

# Saving OOF file

In [None]:
y_true = val["winner"]
logits = trainer.predict(val).predictions
y_pred_probs = torch.from_numpy(logits).float().softmax(-1).numpy()

In [None]:
acc = accuracy_score(y_true=y_true, y_pred=y_pred_probs.argmax(-1))
print(f"Fold {CFG.current_fold} - Accuracy: {acc:.4f}")

In [None]:
joblib.dump(y_pred_probs, f"y_pred_probs_fold_{CFG.current_fold}_acc_{acc:.6f}.pkl")