In [None]:
import numpy as np
import pandas as pd
import os
import warnings
from math import floor, ceil
import re
import random
import torch 
from torch import nn
import torch.optim as optim
from sklearn.model_selection import GroupKFold
from scipy.stats import spearmanr

from transformers import BertTokenizer, BertModel, BertConfig, XLNetTokenizer, XLNetModel, XLNetConfig

warnings.filterwarnings('ignore')

In [None]:
def seed_all(seed_value):
    random.seed(seed_value) # Python
    np.random.seed(seed_value) # cpu vars
    torch.manual_seed(seed_value) # cpu  vars
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # gpu vars
        torch.backends.cudnn.deterministic = True  #needed
        torch.backends.cudnn.benchmark = True

In [None]:
SEED = 42
seed_all(SEED)

In [None]:
sub = pd.read_csv('/kaggle/input/google-quest-challenge/sample_submission.csv')
TARGET_COLS = list(sub.columns[1:].values)
for label in TARGET_COLS: print(label) 

In [None]:
def decontract(text):
    text = re.sub(r"(W|w)on(\'|\’)t ", "will not ", text)
    text = re.sub(r"(C|c)an(\'|\’)t ", "can not ", text)
    text = re.sub(r"(Y|y)(\'|\’)all ", "you all ", text)
    text = re.sub(r"(Y|y)a(\'|\’)ll ", "you all ", text)
    text = re.sub(r"(I|i)(\'|\’)m ", "i am ", text)
    text = re.sub(r"(A|a)isn(\'|\’)t ", "is not ", text)
    text = re.sub(r"n(\'|\’)t ", " not ", text)
    text = re.sub(r"(\'|\’)re ", " are ", text)
    text = re.sub(r"(\'|\’)d ", " would ", text)
    text = re.sub(r"(\'|\’)ll ", " will ", text)
    text = re.sub(r"(\'|\’)t ", " not ", text)
    text = re.sub(r"(\'|\’)ve ", " have ", text)
    return text

In [None]:
MAX_LEN = 512

class QuestDatasetBert(torch.utils.data.Dataset):
    def __init__(self, df, train_mode=True, labeled=True):
        self.df = df
        self.train_mode = train_mode
        self.labeled = labeled
        self.tokenizer = BertTokenizer.from_pretrained('../input/btp-uncased-base/')

    def __getitem__(self, index):
        row = self.df.iloc[index]
        token_ids, seg_ids, masks = self.get_token_ids(row)
        if self.labeled:
            labels = self.get_label(row)
            return token_ids, seg_ids, masks, labels
        else:
            return token_ids, seg_ids, masks

    def __len__(self):
        return len(self.df)

    def select_tokens(self, tokens, max_num):
        if len(tokens) <= max_num:
            return tokens
        if self.train_mode:
            num_remove = len(tokens) - max_num
            remove_start = random.randint(0, len(tokens)-num_remove-1)
            return tokens[:remove_start] + tokens[remove_start + num_remove:]
        else:
            return tokens[:max_num//2] + tokens[-(max_num - max_num//2):]
        
    def preprocessing_text(self, text):
        text = re.sub('\t|\n|\r', '', text)
        text = decontract(text)
        
        return text

    def trim_input(self, title, question, answer, max_sequence_length=MAX_LEN, 
                t_max_len=30, q_max_len=239, a_max_len=239):
        t = self.tokenizer.tokenize(self.preprocessing_text(title))
        q = self.tokenizer.tokenize(self.preprocessing_text(question))
        a = self.tokenizer.tokenize(self.preprocessing_text(answer))

        t_len = len(t)
        q_len = len(q)
        a_len = len(a)

        if (t_len+q_len+a_len+4) > max_sequence_length:

            if t_max_len > t_len:
                t_new_len = t_len
                a_max_len = a_max_len + floor((t_max_len - t_len)/2)
                q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
            else:
                t_new_len = t_max_len

            if a_max_len > a_len:
                a_new_len = a_len 
                q_new_len = q_max_len + (a_max_len - a_len)
            elif q_max_len > q_len:
                a_new_len = a_max_len + (q_max_len - q_len)
                q_new_len = q_len
            else:
                a_new_len = a_max_len
                q_new_len = q_max_len


            if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
                raise ValueError("New sequence length should be %d, but is %d" 
                                 % (max_sequence_length, (t_new_len+a_new_len+q_new_len+4)))

            t = self.select_tokens(t, t_new_len)
            q = self.select_tokens(q, q_new_len)
            a = self.select_tokens(a, a_new_len)

        return t, q, a
        
    def get_token_ids(self, row):
        t_tokens, q_tokens, a_tokens = self.trim_input(row.question_title, row.question_body, row.answer)

        tokens = ['[CLS]'] + t_tokens + ['[SEP]'] + q_tokens + ['[SEP]'] + a_tokens + ['[SEP]']
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        if len(token_ids) < MAX_LEN:
            token_ids += [0] * (MAX_LEN - len(token_ids))
        ids = torch.tensor(token_ids)
        seg_ids = self.get_seg_ids(ids)
        masks = self.get_masks(ids)
        return ids, seg_ids, masks
    
    def get_seg_ids(self, ids):
        seg_ids = torch.zeros_like(ids)
        seg_idx = 0
        first_sep = True
        for i, e in enumerate(ids):
            seg_ids[i] = seg_idx
            if e == 102:
                if first_sep:
                    first_sep = False
                else:
                    seg_idx = 1
        return seg_ids
    
    def get_masks(self, ids):
        masks = torch.where(ids != 0, torch.tensor(1), torch.tensor(0))
        return masks

    def get_label(self, row):
        return torch.tensor(row[TARGET_COLS].values.astype(np.float32))
    
class QuestDatasetXLNet(torch.utils.data.Dataset):
    def __init__(self, df, train_mode=True, labeled=True):
        self.df = df
        self.train_mode = train_mode
        self.labeled = labeled
        self.tokenizer = XLNetTokenizer.from_pretrained('../input/xlnet-base-pytorch/xlnet-base-cased-spiece.model')

    def __getitem__(self, index):
        row = self.df.iloc[index]
        token_ids, seg_ids, masks, cls_index = self.get_token_ids(row)
        if self.labeled:
            labels = self.get_label(row)
            return token_ids, seg_ids, masks, cls_index, labels
        else:
            return token_ids, seg_ids, masks, cls_index

    def __len__(self):
        return len(self.df)

    def select_tokens(self, tokens, max_num):
        if len(tokens) <= max_num:
            return tokens
        if self.train_mode:
            num_remove = len(tokens) - max_num
            remove_start = random.randint(0, len(tokens)-num_remove-1)
            return tokens[:remove_start] + tokens[remove_start + num_remove:]
        else:
            return tokens[:max_num//2] + tokens[-(max_num - max_num//2):]
        
    def preprocessing_text(self, text):
        text = re.sub('\t|\n|\r', '', text)
        text = decontract(text)
        
        return text

    def trim_input(self, title, question, answer, max_sequence_length=MAX_LEN, 
                t_max_len=30, q_max_len=239, a_max_len=239):
        t = self.tokenizer.tokenize(self.preprocessing_text(title))
        q = self.tokenizer.tokenize(self.preprocessing_text(question))
        a = self.tokenizer.tokenize(self.preprocessing_text(answer))

        t_len = len(t)
        q_len = len(q)
        a_len = len(a)

        if (t_len+q_len+a_len+4) > max_sequence_length:

            if t_max_len > t_len:
                t_new_len = t_len
                a_max_len = a_max_len + floor((t_max_len - t_len)/2)
                q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
            else:
                t_new_len = t_max_len

            if a_max_len > a_len:
                a_new_len = a_len 
                q_new_len = q_max_len + (a_max_len - a_len)
            elif q_max_len > q_len:
                a_new_len = a_max_len + (q_max_len - q_len)
                q_new_len = q_len
            else:
                a_new_len = a_max_len
                q_new_len = q_max_len


            if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
                raise ValueError("New sequence length should be %d, but is %d" 
                                 % (max_sequence_length, (t_new_len+a_new_len+q_new_len+4)))

            t = self.select_tokens(t, t_new_len)
            q = self.select_tokens(q, q_new_len)
            a = self.select_tokens(a, a_new_len)

        return t, q, a
        
    def get_token_ids(self, row):
        t_tokens, q_tokens, a_tokens = self.trim_input(row.question_title, row.question_body, row.answer)

        tokens = t_tokens + ['<sep>'] + q_tokens + ['<sep>'] + a_tokens + ['<sep>'] + ['<cls>']
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        cls_index = torch.tensor(len(token_ids) - 1)
        if len(token_ids) < MAX_LEN:
            token_ids += [5] * (MAX_LEN - len(token_ids))
        ids = torch.tensor(token_ids)
        seg_ids = self.get_seg_ids(ids)
        masks = self.get_masks(ids)
        return ids, seg_ids, masks, cls_index
    
    def get_seg_ids(self, ids):
        seg_ids = torch.zeros_like(ids)
        seg_idx = 0
        first_sep = True
        for i, e in enumerate(ids):
            seg_ids[i] = seg_idx
            if e == 4:
                if first_sep:
                    first_sep = False
                else:
                    seg_idx = 1
        return seg_ids
    
    def get_masks(self, ids):
        masks = torch.where(ids != 5, torch.tensor(1), torch.tensor(0))
        return masks
    
    def get_label(self, row):
        return torch.tensor(row[TARGET_COLS].values.astype(np.float32))

def get_test_loader(batch_size=32, model='bert'):
    df = pd.read_csv('../input/google-quest-challenge/test.csv')
    if model == 'bert':
        ds_test = QuestDatasetBert(df, train_mode=False, labeled=False)
    else:
        ds_test = QuestDatasetXLNet(df, train_mode=False, labeled=False)
        
    loader = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, 
                                         num_workers=2, drop_last=False)    
    return loader

In [None]:
class BertForQA(nn.Module):
    def __init__(self, net_bert, n_classes):
        super(BertForQA, self).__init__()

        self.bert = net_bert
        self.dropout = nn.Dropout(0.5)
        self.cls = nn.Linear(in_features=768, out_features=n_classes)

        nn.init.normal_(self.cls.weight, std=0.02)
        nn.init.normal_(self.cls.bias, 0)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        _, _, hidden_states = self.bert(input_ids, attention_mask, token_type_ids)

        all_h = torch.cat([hidden_states[-i][:, 0].reshape((-1, 1, 768)) for i in range(1, 5)], 1)
        mean_pool = torch.mean(all_h, 1)

        pooled_output = self.dropout(mean_pool)
        output = self.cls(pooled_output)
                
        return output

In [None]:
class XLNetForQA(nn.Module):
    def __init__(self, xlnet, n_classes):
        super(XLNetForQA, self).__init__()

        self.xlnet = xlnet  
        self.dropout = nn.Dropout(0.5)
        self.cls = nn.Linear(in_features=768, out_features=n_classes)

        nn.init.normal_(self.cls.weight, std=0.02)
        nn.init.normal_(self.cls.bias, 0)

    def forward(self, input_ids, token_type_ids, attention_mask, cls_index):
        _, hidden_states = self.xlnet(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids)
 
        cls_index = cls_index[:, None, None].expand(-1, -1, 768)  # shape (bsz, 1, hsz)        
        all_h = torch.cat([hidden_states[-i].gather(-2, cls_index) for i in range(1, 5)], 1)
        mean_pool = torch.mean(all_h, 1)

        pooled_output = self.dropout(mean_pool)
        output = self.cls(pooled_output)
                
        return output

In [None]:
def build_model(model='bert'):
    if model == 'bert':
        config = BertConfig.from_pretrained('../input/btp-uncased-base/bert_config.json', output_hidden_states=True)    
        net_bert = BertModel.from_pretrained('../input/btp-uncased-base/pytorch_model.bin', config=config)
        net = BertForQA(net_bert, 30)
    else:
        config = XLNetConfig.from_pretrained('../input/xlnet-base-pytorch/xlnet-base-cased-config.json', output_hidden_states=True)    
        xlnet = XLNetModel.from_pretrained('../input/xlnet-base-pytorch/xlnet-base-cased-pytorch_model.bin', config=config)
        net = XLNetForQA(xlnet, 30)
    
    return net

In [None]:
improved_cols = [
    'question_conversational', 
    'question_has_commonly_accepted_answer',
    'question_interestingness_self',
    'question_not_really_a_question',
    'question_type_choice',
    'question_type_compare',
    'question_type_consequence',
    'question_type_definition',
    'question_type_entity',
    'answer_plausible',
    'answer_relevance',
]
bestcutting = [
    # question_asker_intent_understanding
    [],
    # question_body_critical
    [], 
    # question_conversational
    [0.07143090909002797, 0.12004401497150083, 0.1848436328636153, 0.393906661243123], 
    # question_expect_short_answer
    [], 
    # question_fact_seeking
    [0.5079822811938757, 0.6784007255713459, 0.8821817039873604, 0.9433617221024944], 
    # question_has_commonly_accepted_answer
    [0.3371636921591598, 0.5581504204597405, 0.7077497707969119, 0.7828705361913083], 
    # question_interestingness_others
    [], 
    # question_interestingness_self
    [0.4956358187787945, 0.5224040768530395, 0.5255965284962879, 0.5793678554651578, 0.609287211890339, 0.616010882506188, 0.6196306504341735, 0.6676136621794896], 
    # question_multi_intent
    [], 
    # question_not_really_a_question
    [0.006889458941723597, 0.3260944763731216, 0.41788196092069185, 0.8651878898710491], 
    # question_opinion_seeking
    [], 
    # question_type_choice
    [0.11795989772505418, 0.24754185187571828, 0.4745327419425606, 0.7234122358154191], 
    # question_type_compare
    [0.09410719980735521, 0.15601674060996373, 0.282369801654087, 0.5313848085789701], 
    # question_type_consequence
    [0.0004399460062622443, 0.03194983660092762, 0.07003471591529131, 0.1503817577814426], 
    # question_type_definition
    [0.14375842154645047, 0.2616577902050483, 0.2625648734101294, 0.396704338334258], 
    # question_type_entity
    [0.09127758816465084, 0.17097449682745675, 0.4003532557591638, 0.9950764536726562], 
    # question_type_instructions
    [0.16265486023189857, 0.32208644678446063, 0.5417455427354888, 0.7751835127308364], 
    # question_type_procedure
    [], 
    # question_type_reason_explanation
    [], 
    # question_type_spelling
    [],
    # question_well_written
    [], 
    # answer_helpful
    [], 
    # answer_level_of_information
    [], 
    # answer_plausible
    [0.05277303723757072, 0.08296633934736848, 0.2892567828157349, 0.8733971791156417, 0.9508973759039053, 0.9536983669289478, 0.9693328391611057, 0.9815554213789836], 
    # answer_relevance
    [0.675877531406042, 0.7222527113500082, 0.8276284368386688, 0.9215102556176764, 0.939746161126185, 0.9528639928103086, 0.9671258882221032, 0.978652006112756], 
    # answer_satisfaction
    [0.06008624214559094, 0.08468818485434321, 0.17139508646282006, 0.42753923139658206, 0.47748050195355696, 0.5020362112065385, 0.6191407747833618, 0.6221662654090617, 0.7453504487470863, 0.8057253638493509, 0.8163920374283427, 0.8472975110653058, 0.8670318719402241, 0.8748592375242844, 0.9165127791415058, 0.9982222527210189], 
    # answer_type_instructions
    [], 
    # answer_type_procedure
    [], 
    # answer_type_reason_explanation
    [], 
    # answer_well_written
    []
]

In [None]:
def postprocess(pred):
    df = pd.read_csv('../input/google-quest-challenge/train.csv')
    for i, col in enumerate(TARGET_COLS):
        if not col in improved_cols:
            continue
        labels = np.sort(df[col].unique())
        pred[: , i] = pd.cut(pred[:, i], [-np.inf] + bestcutting[i] + [np.inf], labels=labels)
    return pred

In [None]:
def predict_bert(net_trained, test_dl):
    net_trained.eval()
    net_trained.cuda()

    preds = []

    for batch in test_dl:
        token_ids = batch[0].cuda()
        seg_ids = batch[1].cuda()
        masks = batch[2].cuda()

        with torch.set_grad_enabled(False):
            outputs = net_trained(token_ids, token_type_ids=seg_ids, attention_mask=masks)
            preds.append(torch.sigmoid(outputs).cpu())

    return torch.cat(preds, 0).numpy()

def predict_xlnet(net_trained, test_dl):
    net_trained.eval()
    net_trained.cuda()

    preds = []

    for batch in test_dl:
        token_ids = batch[0].cuda()
        seg_ids = batch[1].cuda()
        masks = batch[2].cuda()
        cls_idx = batch[3].cuda()

        with torch.set_grad_enabled(False):
            outputs = net_trained(token_ids, token_type_ids=seg_ids, attention_mask=masks, cls_index=cls_idx)
            preds.append(torch.sigmoid(outputs).cpu())

    return torch.cat(preds, 0).numpy()

In [None]:
predictions = []

In [None]:
%%time
for i in range(10): 
    net = build_model(model='bert')
    net.load_state_dict(torch.load(f'../input/google-quest-bert-base-train/bert_fold{i+1}.pth'))
    test_dl = get_test_loader(model='bert')
    predictions.append(predict_bert(net, test_dl))
    
for i in range(10): 
    net = build_model(model='xlnet')
    net.load_state_dict(torch.load(f'../input/google-quest-xlnet-train/xlnet_fold{i+1}.pth'))
    test_dl = get_test_loader(model='xlnet')
    predictions.append(predict_xlnet(net, test_dl))

In [None]:
df = pd.read_csv('../input/google-quest-challenge/test.csv')
n = df['url'].apply(lambda x:(('ell.stackexchange.com' in x) or ('english.stackexchange.com' in x))).tolist()
spelling = []
for x in n:
    if x:
        spelling.append(0.5)
    else:
        spelling.append(0.)

In [None]:
sub[TARGET_COLS] = postprocess(np.mean(predictions, axis=0))
sub['question_type_spelling'] = spelling
sub.to_csv('submission.csv', index=False)
sub.head()