# Imports and configs

In [None]:
from transformers import Gemma2ForSequenceClassification, GemmaTokenizerFast
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from concurrent.futures import ThreadPoolExecutor
from peft import PeftModel
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch

In [None]:
class CFG:
    test_path = "/kaggle/input/wsdm-cup-multilingual-chatbot-arena/test.parquet"
    
    gemma_dir = "/kaggle/input/gemma-2-9b-it-bnb-4bit-unsloth/transformers/default/1"
    lora_dir = "/kaggle/input/wsdm-cup-gemma-2-9b-4-bit-qlora/gemma2-9b-4bit/gemma-2-9b-it-bnb-4bit-3072-8-f0/checkpoint-2422"
    
    max_length = 3072
    batch_size = 4

# Loading data

In [None]:
test = pd.read_parquet(CFG.test_path).fillna('')

# Tokenizing

In [None]:
def tokenize(tokenizer, prompt, response_a, response_b, max_length=CFG.max_length):
    prompt = ["<prompt>: " + t for t in prompt]
    response_a = ["\n\n<response_a>: " + t for t in response_a]
    response_b = ["\n\n<response_b>: " + t for t in response_b]
    
    texts = [p + r_a + r_b for p, r_a, r_b in zip(prompt, response_a, response_b)]
    tokenized = tokenizer(texts, max_length=max_length, truncation=True)
    
    return tokenized['input_ids'], tokenized['attention_mask']

In [None]:
tokenizer = GemmaTokenizerFast.from_pretrained(CFG.gemma_dir)
tokenizer.add_eos_token = True
tokenizer.padding_side = "right"

In [None]:
for col in ['prompt', 'response_a', 'response_b']:
    test[col] = test[col].fillna('')
    text_list = []
    if col == 'prompt':
        max_no = 1024
        s_no = 511
        e_no = -512
    else:
        max_no = 3072
        s_no = 1535
        e_no = -1536
    for text in tqdm(test[col]):
        encoded = tokenizer(text, return_offsets_mapping=True)
        if len(encoded['input_ids']) > max_no:
            start_idx, end_idx = encoded['offset_mapping'][s_no]
            new_text = text[:end_idx]
            start_idx, end_idx = encoded['offset_mapping'][e_no]
            new_text = new_text + "\n(snip)\n" + text[start_idx:]
            text = new_text
        text_list.append(text)
    test[col] = text_list

In [None]:
data = pd.DataFrame()
data["id"] = test["id"]
data["input_ids"], data["attention_mask"] = tokenize(tokenizer, test["prompt"], test["response_a"], test["response_b"])
data["length"] = data["input_ids"].apply(len)

aug_data = pd.DataFrame()
aug_data["id"] = test["id"]
# swap response_a & response_b
aug_data['input_ids'], aug_data['attention_mask'] = tokenize(tokenizer, test["prompt"], test["response_b"], test["response_a"])
aug_data["length"] = aug_data["input_ids"].apply(len)

# Model

In [None]:
model_0 = Gemma2ForSequenceClassification.from_pretrained(
    CFG.gemma_dir,
    device_map=torch.device("cuda:0"),
    use_cache=False,
)


model_1 = Gemma2ForSequenceClassification.from_pretrained(
    CFG.gemma_dir,
    device_map=torch.device("cuda:1"),
    use_cache=False,
)

In [None]:
model_0 = PeftModel.from_pretrained(model_0, CFG.lora_dir)
model_1 = PeftModel.from_pretrained(model_1, CFG.lora_dir)

In [None]:
model_0.eval()
model_1.eval()

# Inference

In [None]:
@torch.no_grad()
@torch.cuda.amp.autocast()
def inference(df, model, device, batch_size, max_length=CFG.max_length):
    winners = []
    
    for start_idx in range(0, len(df), batch_size):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp["input_ids"].to_list()
        attention_mask = tmp["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
        outputs = model(**inputs.to(device))
        proba = outputs.logits.softmax(-1).cpu()
        
        winners.extend(proba[:, 1].tolist())
    
    df['winner'] = winners
    
    return df

In [None]:
data = pd.concat([data, aug_data]).reset_index(drop=True)
data['index'] = np.arange(len(data), dtype=np.int32)
data = data.sort_values("length", ascending=False)

In [None]:
data_dict = {}
data_dict[0] = data[data["length"] > 1536].reset_index(drop=True)
data_dict[1] = data[data["length"] <= 1536].reset_index(drop=True)

In [None]:
result_df = []
for i, batch_size in enumerate([2, 16]):
    if len(data_dict[i]) == 0:
        continue
        
    sub_1 = data_dict[i].iloc[0::2].copy()
    sub_2 = data_dict[i].iloc[1::2].copy()
    
    with ThreadPoolExecutor(max_workers=2) as executor:
        results = executor.map(
            inference, 
            (sub_1, sub_2), 
            (model_0, model_1), 
            (torch.device("cuda:0"), torch.device("cuda:1")), 
            (batch_size, batch_size)
        )
        
    result_df.append(pd.concat(list(results), axis=0))

In [None]:
result_df = pd.concat(result_df).sort_values('index').reset_index(drop=True)
proba = result_df['winner'].values[:len(aug_data)]
tta_proba = 1 - result_df['winner'].values[len(aug_data):]
proba = (proba + tta_proba) / 2
result_df = result_df[:len(aug_data)].copy()
result_df['winner'] = proba

# Submission

In [None]:
sub = result_df[['id', 'winner']].copy()
sub['winner'] = np.where(sub['winner'] < 0.5, 'model_a', 'model_b')
sub.to_csv('submission.csv', index=False)
sub.head()