In [None]:
import torch
import random
import math
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch import nn
from tqdm import tqdm_notebook as tqdm

from transformers.tokenization_utils import BatchEncoding
from torch.cuda.amp import GradScaler, autocast


from transformers import AutoModel, AutoConfig, AutoTokenizer
import gc

In [None]:
def merge_dicts(input):
    # list of dict --> dict of list
    keys = input[0].keys()
    ret = dict()
    for key in keys:
        temp = [x[key] for x in input]
        ret[key] = temp
    return ret

In [None]:
def make_test_dataset(train_df, test_df, anchor_num = 50):
    data = []
    interval = len(train_df)//anchor_num
    anchor_idx = [interval*x for x in range(anchor_num)]
    sampled_train = train_df.sort_values(by='etarget')
    sampled_train = train_df.iloc[anchor_idx]
    for idx, row in test_df.iterrows():
        df = pd.DataFrame()
        df['anchor_text'] = sampled_train['excerpt']
        df['anchor_target'] = sampled_train['target']
        df['anchor_etarget'] = sampled_train['etarget']
        df['excerpt'] = row['excerpt']
        df['id'] = row['id']
        data.append(df)
    new_df = pd.concat(data, ignore_index=True, sort=False).reset_index(drop=True)
    return new_df

In [None]:
class MyValidationDataset(Dataset):

    def __init__(self, df, tokenizer, exp=False) -> None:
        super().__init__()
        self.df = df
        self.texts = df['excerpt'].drop_duplicates().tolist()
        self.anchor_texts = df['anchor_text'].drop_duplicates().tolist()
        self.anchor_targets = df['anchor_target'].tolist()
        self.anchor_etargets = df['anchor_etarget'].tolist()
        self.tokenizer = tokenizer
        self.exp = exp

    def __getitem__(self, idx):
        text = self.texts[idx]

        inputs = self.tokenizer.encode_plus(text, return_tensors='pt')
        anchor_inputs = self.tokenizer.batch_encode_plus([a for a in self.anchor_texts], 
                                                         max_length=512,
                                                                return_tensors='pt',
                                                                truncation=True,
                                                                padding=True)
        return inputs, anchor_inputs, {}

    def __len__(self):
        return len(self.texts)

In [None]:
def add_pooling_mask(encode_result):
    stm = encode_result['special_tokens_mask'][0]
    for sep_pos in range(1, len(stm)):
        if stm[sep_pos] ==1:
            break
    mask = torch.LongTensor([[1]*sep_pos+[2]*(len(stm)-sep_pos)])
    encode_result['pooling_mask'] = mask
    return encode_result

class OneBertValidationDataset(Dataset):

    def __init__(self, df, tokenizer, exp=False) -> None:
        super().__init__()
        self.df = df
        self.texts = df['excerpt'].tolist()
        self.anchor_texts = df['anchor_text'].tolist()
        self.anchor_targets = df['anchor_target'].tolist()
        self.anchor_etargets = df['anchor_etarget'].tolist()  
        self.tokenizer = tokenizer
        self.exp = exp

    def __getitem__(self, idx):
        text = self.texts[idx]

        inputs = self.tokenizer.encode_plus(text, self.anchor_texts[idx], 
                                             return_tensors='pt', return_special_tokens_mask=True,
                                             truncation='only_second',
                                             max_length=512,
                                             padding=True)
        inputs = add_pooling_mask(inputs)
        return inputs, {}

    def __len__(self):
        return len(self.texts)

In [None]:
class MyValidationCollator:

    def __init__(self, token_pad_value=0, type_pad_value=1):
        super().__init__()
        self.token_pad_value = token_pad_value
        self.type_pad_value = type_pad_value

    def __call__(self, batch):
        inputs, anchor_inputs, labels = zip(*batch)

        tokens = pad_sequence([d['input_ids'][0] for d in inputs], batch_first=True,
                              padding_value=self.token_pad_value)
        masks = pad_sequence([d['attention_mask'][0]
                              for d in inputs], batch_first=True, padding_value=0)
        features = {
            'input_ids': tokens,
            'attention_mask': masks
        }
        if 'token_type_ids' in inputs[0]:
            type_ids = pad_sequence(
                [d['token_type_ids'][0] for d in inputs], batch_first=True, padding_value=self.type_pad_value)
            features['token_type_ids'] = type_ids

        anchor_fetures = anchor_inputs[0]

        labels = merge_dicts(labels)
        for key, value in labels.items():
            labels[key] = torch.cat(value, dim=0)
        return {'features':features, 'anchor_features':anchor_fetures}, labels


In [None]:
class OneBertCompareCollator:

    def __init__(self, token_pad_value=0, type_pad_value=1):
        super().__init__()
        self.token_pad_value = token_pad_value
        self.type_pad_value = type_pad_value

    def __call__(self, batch):
        inputs, labels = zip(*batch)

        tokens = pad_sequence([d['input_ids'][0] for d in inputs], batch_first=True,
                              padding_value=self.token_pad_value)
        masks = pad_sequence([d['attention_mask'][0]
                              for d in inputs], batch_first=True, padding_value=0)
        pooling_mask = pad_sequence([d['pooling_mask'][0]
                              for d in inputs], batch_first=True, padding_value=0)
        features = {
            'input_ids': tokens,
            'attention_mask': masks,
            'pooling_mask': pooling_mask
        }
        if 'token_type_ids' in inputs[0]:
            type_ids = pad_sequence(
                [d['token_type_ids'][0] for d in inputs], batch_first=True, padding_value=self.type_pad_value)
            features['token_type_ids'] = type_ids

        labels = merge_dicts(labels)
        for key, value in labels.items():
            labels[key] = torch.cat(value, dim=0)
        
        return {'features':features}, labels

In [None]:
class Encoder(nn.Module):

    def __init__(self, pretrained_model_name, config=None, pooling='cls', grad_checkpoint=False, **kwargs):
        super().__init__()
        if config is not None:
            self.bert = AutoModel.from_config(config)
        else:
            config = AutoConfig.from_pretrained(pretrained_model_name)
            if grad_checkpoint:
                self.bert.config.gradient_checkpointing = True
            if kwargs.get('hidden_dropout'):
                config.hidden_dropout_prob = kwargs['hidden_drop']   
            if kwargs.get('attention_dropput'):
                config.attention_probs_dropout_prob = kwargs['attention_dropout']
            if kwargs.get('layer_norm_eps'):
                config.layer_norm_eps = kwargs['layer_norm_eps']
            self.bert = AutoModel.from_pretrained(pretrained_model_name, config=config)
        self.pooling = pooling
        
        self.hidden_size = self.bert.config.hidden_size
        if pooling!='cls':
            self.bert.pooler = nn.Identity()
        self.attention = nn.Sequential(            
            nn.Linear(self.hidden_size, 256),            
            nn.GELU(),                       
            nn.Linear(256, 1)
        )

    def forward(self, features):

        output_states = self.bert(input_ids=features.get('input_ids'),
                                  attention_mask=features.get(
                                      'attention_mask'),
                                  token_type_ids=features.get('token_type_ids'))
        out = output_states[0]  # embedding for all tokens
        if self.pooling == 'cls':
            pooled_out = output_states[1]  # CLS token is first token
        elif self.pooling == 'mean':
            attention_mask = features['attention_mask']
            input_mask_expanded = attention_mask.unsqueeze(
                -1).expand(out.size()).float()
            sum_embeddings = torch.sum(out * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            pooled_out = sum_embeddings / sum_mask
        elif self.pooling=='att':
            weights = self.attention(out)
            attention_mask = features['attention_mask'].unsqueeze(
                -1).expand(weights.size())
            weights.masked_fill_(attention_mask==0, -float('inf'))
            weights = torch.softmax(weights, dim=1)
            pooled_out = torch.sum(out*weights, dim=1)
        return pooled_out, out

In [None]:
class Comparer(nn.Module):

    def __init__(self, pretrained_model_name, config=None, pooling='mean', esim=True, grad_checkpoint=False):
        super().__init__()
        self.encoder = Encoder(pretrained_model_name, config=config,
                               pooling=pooling, grad_checkpoint=grad_checkpoint)
        self.esim = esim
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.encoder.hidden_size*2, self.encoder.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(self.encoder.hidden_size, 1)
        )

        self.anchor_pooled_emb, self.anchor_seq_emb = None, None

    def inference(self, seq1, seq2, mask1=None, mask2=None):
        # seq1: B*l1*D
        # seq2: B*l2*D
        # mask1: B*l1
        # mask2: B*l2

        score = torch.bmm(seq1, seq2.permute(0, 2, 1))
        new_seq1 = torch.bmm(torch.softmax(score, dim=-1), seq2*mask2.unsqueeze(-1)) #
        new_seq1 = torch.sum(new_seq1*mask1.unsqueeze(-1),dim=1)/torch.sum(mask1, dim=1).unsqueeze(-1)
        # del score1

        new_seq2 = torch.bmm(torch.softmax(score, dim=1).permute(0, 2, 1), seq1*mask1.unsqueeze(-1)) #
        new_seq2 = torch.sum(new_seq2*mask2.unsqueeze(-1), dim=1)/torch.sum(mask2, dim=1).unsqueeze(-1)
        return new_seq1, new_seq2


    def forward(self, features, anchor_features=None):

        pooled_emb, seq_emb = self.encoder(features)

#         etarget_out = self.etarget_head(pooled_emb)

        bs, length, dim = seq_emb.size()
        if anchor_features is None:
            # embeddings: B*D
            # 111,222,333
            pooled_emb1 = pooled_emb.unsqueeze(
                1).expand(-1, bs, -1).reshape(-1, dim)
            # 123,123,123
            pooled_emb2 = pooled_emb.unsqueeze(
                0).expand(bs, -1, -1).reshape(-1, dim)

            seq_emb1 = seq_emb.unsqueeze(
                1).expand(-1, bs, -1, -1).reshape(-1, length, dim)
            seq_emb2 = seq_emb.unsqueeze(0).expand(
                bs, -1, -1, -1).reshape(-1, length, dim)

            mask1 = features['attention_mask'].unsqueeze(
                1).expand(-1, bs, -1).reshape(-1, length)
            mask2 = features['attention_mask'].unsqueeze(
                0).expand(bs, -1, -1).reshape(-1, length)
            new_emb1, new_emb2 = self.inference(
                seq_emb1, seq_emb2, mask1, mask2)
        else:
            if self.anchor_pooled_emb is None:
                anchor_pooled_emb, anchor_seq_emb = self.encoder(anchor_features)
                self.anchor_pooled_emb, self.anchor_seq_emb = anchor_pooled_emb, anchor_seq_emb
            else:
                anchor_pooled_emb, anchor_seq_emb = self.anchor_pooled_emb, self.anchor_seq_emb
            anchor_bs, _ = anchor_pooled_emb.size()
            pooled_emb = pooled_emb.unsqueeze(
                1).expand(-1, anchor_bs, -1).reshape(-1, dim)
            anchor_pooled_emb = anchor_pooled_emb.repeat(bs, 1)
            pooled_emb1 = pooled_emb
            pooled_emb2 = anchor_pooled_emb

            seq_emb1 = seq_emb.unsqueeze(
                1).expand(-1, anchor_bs, -1, -1).reshape(-1, length, dim)
            seq_emb2 = anchor_seq_emb.repeat(bs, 1, 1)

            mask1 = features['attention_mask'].unsqueeze(
                1).expand(-1, anchor_bs, -1).reshape(-1, length)
            mask2 = anchor_features['attention_mask'].repeat(bs, 1)
            new_emb1, new_emb2 = self.inference(seq_emb1, seq_emb2, mask1, mask2)

        fc_input = torch.cat(
            [pooled_emb1-pooled_emb2, new_emb2-new_emb1], dim=1)
        # fc_input = pooled_emb1-pooled_emb2+new_emb2-new_emb1], dim=1)
        output = self.fc(fc_input)
        ret = {'pred': output}
        return ret

def _prepare_inputs(inputs):
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.cuda()
        elif isinstance(v, BatchEncoding): # for embedding training
            inputs[k] = v.cuda()
        elif isinstance(v, dict): # for embedding training
            inputs[k] = _prepare_inputs(v)
    return inputs

In [None]:
class OneBertComparer(nn.Module):

    def __init__(self, pretrained_model_name, config=None, pooling='cls', sep_pooling=False):
        super().__init__()
        self.encoder = Encoder(pretrained_model_name,
                               config=config, pooling='mean')
        self.sep_pooling = sep_pooling
        self.pooling = pooling
        self.fc = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(self.encoder.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(256, 1)
        )

        self.attention = nn.Sequential(            
            nn.Linear(self.encoder.hidden_size, 256),            
            nn.GELU(),                       
            nn.Linear(256, 1)
        )   

    def masked_mean_pooling(self, seq_emb, pooling_mask):
        mask1 = (pooling_mask == 1)
        pooled_emb_1 = torch.sum(seq_emb*mask1.unsqueeze(-1), dim=1)/torch.sum(mask1, dim=1, keepdim=True)

        mask2 = (pooling_mask == 2)
        pooled_emb_2 = torch.sum(seq_emb*mask2.unsqueeze(-1), dim=1)/torch.sum(mask2, dim=1, keepdim=True)

        return pooled_emb_1, pooled_emb_2

    def masked_att_pooling(self, seq_emb, pooling_mask):
        
        weights = self.attention(seq_emb)
        mask1 = (pooling_mask == 2).unsqueeze(-1).expand(weights.size())
        weight1 = torch.softmax(weights.masked_fill(mask1, -float('inf')), dim=1)

        pooled_emb_1 = torch.sum(seq_emb*weight1, dim=1)

        mask2 = (pooling_mask == 1).unsqueeze(-1).expand(weights.size())
        weight2 = torch.softmax(weights.masked_fill(mask2, -float('inf')), dim=1)
        pooled_emb_2 = torch.sum(seq_emb*weight2, dim=1)

        return pooled_emb_1, pooled_emb_2

    def forward(self, features):
        pooled_emb, seq_emb = self.encoder(features)
        if self.sep_pooling:
            if self.pooling=='att':
                emb1, emb2 = self.masked_att_pooling(seq_emb, features['pooling_mask'])
            else:
                emb1, emb2 = self.masked_mean_pooling(seq_emb, features['pooling_mask'])
            sep_pooled_emb = emb1-emb2
            # output = self.fc(torch.cat([pooled_emb, sep_pooled_emb], dim=1))
            output = self.fc(sep_pooled_emb)
        else:
            output = self.fc(pooled_emb)
        ret = {'pred': output}
        return ret

In [None]:
!ls ../input

In [None]:
train = pd.read_csv('../input/clrp-compare-base/train_with_folds.csv')
test = pd.read_csv('../input/commonlitreadabilityprize/test.csv')

tokenizer = AutoTokenizer.from_pretrained('../input/clrp-roberta-reg/')

In [None]:
config = AutoConfig.from_pretrained('../input/clrp-roberta-reg')

In [None]:
all_results = []
for fold in range(5):
    print(f'fold {fold}')
    train_fold = train[train['fold']!=fold]
    new_test = make_test_dataset(train_fold, test, 50)
    valid_set = MyValidationDataset(new_test, tokenizer)
    valid_collator = MyValidationCollator(token_pad_value=tokenizer.pad_token_id, 
                          type_pad_value=tokenizer.pad_token_id)
    loader = torch.utils.data.DataLoader(
        valid_set,
        shuffle=False,
        batch_size=16,
        pin_memory=True,
        drop_last=False,
        num_workers=4,
    collate_fn=valid_collator
    )
    
    state = torch.load(f'../input/clrp-roberta-reg/best-model-{fold}.pt')
    model = Comparer(None, config, pooling='att')
    model.load_state_dict(state['model'])
    model.eval()
    model.cuda()
    
    results=[]
    with torch.no_grad():
        for inputs, _ in tqdm(loader): 
            inputs = _prepare_inputs(inputs)
            with autocast(enabled=True):
                outputs = model(**inputs)
            results.append(outputs)
    predicts = merge_dicts(results)
    for key in predicts.keys():
        predicts[key] = torch.cat(predicts[key], dim=0)
    pred = predicts['pred'].cpu().numpy()
    
    df = pd.DataFrame()
    df['id']= new_test['id']
    df['pred'] = pred.flatten()
    df['anchor_target'] = new_test['anchor_target']
    df['pred_target'] = df['pred']+df['anchor_target']
    all_results.append(df[['id','pred_target']])
    
    del model
    gc.collect()

In [None]:
for fold in range(5):
    print(f'fold {fold}')
    train_fold = train[train['fold']!=fold]
    new_test = make_test_dataset(train_fold, test, 50)
    valid_set = MyValidationDataset(new_test, tokenizer)
    valid_collator = MyValidationCollator(token_pad_value=tokenizer.pad_token_id, 
                          type_pad_value=tokenizer.pad_token_id)
    loader = torch.utils.data.DataLoader(
        valid_set,
        shuffle=False,
        batch_size=16,
        pin_memory=True,
        drop_last=False,
        num_workers=4,
    collate_fn=valid_collator
    )
    
    state = torch.load(f'../input/clrp-roberta-v2/best-model-{fold}.pt')
    model = Comparer(None, config, pooling='att')
    model.load_state_dict(state['model'])
    model.eval()
    model.cuda()
    
    results=[]
    with torch.no_grad():
        for inputs, _ in tqdm(loader): 
            inputs = _prepare_inputs(inputs)
            with autocast(enabled=True):
                outputs = model(**inputs)
            results.append(outputs)
    predicts = merge_dicts(results)
    for key in predicts.keys():
        predicts[key] = torch.cat(predicts[key], dim=0)
    pred = torch.sigmoid(predicts['pred']).cpu().numpy()
    
    df = pd.DataFrame()
    df['id']= new_test['id']
    df['pred'] = pred.flatten()
    df['anchor_target'] = new_test['anchor_etarget']
    df['pred_etarget'] = df['pred']*df['anchor_target']/(1-df['pred'])
    df['pred_etarget'] = np.clip(df['pred_etarget'],a_min = 0.025, a_max=5.6)
    df['pred_target'] = np.log(df['pred_etarget'])
    all_results.append(df[['id','pred_target']])
    
    del model
    gc.collect()

In [None]:
# onebert binary
for fold in [0,1,2]:
    print(f'fold {fold}')
    train_fold = train[train['fold']!=fold]
    new_test = make_test_dataset(train_fold, test, 20)
    valid_set = OneBertValidationDataset(new_test, tokenizer)
    valid_collator = OneBertCompareCollator(token_pad_value=tokenizer.pad_token_id, 
                          type_pad_value=tokenizer.pad_token_type_id)
    loader = torch.utils.data.DataLoader(
        valid_set,
        shuffle=False,
        batch_size=16,
        pin_memory=True,
        drop_last=False,
        num_workers=4,
    collate_fn=valid_collator
    )
    
    state = torch.load(f'../input/clrp-roberta-one/best-model-{fold}.pt')
    model = OneBertComparer(None, config, sep_pooling=True, pooling='att')
    model.load_state_dict(state['model'])
    model.eval()
    model.cuda()
    
    results=[]
    with torch.no_grad():
        for inputs, _ in tqdm(loader): 
            inputs = _prepare_inputs(inputs)
            with autocast(enabled=True):
                outputs = model(**inputs)
            results.append(outputs)
    predicts = merge_dicts(results)
    for key in predicts.keys():
        predicts[key] = torch.cat(predicts[key], dim=0)
    pred = torch.sigmoid(predicts['pred']).cpu().numpy()
    
    df = pd.DataFrame()
    df['id']= new_test['id']
    df['pred'] = pred.flatten()
    df['anchor_target'] = new_test['anchor_etarget']
    df['pred_etarget'] = df['pred']*df['anchor_target']/(1-df['pred'])
    df['pred_etarget'] = np.clip(df['pred_etarget'],a_min = 0.025, a_max=5.6)
    df['pred_target'] = np.log(df['pred_etarget'])
    all_results.append(df[['id','pred_target']])
    
    del model
    gc.collect()

In [None]:
# onebert reg
for fold in [2,3,4]:
    print(f'fold {fold}')
    train_fold = train[train['fold']!=fold]
    new_test = make_test_dataset(train_fold, test, 20)
    valid_set = OneBertValidationDataset(new_test, tokenizer)
    valid_collator = OneBertCompareCollator(token_pad_value=tokenizer.pad_token_id, 
                          type_pad_value=tokenizer.pad_token_type_id)
    loader = torch.utils.data.DataLoader(
        valid_set,
        shuffle=False,
        batch_size=16,
        pin_memory=True,
        drop_last=False,
        num_workers=4,
    collate_fn=valid_collator
    )
    
    state = torch.load(f'../input/clrp-onebert-reg/best-model-{fold}.pt')
    model = OneBertComparer(None, config, sep_pooling=True, pooling='att')
    model.load_state_dict(state['model'])
    model.eval()
    model.cuda()
    
    results=[]
    with torch.no_grad():
        for inputs, _ in tqdm(loader): 
            inputs = _prepare_inputs(inputs)
            with autocast(enabled=True):
                outputs = model(**inputs)
            results.append(outputs)
    predicts = merge_dicts(results)
    for key in predicts.keys():
        predicts[key] = torch.cat(predicts[key], dim=0)
    pred = predicts['pred'].cpu().numpy()
    
    df = pd.DataFrame()
    df['id']= new_test['id']
    df['pred'] = pred.flatten()
    df['anchor_target'] = new_test['anchor_target']
    df['pred_target'] = df['pred']+df['anchor_target']
    all_results.append(df[['id','pred_target']])
    
    del model
    gc.collect()

In [None]:
df = pd.concat(all_results, ignore_index=True, sort=False)
pred = df.groupby('id')['pred_target'].mean().reset_index()

In [None]:
pred['target'] = pred['pred_target']
pred[['id','target']].to_csv('submission.csv', index=False)

In [None]:
len(all_results[0])

In [None]:
pred.head()

In [None]:
pred.head()