# Imports and configs

In [1]:
from transformers import Gemma2ForSequenceClassification, GemmaTokenizerFast
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from concurrent.futures import ThreadPoolExecutor
from timeit import default_timer as timer
from peft import PeftModel
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch

In [2]:
class CFG:
    test_path = "/kaggle/input/wsdm-cup-multilingual-chatbot-arena/test.parquet"
    
    gemma_dir = "/kaggle/input/gemma-2-9b-4bit-it-unsloth/transformers/default/1/gemma-2-9b-it-4bit-unsloth_old"
    lora_dir = "/kaggle/input/wsdm-cup-gemma-2-9b-4-bit-qlora/gemma2-9b-4bit/gemma-2-9b-it-bnb-4bit-3072-8/checkpoint-3027"
    
    max_length = 3072
    batch_size = 4

# Loading data

In [3]:
test = pd.read_parquet(CFG.test_path).fillna('')

In [4]:
if len(test) > 10_000:
    time_limit = int(3600 * 12) 
else:
    time_limit = int(3600 * 4.75)

# Tokenizing

In [5]:
def tokenize(tokenizer, prompt, response_a, response_b, max_length=CFG.max_length):
    prompt = ["<prompt>: " + t for t in prompt]
    response_a = ["\n\n<response_a>: " + t for t in response_a]
    response_b = ["\n\n<response_b>: " + t for t in response_b]
    
    texts = [p + r_a + r_b for p, r_a, r_b in zip(prompt, response_a, response_b)]
    tokenized = tokenizer(texts, max_length=max_length, truncation=True)
    
    return tokenized['input_ids'], tokenized['attention_mask']

In [6]:
tokenizer = GemmaTokenizerFast.from_pretrained(CFG.gemma_dir)
tokenizer.add_eos_token = True
tokenizer.padding_side = "right"

In [7]:
for col in ['prompt', 'response_a', 'response_b']:
    test[col] = test[col].fillna('')
    text_list = []
    if col == "prompt":
        max_no = 512
        s_no = 255
        e_no = -256
    else:
        max_no = 3072
        s_no = 1535
        e_no = -1536
    for text in tqdm(test[col]):
        encoded = tokenizer(text, return_offsets_mapping=True)
        if len(encoded['input_ids']) > max_no:
            start_idx, end_idx = encoded['offset_mapping'][s_no]
            new_text = text[:end_idx]
            start_idx, end_idx = encoded['offset_mapping'][e_no]
            new_text = new_text + "\n(snip)\n" + text[start_idx:]
            text = new_text
        text_list.append(text)
    test[col] = text_list

100%|██████████| 3/3 [00:00<00:00, 856.04it/s]
100%|██████████| 3/3 [00:00<00:00, 786.43it/s]
100%|██████████| 3/3 [00:00<00:00, 665.87it/s]


In [8]:
data = pd.DataFrame()
data["id"] = test["id"]
data["input_ids"], data["attention_mask"] = tokenize(tokenizer, test["prompt"], test["response_a"], test["response_b"])
data["length"] = data["input_ids"].apply(len)

aug_data = pd.DataFrame()
aug_data["id"] = test["id"]
# swap response_a & response_b
aug_data['input_ids'], aug_data['attention_mask'] = tokenize(tokenizer, test["prompt"], test["response_b"], test["response_a"])
aug_data["length"] = aug_data["input_ids"].apply(len)

# Model

In [9]:
model_0 = Gemma2ForSequenceClassification.from_pretrained(
    CFG.gemma_dir,
    device_map=torch.device("cuda:0"),
    use_cache=False,
)

model_1 = Gemma2ForSequenceClassification.from_pretrained(
    CFG.gemma_dir,
    device_map=torch.device("cuda:1"),
    use_cache=False,
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Some weights of Gemma2ForSequenceClassification were not initialized from the model checkpoint at /kaggle/input/gemma-2-9b-4bit-it-unsloth/transformers/default/1/gemma-2-9b-it-4bit-unsloth_old and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Some weights of Gemma2ForSequenceClassification were not initialized from the model checkpoint at /kaggle/input/gemma-2-9b-4bit-it-unsloth/transformers/default/1/gemma-2-9b-it-4bit-unsloth_old and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for p

In [10]:
model_0 = PeftModel.from_pretrained(model_0, CFG.lora_dir)
model_1 = PeftModel.from_pretrained(model_1, CFG.lora_dir)

In [11]:
model_0.eval()
model_1.eval()

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): Gemma2ForSequenceClassification(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 3584, padding_idx=0)
        (layers): ModuleList(
          (0-15): 16 x Gemma2DecoderLayer(
            (self_attn): Gemma2SdpaAttention(
              (q_proj): Linear4bit(in_features=3584, out_features=4096, bias=False)
              (k_proj): Linear4bit(in_features=3584, out_features=2048, bias=False)
              (v_proj): Linear4bit(in_features=3584, out_features=2048, bias=False)
              (o_proj): Linear4bit(in_features=4096, out_features=3584, bias=False)
              (rotary_emb): Gemma2RotaryEmbedding()
            )
            (mlp): Gemma2MLP(
              (gate_proj): Linear4bit(in_features=3584, out_features=14336, bias=False)
              (up_proj): Linear4bit(in_features=3584, out_features=14336, bias=False)
              (down_proj): Linear4bit(in_features=14336, out_features=3584,

# Inference

In [12]:
@torch.no_grad()
@torch.cuda.amp.autocast()
def inference(df, model, device, batch_size, max_length=CFG.max_length):
    winners = []
    
    for start_idx in range(0, len(df), batch_size):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp["input_ids"].to_list()
        attention_mask = tmp["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
        outputs = model(**inputs.to(device))
        proba = outputs.logits.softmax(-1).cpu()
        
        winners.extend(proba[:, 1].tolist())
    
    df['winner'] = winners
    
    return df

  @torch.cuda.amp.autocast()


## No TTA

In [13]:
global_timer = timer()

In [14]:
data['index'] = np.arange(len(data), dtype=np.int32)
data = data.sort_values("length", ascending=False)

In [15]:
data_dict = {}
data_dict[0] = data[data["length"] > 1024].reset_index(drop=True)
data_dict[1] = data[data["length"] <= 1024].reset_index(drop=True)

In [16]:
result_df = []
for i, batch_size in enumerate([CFG.batch_size, CFG.batch_size]):
    if len(data_dict[i]) == 0:
        continue
        
    sub_1 = data_dict[i].iloc[0::2].copy()
    sub_2 = data_dict[i].iloc[1::2].copy()
    
    with ThreadPoolExecutor(max_workers=2) as executor:
        results = executor.map(
            inference, 
            (sub_1, sub_2), 
            (model_0, model_1), 
            (torch.device("cuda:0"), torch.device("cuda:1")), 
            (batch_size, batch_size)
        )
        
    result_df.append(pd.concat(list(results), axis=0))

In [17]:
result_df = pd.concat(result_df).sort_values('index').reset_index(drop=True)

## TTA

In [18]:
aug_data['index'] = np.arange(len(aug_data), dtype=np.int32)
aug_data = aug_data.sort_values("length", ascending=False)

In [19]:
CONFIDENCE_THRESHOLD = 0.2
not_confident_mask = abs(result_df['winner'] - 0.5) < CONFIDENCE_THRESHOLD

aug_data = aug_data[aug_data['index'].isin(result_df[not_confident_mask]['index'])]

In [20]:
aug_data_dict = {}
aug_data_dict[0] = aug_data[aug_data["length"] > 1024].reset_index(drop=True)
aug_data_dict[1] = aug_data[aug_data["length"] <= 1024].reset_index(drop=True)

In [21]:
aug_result_df = []
for i, batch_size in enumerate([CFG.batch_size, CFG.batch_size]):
    if len(aug_data_dict[i]) == 0:
        continue

    if timer() - global_timer > (time_limit - 300):
        break
        
    sub_1 = aug_data_dict[i].iloc[0::2].copy()
    sub_2 = aug_data_dict[i].iloc[1::2].copy()
    
    with ThreadPoolExecutor(max_workers=2) as executor:
        results = executor.map(
            inference, 
            (sub_1, sub_2), 
            (model_0, model_1), 
            (torch.device("cuda:0"), torch.device("cuda:1")), 
            (batch_size, batch_size)
        )
        
    aug_result_df.append(pd.concat(list(results), axis=0))

In [22]:
time_taken = timer() - global_timer

print(f'time for 10_000 (max  4.5 hr): {10_000/3*time_taken/60/60:4.1f}')
print(f'time for 25_000 (max 12.0 hr): {25_000/3*time_taken/60/60:4.1f}')

time for 10_000 (max  4.5 hr):  5.4
time for 25_000 (max 12.0 hr): 13.5


## Combining the results

In [23]:
if len(aug_result_df) > 0:
    aug_result_df = pd.concat(aug_result_df).sort_values('index').reset_index(drop=True)
    aug_result_df["winner"] = 1 - aug_result_df['winner']
    
    result_df = result_df.merge(
        aug_result_df[['index', 'winner']], 
        on='index', 
        how='left', 
        suffixes=('', '_aug')
    )
    
    mask = result_df['winner_aug'].notna()
    result_df.loc[mask, 'winner'] = (result_df.loc[mask, 'winner'] + result_df.loc[mask, 'winner_aug']) / 2
    
    result_df = result_df.drop('winner_aug', axis=1)

# Submission

In [24]:
sub = result_df[['id', 'winner']].copy()
sub['winner'] = np.where(sub['winner'] < 0.5, 'model_a', 'model_b')
sub.to_csv('submission.csv', index=False)
sub.head()

Unnamed: 0,id,winner
0,327228,model_b
1,1139415,model_a
2,1235630,model_a
