In [None]:
from itertools import zip_longest
from  tqdm import tqdm
from sklearn.model_selection import StratifiedKFold
from transformers import BertTokenizer, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
import  pandas as pd
test = pd.read_csv('/kaggle/input/lmsys-chatbot-arena/test.csv')


def process(example):
    sentences = [s.strip('"').replace('[MASK]','') for s in example['prompt'].strip('[]').split('","')]
    sentences_a = [s.strip('"').replace('[MASK]','') for s in example['response_a'].strip('[]').split('","')]
    sentences_b = [s.strip('"').replace('[MASK]','') for s in example['response_b'].strip('[]').split('","')]
    texts_a = [p for pair in zip_longest(sentences, sentences_a, fillvalue='') for p in pair if p]
    texts_b = [p for pair in zip_longest(sentences, sentences_b, fillvalue='') for p in pair if p]
    return pd.Series([' '.join(sentences), ' '.join(sentences_a), ' '.join(sentences_b), '\n'.join(texts_a), '\n'.join(texts_b)], index=['prompt', 'response_a', 'response_b', 'text_a','text_b'])

test[['prompt', 'response_a', 'response_b', 'text_a','text_b']] = test.apply(process, axis=1)
test = test.dropna()


# pet
class Senmamtic_news(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.sep_token = tokenizer.sep_token
        self.mask_token = tokenizer.mask_token

        prompt_new = []
        tk0 = tqdm(data['prompt'].fillna("").values, total=len(data))
        for text in tk0:
            length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
            if (length > 512):
                text = tokenizer.convert_tokens_to_string(
                    tokenizer.tokenize(text)[:256] + tokenizer.tokenize(text)[-256:])
            prompt_new.append(text)
        # response_a
        print(f'== response_a ==')
        response_a_new = []
        tk0 = tqdm(data['response_a'].fillna("").values, total=len(data))
        for text in tk0:
            length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
            if (length > 512):
                text = tokenizer.convert_tokens_to_string(
                    tokenizer.tokenize(text)[:256] + tokenizer.tokenize(text)[-256:])
            response_a_new.append(text)

        # response_b
        print(f'== response_b ==')
        response_b_new = []
        tk0 = tqdm(data['response_b'].fillna("").values, total=len(data))
        for text in tk0:
            length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
            if (length > 512):
                text = tokenizer.convert_tokens_to_string(
                    tokenizer.tokenize(text)[:256] + tokenizer.tokenize(text)[-256:])
            response_b_new.append(text)
        self.data['prompt'] = prompt_new
        self.data['response_a'] = response_a_new
        self.data['response_b'] = response_b_new


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # text, label = self.data[idx]
#         semantic_label = self.data['semantic_labels'][idx]
        prompt = self.data['prompt'][idx]
        response_a = self.data['response_a'][idx]
        response_b = self.data['response_b'][idx]
        system_prompt = f"""{prompt}:
                        Response A: {response_a}
                        Response B: {response_b}
                        Which is better? Choose 'A', 'B', or 'both'.
""" +self.mask_token


        inputs = self.tokenizer(system_prompt, truncation=True, max_length=1600)
#         semantic_label = self.tokenizer(semantic_label, truncation=True, max_length=2,add_special_tokens=False)['input_ids']
        input_ids  = torch.tensor(inputs['input_ids'], dtype=torch.long)
        mask_index = torch.where(input_ids == self.tokenizer.mask_token_id)[0]

        return {
            'input_ids': input_ids,
            'attention_mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
#             'semantic_label': torch.tensor(semantic_label, dtype=torch.long),
            # 'semantic_label': torch.tensor([-100]*(len(inputs['input_ids'])-3)+semantic_label+[-100], dtype=torch.long),
            'mask_index': mask_index
        }
def collate_fn_semantic(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
#     semantic_label = [item['semantic_label'] for item in batch]
    mask_index = [item['mask_index'] for item in batch]

    input_ids = pad_sequence(input_ids, batch_first=True)
    attention_mask = pad_sequence(attention_mask, batch_first=True,)
#     semantic_label = pad_sequence(semantic_label, batch_first=True,padding_value=-100)
    mask_index = torch.stack(mask_index)
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
#         'semantic_label': semantic_label,
        'mask_index': mask_index
    }
def get_semantic_data_loader(batch_size=32, mode='train',shuffle=True):
    tokenizer = AutoTokenizer.from_pretrained('/kaggle/input/deberta-small/deberta')
    if mode == 'train':
        dataset = Senmamtic_news(train, tokenizer)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate_fn_semantic)
        return dataloader
    else:
        dataset = Senmamtic_news(test, tokenizer)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,collate_fn=collate_fn_semantic,drop_last = False)
        return dataloader
import torch
from torch import nn
from transformers import AutoModelForMaskedLM, AutoConfig
from tqdm import tqdm

# 定义数据加载器函数和索引字典
# from get_dataloader import get_semantic_data_loader


# 定义模型类
class BertPET(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        self.bert = AutoModelForMaskedLM.from_pretrained(model_path)
        self.config = AutoConfig.from_pretrained(model_path)
    # {"336": 0,
    #             "736": 1,
    #             "462": 2 }
    def forward(self, input_ids, attention_mask, mask_index, labels=None, return_logits=False):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask).logits
        mask_index = mask_index.unsqueeze(-1).expand(-1, -1, output.size(-1))
        output = torch.gather(output, 1, mask_index).squeeze(1)
        if return_logits:
            # 返回特定索引的 logits
            logits = output[:,[336, 732, 462]]
            return logits

# 推理函数
def inference(model, dataloader, device):
    model.eval()
    softmax = nn.Softmax(dim=-1)
    all_probs = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            mask_index = batch['mask_index'].to(device)

            logits = model(input_ids=input_ids, attention_mask=attention_mask, mask_index=mask_index, return_logits=True)
#             logits = logits.squeeze(1)

            probs = softmax(logits)
            probs_list = probs.cpu().numpy().tolist()
            for i in probs_list:
                all_probs.append(i)
        

    return all_probs

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载模型和权重
model_path = '/kaggle/input/deberta-small/deberta'  # 预训练模型路径
checkpoint_path = '/kaggle/input/deberta-v3/model_checkpoint_epoch_2.pt'  # 最佳权重文件路径
model = BertPET(model_path)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)

# 获取测试数据
test_loader = get_semantic_data_loader(2, mode='test')

# 进行推理
test_probs = inference(model, test_loader, device)


sample_sub = pd.read_csv('/kaggle/input/lmsys-chatbot-arena/sample_submission.csv')

test[['winner_model_a', 'winner_model_b', 'winner_tie']] = test_probs
sample_sub = sample_sub[['id']].merge(test[['id', 'winner_model_a', 'winner_model_b', 'winner_tie']], on='id', how='left')
sample_sub.to_csv('submission.csv', index=False)
display(sample_sub.head())