動作確認

In [None]:
import gc
gc.enable()

import sys
sys.path.append("../input/tez-lib/")

import os
import random
from tqdm import tqdm

import numpy as np
import pandas as pd
from scipy.special import softmax
import pickle

import tez
import torch
import torch.nn as nn
from joblib import Parallel, delayed
from transformers import AutoConfig, AutoModel, AutoTokenizer


# Config

In [None]:
class Config:
    input_dir = '../input/feedback-prize-2021'
    
    model_longformer = '../input/longformerlarge4096/longformer-large-4096'
    model_led = '../input/led-large'
    model_deberta = '../input/deberta-large-download/deberta-large'
    
    max_len_test_longformer = 1600
    max_len_test_led = 1024
    
    tokenizer = '../input/deberta-large-download/deberta-large' #'../input/longformerlarge4096/longformer-large-4096'
    num_jobs = 4
    seed = 1
    
    model_ckp_path = [
        #  (kind, num_labels, model_path, weight)
        # exp lf-large, lf-base, led-large, bb-large, ...
        
        #('lf-large', 22, '../input/2022021410-lf-bie-bin/model_0.bin', 1.0),
        #('lf-large', 22, '../input/2022021410-lf-bie-bin/model_1.bin', 1.0),
        #('lf-large', 22, '../input/2022021410-lf-bie-bin/model_2.bin', 1.0),
        #('lf-large', 22, '../input/2022021410-lf-bie-bin/model_3.bin', 1.0),
        #('lf-large', 22, '../input/2022021410-lf-bie-bin/model_4.bin', 1.0),
        
        #('lf-large', 15, '../input/2022020906-aug-bin/model_0.bin', 1.0),
        #('lf-large', 15, '../input/2022020906-aug-bin/model_1.bin', 1.0),
        #('lf-large', 15, '../input/2022020906-aug-bin/model_2.bin', 1.0),
        #('lf-large', 15, '../input/2022020906-aug-bin/model_3.bin', 1.0),
        #('lf-large', 15, '../input/2022020906-aug-bin/model_4.bin', 1.0),

        #('led-large', 24, '../input/fb-exp037-led/model_0.bin', 1.0),
        #('led-large', 24, '../input/fb-exp037-led/model_1.bin', 1.0),
        #('led-large', 24, '../input/fb-exp037-led/model_2.bin', 1.0),
        #('led-large', 24, '../input/fb-exp037-led/model_3.bin', 1.0),
        #('led-large', 24, '../input/fb-exp037-led/model_4.bin', 1.0),

        #('led-large', 24, '../input/fb-exp017-led/model_0.bin', 1.0),
        #('led-large', 24, '../input/fb-exp017-led/model_1.bin', 1.0),
        #('led-large', 24, '../input/fb-exp017-led/model_2.bin', 1.0),
        #('led-large', 24, '../input/fb-exp017-led/model_3.bin', 1.0),
        #('deberta-large', 22, '../input/fb-exp045-deberta/model_0.bin', 1.0),
        ('deberta-large', 15, '../input/2022022709-deberta-large-boe-bin/model_0.bin', 1.0),
    ]

    proba_thresh = {
        "Lead": 0.6, # 0.7
        "Position": 0.4, # 0.55
        "Evidence": 0.65,
        "Claim": 0.55,
        "Concluding Statement": 0.6, # 0.7
        "Counterclaim": 0.5,
        "Rebuttal": 0.55,
    }
    min_token_thresh = {
        "Lead": 5, # 9
        "Position": 4, # 5
        "Evidence": 14,
        "Claim": 2,
        "Concluding Statement": 7, # 11
        "Counterclaim": 6,
        "Rebuttal": 4,
    }
    link = {
        'Evidence': 40,
        'Counterclaim': 200,
        'Rebuttal': 200,
    }


cfg = Config()    
target_id_map = {
    "B-Lead": 0,
    "I-Lead": 1,
    "B-Position": 2,
    "I-Position": 3,
    "B-Evidence": 4,
    "I-Evidence": 5,
    "B-Claim": 6,
    "I-Claim": 7,
    "B-Concluding Statement": 8,
    "I-Concluding Statement": 9,
    "B-Counterclaim": 10,
    "I-Counterclaim": 11,
    "B-Rebuttal": 12,
    "I-Rebuttal": 13,
    "O": 14,
    "PAD": -100,
}
id_target_map = {v: k for k, v in target_id_map.items()}


# Modules

## Utils

In [None]:
def get_test_text(ids):
    with open(f'{cfg.input_dir}/train/{ids}.txt', 'r') as f:
        text = f.read()
    return text


def seed_everything(seed: int) -> None:
    """
    seedの固定
    """
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        
seed_everything(cfg.seed)


## Pre Process

In [None]:
def _prepare_test_data_helper(tokenizer, ids):
    test_samples = []
    for idx in ids:
        text = get_test_text(idx)

        encoded_text = tokenizer.encode_plus(
            text,
            add_special_tokens=False,
            return_offsets_mapping=True,
            max_length=1600,
            truncation=True,
        )
        input_ids = encoded_text["input_ids"]
        offset_mapping = encoded_text["offset_mapping"]

        sample = {
            "id": idx,
            "input_ids": input_ids,
            "text": text,
            "offset_mapping": offset_mapping,
        }

        test_samples.append(sample)
    return test_samples


def prepare_test_data(ids, tokenizer, num_jobs):
    test_samples = []
    #ids = df["id"].unique()
    ids_splits = np.array_split(ids, 4)

    results = Parallel(n_jobs=num_jobs, backend="multiprocessing")(
        delayed(_prepare_test_data_helper)(tokenizer, idx) for idx in ids_splits
    )
    for result in results:
        test_samples.extend(result)

    return test_samples

## Dataset

In [None]:
class FeedbackTestDataset:
    def __init__(self, samples, max_len, tokenizer):
        self.samples = samples
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.length = len(samples)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_ids = self.samples[idx]["input_ids"]
        input_ids = [self.tokenizer.cls_token_id] + input_ids

        if len(input_ids) > self.max_len - 1:
            input_ids = input_ids[: self.max_len - 1]

        # add end token id to the input_ids
        input_ids = input_ids + [self.tokenizer.sep_token_id]
        attention_mask = [1] * len(input_ids)
        return {
            "ids": input_ids,
            "mask": attention_mask,
        }
    
    
class Collate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        output = dict()
        output["ids"] = [sample["ids"] for sample in batch]
        output["mask"] = [sample["mask"] for sample in batch]

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in output["ids"]])

        # add padding
        if self.tokenizer.padding_side == "right":
            output["ids"] = [s + (batch_max - len(s)) * [self.tokenizer.pad_token_id] for s in output["ids"]]
            output["mask"] = [s + (batch_max - len(s)) * [0] for s in output["mask"]]
        else:
            output["ids"] = [(batch_max - len(s)) * [self.tokenizer.pad_token_id] + s for s in output["ids"]]
            output["mask"] = [(batch_max - len(s)) * [0] + s for s in output["mask"]]

        # convert to tensors
        output["ids"] = torch.tensor(output["ids"], dtype=torch.long)
        output["mask"] = torch.tensor(output["mask"], dtype=torch.long)

        return output

## Model

In [None]:
class FeedbackModel(tez.Model):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        self.config = AutoConfig.from_pretrained(model_name)
        self.config.update(
            {
                'output_hidden_states': True,
                'add_pooling_layer': False,
                'num_labels': self.num_labels,
                'attention_probs_dropout_prob':  0.1,  # for longformer
                'hidden_dropout_prob': 0.1,            # for longformer
                'layer_norm_eps': 1e-7,                # for longformer
                'activation_dropout': 0.0,             # for LED
                'attention_dropout': 0.0,              # for LED
                'classif_dropout': 0.0,                # for LED
                'classifier_dropout': 0.0,             # for LED
                'decoder_layerdrop': 0.0,              # for LED
                'encoder_layerdrop': 0.0,              # for LED 
            }
        )
        self.transformer = AutoModel.from_config(self.config)
        self.output = nn.Linear(self.config.hidden_size, self.num_labels)

    def forward(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        logits = self.output(sequence_output)
        logits = torch.softmax(logits, dim=-1)
        return logits, 0, {}

    
class FeedbackModelV2(tez.Model):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        self.config = AutoConfig.from_pretrained(model_name)
        self.config.update(
            {
                'output_hidden_states': True,
                'add_pooling_layer': False,
                'num_labels': self.num_labels,
                'attention_probs_dropout_prob':  0.1,  # for longformer
                'hidden_dropout_prob': 0.1,            # for longformer
                'layer_norm_eps': 1e-7,                # for longformer
                'activation_dropout': 0.0,             # for LED
                'attention_dropout': 0.0,              # for LED
                'classif_dropout': 0.0,                # for LED
                'classifier_dropout': 0.0,             # for LED
                'decoder_layerdrop': 0.0,              # for LED
                'encoder_layerdrop': 0.0,              # for LED 
            }
        )
        self.transformer = AutoModel.from_config(self.config)
        self.fc = nn.Linear(self.config.hidden_size, self.num_labels)

    def forward(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        logits = self.fc(sequence_output)
        logits = torch.softmax(logits, dim=-1)
        return logits, 0, {}


## Inference

In [None]:
data_name = {0:"lf",
            1:"led",
            2:"deb",
            3:"ens"}

In [None]:
def inference(raw_preds, samples, tokenizer, cfg, bs, fold_num=0):
    
    #raw_preds = (raw_predss*w[0] + raw_predss*w[1]) / np.sum(w)
    #for raw_preds in raw_preds_all:
    samples_ = samples.copy()
    preds_class, preds_prob = [], []
    for rp in raw_preds:
        c_arr = np.argmax(rp, axis=2)
        p_arr = np.max(rp, axis=2)
        for c, p in zip(c_arr, p_arr):
            c = c.tolist()
            p = p.tolist()
            preds_class.append(c)
            preds_prob.append(p)

    for i in range(len(samples)):
        # 先頭0はspecial tokenなので除外
        text_class = [id_target_map[c] for c in preds_class[i][1:]]
        text_prob = preds_prob[i][1:]
        samples_[i]['pred_class'] = text_class
        samples_[i]['pred_prob'] = text_prob
    #samples_all.append(samples_)

    return samples#_all


def reshape_preds_24_to_15(preds):
    new_preds = np.zeros((preds.shape[0], preds.shape[1], 15))
    new_preds[:, :, 0] = preds[:, :, 0]
    new_preds[:, :, 1] += preds[:, :, 1] + preds[:, :, 2]
    new_preds[:, :, 2] += preds[:, :, 3]
    new_preds[:, :, 3] += preds[:, :, 4] + preds[:, :, 5]
    new_preds[:, :, 4] += preds[:, :, 6]
    new_preds[:, :, 5] += preds[:, :, 7] + preds[:, :, 8]
    new_preds[:, :, 6] += preds[:, :, 9]
    new_preds[:, :, 7] += preds[:, :, 10] + preds[:, :, 11]
    new_preds[:, :, 8] += preds[:, :, 12]
    new_preds[:, :, 9] += preds[:, :, 13] + preds[:, :, 14]
    new_preds[:, :, 10] += preds[:, :, 15]
    new_preds[:, :, 11] += preds[:, :, 16] + preds[:, :, 17]
    new_preds[:, :, 12] += preds[:, :, 18]
    new_preds[:, :, 13] += preds[:, :, 19] + preds[:, :, 20]
    new_preds[:, :, 14] += preds[:, :, 21] + preds[:, :, 22] + preds[:, :, 23]
    return new_preds

def reshape_preds_22_to_15(preds):
    new_preds = np.zeros((preds.shape[0], preds.shape[1], 15))
    new_preds[:, :, 0] = preds[:, :, 0]
    new_preds[:, :, 1] += preds[:, :, 1] + preds[:, :, 2]
    new_preds[:, :, 2] += preds[:, :, 3]
    new_preds[:, :, 3] += preds[:, :, 4] + preds[:, :, 5]
    new_preds[:, :, 4] += preds[:, :, 6]
    new_preds[:, :, 5] += preds[:, :, 7] + preds[:, :, 8]
    new_preds[:, :, 6] += preds[:, :, 9]
    new_preds[:, :, 7] += preds[:, :, 10] + preds[:, :, 11]
    new_preds[:, :, 8] += preds[:, :, 12]
    new_preds[:, :, 9] += preds[:, :, 13] + preds[:, :, 14]
    new_preds[:, :, 10] += preds[:, :, 15]
    new_preds[:, :, 11] += preds[:, :, 16] + preds[:, :, 17]
    new_preds[:, :, 12] += preds[:, :, 18]
    new_preds[:, :, 13] += preds[:, :, 19] + preds[:, :, 20]
    new_preds[:, :, 14] += preds[:, :, 21]
    return new_preds

## Post Process

In [None]:
def post_process(samples):
    sub = create_submission(samples)
    sub = thresh_prob(sub, cfg)
    sub = thresh_min_token(sub, cfg)
    sub = get_max_prob(sub)
    sub = link_class(sub, 'Evidence', cfg.link['Evidence'])
    sub = link_class(sub, 'Counterclaim', cfg.link['Counterclaim'])
    sub = link_class(sub, 'Rebuttal', cfg.link['Rebuttal'])
    sub = sub.reset_index(drop=True)
    sub = sub[['id', 'class', 'predictionstring']]
    return sub


def create_submission(samples):
    sub = []
    for _, sample in enumerate(samples):
        pred_class = sample['pred_class']
        offset_mapping = sample['offset_mapping']
        sample_id = sample['id']
        sample_text = sample['text']
        pred_prob = sample['pred_prob']
        
        second_class = sample['second_class']
        second_prob = sample['second_prob']
        
        third_class = sample['third_class']
        third_prob = sample['third_prob']
        

        pred_class = fix_list_(pred_class)

        sample_preds = []
        # テキストが4096より長い場合
        if len(pred_class) < len(offset_mapping)-1:
            pred_class += ['O'] * (len(offset_mapping) - len(pred_class))
            second_class += ['O'] * (len(offset_mapping) - len(second_class))
            third_class += ['O'] * (len(offset_mapping) - len(third_class))
            
            pred_prob += [0] * (len(offset_mapping) - len(pred_prob))
            second_prob += [0] * (len(offset_mapping) - len(second_prob))
            third_prob += [0] * (len(offset_mapping) - len(third_prob))

        idx = 0
        phrase_preds = []
        while idx < len(offset_mapping)-1:
            start, _ = offset_mapping[idx]
            if pred_class[idx] != 'O':
                label = pred_class[idx][2:]
            else:
                label = 'O'
                
            if second_class[idx] != 'O':
                second_label = second_class[idx][2:]
            else:
                second_label = 'O'
                
            if pred_class[idx] != 'O':
                third_label = third_class[idx][2:]
            else:
                third_label = 'O'
            
            
            phrase_probs = []
            phrase_probs.append(pred_prob[idx])
            
            phrase_second_probs = []
            phrase_second_probs.append(second_prob[idx])
            
            phrase_third_probs = []
            phrase_third_probs.append(third_prob[idx])
            
            idx += 1
            
            while idx < len(offset_mapping)-1:
                if label != 'O':
                    matching_label = f'I-{label}'
                else:
                    matching_label = 'O'
                if pred_class[idx] == matching_label:
                    _, end = offset_mapping[idx]
                    phrase_probs.append(pred_prob[idx])
                    phrase_second_probs.append(second_prob[idx])
                    phrase_third_probs.append(third_prob[idx])
                    
                    idx += 1
                else:
                    break
            if 'end' in locals():
                phrase = sample_text[start:end]
                phrase_preds.append((phrase, start, end, label, second_label, third_label, phrase_probs, phrase_second_probs, phrase_third_probs))

        temp_df = []
        for phrase_idx, (phrase, start, end, label, second_label, third_label, phrase_probs, phrase_second_probs, phrase_third_probs) in enumerate(phrase_preds):
            word_start = len(sample_text[:start].split())
            word_end = word_start + len(sample_text[start:end].split())
            word_end = min(word_end, len(sample_text.split()))
            ps = " ".join([str(x) for x in range(word_start, word_end)])
            if label != 'O':
                phrase_probs_mean = sum(phrase_probs) / len(phrase_probs)
                phrase_second_probs_mean = sum(phrase_second_probs) / len(phrase_probs)
                phrase_third_probs_mean = sum(phrase_third_probs) / len(phrase_probs)
                
                temp_df.append((sample_id, label, second_label, third_label, ps, phrase_probs_mean, phrase_second_probs_mean, phrase_third_probs_mean))
        temp_df = pd.DataFrame(temp_df, columns=['id', 'class', 'second_class', 'third_class', 'predictionstring', 'prob', 'second_prob', 'third_prob'])
        sub.append(temp_df)
    
    sub = pd.concat(sub).reset_index(drop=True)
    sub['len'] = sub['predictionstring'].apply(lambda x: len(x.split()))
    sub = sub[sub['len'] > 0]
    return sub

def fix_list(pred_list):

    class_list = ["I-Lead", "I-Position", "I-Evidence", "I-Claim", 
                  "I-Concluding Statement", "I-Counterclaim", "I-Rebuttal", "O"]
    
    fix_threholds = {
        "I-Lead":2, 
        "I-Concluding Statement":2, 
        "I-Evidence":1,
        "I-Position":2,
        "I-Claim":1,
        "I-Counterclaim":5, 
        "I-Rebuttal":7,
        "O":1
    }
    
    for class_ in class_list:

        flg_index = []
        out_class = [col for col in class_list if col not in class_]
        counter = 0

        for token_id, token in enumerate(pred_list):
            
            # 連続2回以上続いた後の別classにはflgを立てる
            if counter > 2 and token in out_class:
                flg_index.append(token_id)
                counter = 0

            if token == class_:
                counter += 1
            else:
                counter = 0
                
        for ind in flg_index:
            if ind + fix_threholds[class_] + 1 < len(pred_list):
                counter_2 = fix_threholds[class_]
                while counter_2 != 0:
                    if pred_list[ind + counter_2] == class_ and pred_list[ind + counter_2 + 1] == class_:
                        for i in range(counter_2):
                            pred_list[ind + i] = class_
                        counter_2 = 0
                    else:
                        counter_2 -= 1
                        
    return pred_list

def fix_list_(pred_list):
    class_list = ["I-Lead", "I-Position", "I-Evidence", "I-Claim", 
                  "I-Concluding Statement", "I-Counterclaim", "I-Rebuttal"]

    for class_ in class_list:
        flg_index = []
        out_class = set(class_list) - {class_}
        counter = 0

        for token_id, token in enumerate(pred_list):
            if counter > 2 and token in out_class:
                flg_index.append(token_id)
                counter = 0

            if token == class_:
                counter += 1
            else:
                counter = 0

        for ind in flg_index:
            if ind + 2 < len(pred_list):
                if pred_list[ind + 1] == class_ and pred_list[ind + 2] == class_:
                    pred_list[ind] = class_

    return pred_list


def jn(pst, start, end):
    return " ".join([str(x) for x in pst[start:end] if x != -1])


def link_class(oof, discourse_type, thresh2):
    id_list = oof['id'].unique().tolist()
    if not len(oof):
        return oof
    thresh = 1
    idu = oof['id'].unique()
    eoof = oof[oof['class'] == f"{discourse_type}"]
    neoof = oof[oof['class'] != f"{discourse_type}"]
    eoof.index = eoof[['id', 'class']]
    
    retval = []
    for idv in idu:
        q = eoof[eoof['id'] == idv]
        if not len(q):
            continue
        pst = []
        for r in q.itertuples():
            pst = [*pst, -1,  *[int(x) for x in r.predictionstring.split()]]
        start, end = 1, 1
        for i in range(2, len(pst)):
            cur = pst[i]
            end = i
            if  (
                (cur == -1) and
                ((pst[i + 1] > pst[end - 1] + thresh) or (pst[i + 1] - pst[start] > thresh2))
            ):
                retval.append((idv, discourse_type, jn(pst, start, end)))
                start = i + 1
        v = (idv, discourse_type, jn(pst, start, end + 1))
        retval.append(v)

    roof = pd.DataFrame(retval, columns=['id', 'class', 'predictionstring'])
    roof = roof.merge(neoof, how='outer')
    
    dfs = []
    for doc_id in id_list:
        r_df_tmp = roof.query(f'id == "{doc_id}"')
        r_df_tmp['start'] = r_df_tmp['predictionstring'].apply(lambda x: int(x.split(' ')[0]))
        r_df_tmp = r_df_tmp.sort_values('start').drop('start', axis=1)
        dfs.append(r_df_tmp)
    return pd.concat(dfs, axis=0)


def thresh_prob(df, cfg):
    df_other = df[(df['class'] != 'Claim') | (df['len'] != 2)]
    df_target = df[(df['class'] == 'Claim') & (df['len'] == 2)]
    df_target['prob'] -= 0.1
    df = pd.concat([df_other, df_target])
    df = df.sort_index()
    for k, v in cfg.proba_thresh.items():
        idx = df.loc[df['class'] == k].query(f'prob < {v}').index
        df = df.drop(idx)
    return df

# add
def thresh_second_prob(df, cfg):
    df = df.sort_index()
    for k, v in cfg.second_proba_thresh.items():
        idx = df.loc[df['class'] == k].query(f'second_prob > {v}').index
        df = df.drop(idx)
    return df

# add
def thresh_third_prob(df, cfg):
    df = df.sort_index()
    for k, v in cfg.second_proba_thresh.items():
        idx = df.loc[df['class'] == k].query(f'third_prob > {v}').index
        df = df.drop(idx)
    return df


def thresh_min_token(df, cfg):
    df['len'] = df['predictionstring'].apply(lambda x: len(x.split(' ')))
    for k, v in cfg.min_token_thresh.items():
        idx = df.loc[df['class'] == k].query(f'len < {v}').index
        df = df.drop(idx)
    return df


def get_max_prob(sub):
    sub['prob'] = sub['prob'].astype(float)
    id_list = sub['id'].unique().tolist()
    unique_class = ['Lead', 'Position', 'Concluding Statement']
    sub_in_unique = sub[sub['class'].isin(unique_class) == True]
    sub_not_in_unique = sub[sub['class'].isin(unique_class) == False]
    sub_in_unique = sub_in_unique.loc[sub_in_unique.groupby(['id', 'class'])['prob'].idxmax(), :]
    sub = pd.concat([sub_in_unique, sub_not_in_unique])
    return sub


In [None]:
# =============================
# Metrics
# =============================   
def calc_overlap(row):
    """
    ref: https://www.kaggle.com/robikscube/student-writing-competition-twitch
    """
    set_pred = set(row["predictionstring_pred"].split(" "))
    set_true = set(row["predictionstring_true"].split(" "))

    len_true, len_pred = len(set_true), len(set_pred)
    intersection = len(set_true.intersection(set_pred))

    overlap_t_and_p = intersection / len_true
    overlap_p_and_t = intersection / len_pred

    return [overlap_t_and_p, overlap_p_and_t]


def get_micro_f1_score(pred_df, true_df):

    true_df = (true_df[["id", "discourse_type", "predictionstring"]]
               .reset_index(drop=True)
               .copy())
    pred_df = pred_df[["id", "class", "predictionstring"]].reset_index(drop=True).copy()
    pred_df["pred_id"] = pred_df.index
    true_df["true_id"] = true_df.index

    # 1. all ground truths and predictions for a given class are compared.
    joined_df = pred_df.merge(
        true_df,
        left_on=["id", "class"],
        right_on=["id", "discourse_type"],
        how="outer",
        suffixes=("_pred", "_true"),
    )
    joined_df["predictionstring_true"].fillna(" ", inplace=True)
    joined_df["predictionstring_pred"].fillna(" ", inplace=True)
    joined_df["overlaps"] = joined_df.apply(calc_overlap, axis=1)

    # 2. If the overlap between the ground truth and prediction is >= 0.5,
    #    and the overlap between the prediction and the ground truth >= 0.5,
    #    the prediction is a match and considered a true positive.
    #    If multiple matches exist, the match with the highest pair of overlaps is taken.
    joined_df["overlap1"] = joined_df["overlaps"].apply(lambda x: eval(str(x))[0])
    joined_df["overlap2"] = joined_df["overlaps"].apply(lambda x: eval(str(x))[1])

    joined_df["potential_TP"] = (joined_df["overlap1"] >= 0.5) & (joined_df["overlap2"] >= 0.5)
    joined_df["max_overlap"] = joined_df[["overlap1", "overlap2"]].max(axis=1)
    tp_pred_ids = (
        joined_df.query("potential_TP")
        .sort_values("max_overlap", ascending=False)
        .groupby(["id", "predictionstring_true"])
        .first()["pred_id"]
        .values
    )

    # 3. Any unmatched ground truths are false negatives
    #    and any unmatched predictions are false positives.
    fp_pred_ids = [p for p in joined_df["pred_id"].unique() if p not in tp_pred_ids]
    matched_gt_ids = joined_df.query("potential_TP")["true_id"].unique()
    unmatched_gt_ids = [c for c in joined_df["true_id"].unique() if c not in matched_gt_ids]

    # Get numbers of each type
    TP = len(tp_pred_ids)
    FP = len(fp_pred_ids)
    FN = len(unmatched_gt_ids)
    # calc microF1
    my_f1_score = TP / (TP + 0.5 * (FP + FN))
    return my_f1_score


def get_score(pred_df, true_df, return_class_scores=False):
    class_scores = {}
    pred_df = pred_df[["id", "class", "predictionstring"]].reset_index(drop=True).copy()
    for discourse_type, true_subset in true_df.groupby("discourse_type"):
        pred_subset = (
            pred_df.loc[pred_df["class"] == discourse_type]
            .reset_index(drop=True)
            .copy()
        )
        class_score = get_micro_f1_score(pred_subset, true_subset)
        class_score = np.round(class_score, decimals=4)
        class_scores[discourse_type] = class_score

    score = np.mean([v for v in class_scores.values()])
    score = np.round(score, decimals=4)

    if return_class_scores:
        return score, class_scores

    return score

# main

In [None]:
test_df = pd.read_csv('../input/feedback-prize-2021/train.csv')

#test_df = pd.read_csv('../input/fb-train-folds/train_folds.csv')
#test_df = test_df[test_df['kfold']==fold_num].reset_index(drop=True)
test_ids = test_df['id'].unique()
test_ids = test_ids[~(test_ids=='AD005493F9BF')]
# for debug
# test_ids = test_ids[:100]
test_df = test_df[test_df['id'].isin(test_ids)]
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer)
test_samples = prepare_test_data(test_ids, tokenizer, cfg.num_jobs)

In [None]:
fold_df = pd.read_csv('../input/fb-train-folds/train_folds.csv')
fold_df = fold_df[['id', 'kfold']].drop_duplicates()
fold_dict = dict(zip(fold_df['id'], fold_df['kfold']))

In [None]:
%%time

def create_id_df(samples):
    
    id_df = pd.DataFrame()
    
    offset_list = []
    id_list = []
    num_list = []
    len_list = []
        
    for i in tqdm(range(len(samples))):
        offset_list += (samples[i]['offset_mapping'] + [(9999,9999)] * (cfg.max_len_test_longformer - len(samples[i]['offset_mapping'])))
        id_list += [samples[i]['id']] * cfg.max_len_test_longformer
        num_list += list(range(cfg.max_len_test_longformer))
        len_list += [len(samples[i]['offset_mapping'])] * cfg.max_len_test_longformer
        
    id_df['offset'] = offset_list
    id_df['id'] = id_list
    id_df['token_num'] = num_list
    id_df['token_len'] = len_list
    id_df['token_key'] = id_df['id'].astype(str) + '_' + id_df['token_num'].astype(str)
    id_df['offset_start'] = id_df['offset'].map(lambda x: x[0])
    
    return id_df

In [None]:
%%time
txt_dict = {}
for ids in tqdm(test_ids):
    txt_dict[ids] = get_test_text(ids)

In [None]:
paths = []
scores = []
paths_ = [
    #'../input/oof-pred-108/deberta-large-108.pkl',
    '../input/fb-oof-dict/oof_dict_yyama_exp045.pickle',
 #'../input/fb-oof-dict/oof_dict_makotu_021410-lf-bie-bin.pickle',
 '../input/fb-oof-dict/oof_dict_makabe_feedback-101_LSTM_2head.pickle',
 #'../input/fb-oof-dict/oof_dict_makotu_030209-deberta-large-mnli-boe-bin.pickle',
 '../input/fb-oof-dict/oof_dict_022709-deberta-large-boe-bin.pickle',
 '../input/fb-oof-dict/oof_dict_makabe_feedback-096.pickle',
 '../input/convert-fp64-to-fp16-inference98/oof_dict_makabe_feedback-098_LSTM.pickle',
 '../input/fb-oof-dict/oof_dict_yyama_exp058-wo-mlm.pickle',
'../input/fb-oof-dict/oof_dict_yyama_fb-exp037-led.pickle',
'../input/fb-oof-dict-exp062-xlarge-f24/oof_dict_yyama_fb_exp06x_xlarge_fold2-4.pickle']

weight = [1.8, 2.6, 1.3, 2.1, 2.7, 1, 1.6, 1.4]

In [None]:
%%time

def inference(oof, samples, tokenizer, cfg, fold_num=0):
    
    samples_ = samples.copy()
    preds_class, preds_prob = [], []
    
    arg_srt = np.argsort(oof, axis=-1)
    srt = np.sort(oof, axis=-1)
    
    #oof_discourse = np.zeros([oof.shape[0], oof.shape[1], 8])
    #oof_discourse[:, :, 0] = oof[:, :, 0] + oof[:, :, 1]
    #oof_discourse[:, :, 1] = oof[:, :, 2] + oof[:, :, 3]
    #oof_discourse[:, :, 2] = oof[:, :, 4] + oof[:, :, 5]
    #oof_discourse[:, :, 3] = oof[:, :, 6] + oof[:, :, 7]
    #oof_discourse[:, :, 4] = oof[:, :, 8] + oof[:, :, 9]
    #oof_discourse[:, :, 5] = oof[:, :, 10] + oof[:, :, 11]
    #oof_discourse[:, :, 6] = oof[:, :, 12] + oof[:, :, 13]
    #oof_discourse[:, :, 7] = oof[:, :, 14]
    
    #argm = np.argmax(oof, axis=2)
    #m = np.max(oof, axis=2)
    
    argm = arg_srt[:, :, -1]
    m = srt[:, :, -1]
    
    argsecond = arg_srt[:, :, -2]
    second = srt[:, :, -2]
    
    argthird = arg_srt[:, :, -3]
    third = srt[:, :, -3]
    
    #m = np.max(oof_discourse, axis=2)
    

    for i in tqdm(range(len(samples))):
        # 先頭0はspecial tokenなので除外
        text_class = [id_target_map[c] for c in list(argm[i, 1:])]
        text_prob = list(m[i, 1:])
        samples_[i]['pred_class'] = text_class
        samples_[i]['pred_prob'] = text_prob
        
        second_class = [id_target_map[c] for c in list(argsecond[i, 1:])]
        second_prob = list(second[i, 1:])
        samples_[i]['second_class'] = second_class
        samples_[i]['second_prob'] = second_prob
        
        third_class = [id_target_map[c] for c in list(argthird[i, 1:])]
        third_prob = list(third[i, 1:])
        samples_[i]['third_class'] = third_class
        samples_[i]['third_prob'] = third_prob
        
    #samples_all.append(samples_)

    return samples#_all

#sample1 = inference(oof, test_samples, tokenizer, cfg, fold_num=0)

In [None]:
import re

def word_list(text):
    a = re.split('(\s+)', text)
    return [a[i*2] + a[i*2+1] for i in range(len(a)//2)]

def discourse_offset(text, lst, start, end):
    words = ''.join(lst[start:end+1])
    return (text.find(words), text.find(words)+len(words)-1)

def add_offset(sub):
    sub['start'] = sub['predictionstring'].map(lambda x: int(x.split()[0]))
    sub['end'] = sub['predictionstring'].map(lambda x: int(x.split()[-1]))    

    sub['text'] = sub['id'].map(txt_dict)
    sub['offset'] = sub.apply(lambda x: discourse_offset(x['text'], word_list(x['text']), x['start'], x['end']), axis=1)

    sub['start'] = sub['offset'].map(lambda x: x[0])
    sub['end'] = sub['offset'].map(lambda x: x[1])
    sub['discourse_num'] = sub.groupby('id')['class'].cumcount()
    
    return sub

In [None]:
#import sys
#sys.path.append("../input/pandas-bj/")
#!pip install ../input/pandas-bj/pandas_bj-0.1.1-py3-none-any.whl

In [None]:
#oof = oof[oof['token_key'].isin(id_df_2['token_key'])]

In [None]:
#%%time

#oof['discourse_key'] = oof['token_key'].map(dict(zip(id_df_2['token_key'], id_df_2['discourse_key'])))

In [None]:
def create_oof(path):
    oof_dict = pickle.load(open(path, 'rb'))
    
    if 'xlarge' in path:
        missing_ids = [id_ for id_ in test_ids if id_ not in oof_dict.keys()]
        for id_ in missing_ids:
            oof_dict[id_] = (np.nan * np.zeros([1, cfg.max_len_test_longformer,15])).astype(np.float16)
    
    oof = np.vstack([oof_dict[id_] for id_ in test_ids])
    
    if '108' in path:
        oof = oof.astype(np.float16)
    return oof

def create_sub_and_id_df(oof):
    sample1 = inference(oof, test_samples, tokenizer, cfg, fold_num=0)
    sub = create_submission(sample1)
    sub = add_offset(sub)
    
    id_df_2 = pd.DataFrame()
    for i in tqdm(range(len(test_ids) // 300 + 1)):
        a =id_df.iloc[i*300*cfg.max_len_test_longformer:(i+1)*300*cfg.max_len_test_longformer, :].merge(sub[['id', 'discourse_num', 'start', 'end']], on=['id'], how='left').reset_index(drop=True)
        a = a[(a['offset_start'] >= a['start']) & (a['offset_start'] <= a['end'])]
        id_df_2 = id_df_2.append(a)
        
    id_df_2['discourse_key'] = id_df_2['id'].astype(str) + '_' + id_df_2['discourse_num'].astype(str)
    sub['discourse_key'] = sub['id'].astype(str) + '_' + sub['discourse_num'].astype(str)
    
    return sub, id_df_2

In [None]:
def create_prob_agg_df(oof):
    value_list = [k for k in list(target_id_map.keys()) if k != 'PAD']

    dfg = oof.groupby('discourse_key')[value_list]

    dfg_stats = dfg.agg([np.mean, np.max, np.min]).stack()
    dfg_quantiles = dfg.quantile([0.2, 0.8])

    dfg_stats = dfg_stats.append(dfg_quantiles).sort_index()
    prob_agg_df = dfg_stats.unstack().astype(np.float16)

    prob_agg_df.columns = [col[0] + '_' + str(col[1]) for col in prob_agg_df.columns.values]
    
    print(dfg.head())
    
    return prob_agg_df

#oof[:480000].groupby('discourse_key').agg({k: ['mean', 'min', 'max', q20, q80] for k in list(target_id_map.keys())[:1] if k != 'PAD'})

In [None]:
#%%time
#prob_agg_df = create_prob_agg_df(oof)

In [None]:
def create_feats(path, id_df_2):
    #path = '../input/fb-oof-dict/oof_dict_022709-deberta-large-boe-bin.pickle'
    oof = create_oof(path)

    oof = pd.DataFrame(oof.reshape(oof.shape[0] * oof.shape[1], oof.shape[2]))
    model_name = path.split('/')[-1].replace(".pickle", "").replace("oof_dict_", "")

    oof.columns = oof.columns.map(id_target_map)

    oof['token_key'] = id_df['token_key']
    oof = oof[oof['token_key'].isin(id_df_2['token_key'])]
    oof['discourse_key'] = oof['token_key'].map(dict(zip(id_df_2['token_key'], id_df_2['discourse_key'])))

    prob_agg_df = create_prob_agg_df(oof)
    prob_agg_df.columns = [model_name + '_' + col for col in prob_agg_df.columns]
    return prob_agg_df

In [None]:
def create_sub_and_feats(path, model_paths):
    oof = create_oof(path)
    sub, id_df_2 = create_sub_and_id_df(oof)

    #feats_list = []
    for model_path in model_paths:
        feats  = create_feats(model_path, id_df_2)
        sub = sub.merge(feats.reset_index(), on='discourse_key', how='left')
        #feats_list.append(feats)
    #feats = pd.concat(feats_list, axis=1)
    del id_df_2
    
    sub['text_len'] = sub['text'].map(lambda x: len(x))
    
    return sub

In [None]:
def create_ensemble_sub_and_feats(model_paths, weight):
    
    oof = np.zeros([len(test_ids), cfg.max_len_test_longformer , 15])
    for path, w in zip(model_paths, weight):
        oof += create_oof(path) * w / np.sum(weight)
    
    sub, id_df_2 = create_sub_and_id_df(oof)

    #feats_list = []
    for model_path in model_paths:
        feats  = create_feats(model_path, id_df_2)
        sub = sub.merge(feats.reset_index(), on='discourse_key', how='left')
        #feats_list.append(feats)
    #feats = pd.concat(feats_list, axis=1)
    del id_df_2
    
    sub['text_len'] = sub['text'].map(lambda x: len(x))
    
    return sub

In [None]:
def post_process_sub(sub):
    sub = thresh_prob(sub, cfg)
    sub = thresh_min_token(sub, cfg)
    sub = get_max_prob(sub)
    sub = link_class(sub, 'Evidence', cfg.link['Evidence'])
    sub = link_class(sub, 'Counterclaim', cfg.link['Counterclaim'])
    sub = link_class(sub, 'Rebuttal', cfg.link['Rebuttal'])
    sub = sub.reset_index(drop=True)
    sub = sub[['id', 'class', 'predictionstring', 'prob', 'len']]
    return sub

In [None]:
paths = []
scores = []
paths_ = [
    #'../input/oof-pred-108/deberta-large-108.pkl',
    '../input/fb-oof-dict/oof_dict_yyama_exp045.pickle',
 #'../input/fb-oof-dict/oof_dict_makotu_021410-lf-bie-bin.pickle',
 '../input/fb-oof-dict/oof_dict_makabe_feedback-101_LSTM_2head.pickle',
 #'../input/fb-oof-dict/oof_dict_makotu_030209-deberta-large-mnli-boe-bin.pickle',
 '../input/fb-oof-dict/oof_dict_022709-deberta-large-boe-bin.pickle',
 '../input/fb-oof-dict/oof_dict_makabe_feedback-096.pickle',
 '../input/convert-fp64-to-fp16-inference98/oof_dict_makabe_feedback-098_LSTM.pickle',
 '../input/fb-oof-dict/oof_dict_yyama_exp058-wo-mlm.pickle',
'../input/fb-oof-dict/oof_dict_yyama_fb-exp037-led.pickle',
'../input/fb-oof-dict-exp062-xlarge-f24/oof_dict_yyama_fb_exp06x_xlarge_fold2-4.pickle']

weight = [1.1, 2.6, 1.3, 2.1, 2.7, 1, 1.6, 1.4]

In [None]:
# exp045 1.8
(0.7158, {'Claim': 0.6847, 'Concluding Statement': 0.8736, 'Counterclaim': 0.5908, 'Evidence': 0.7785, 'Lead': 0.844, 'Position': 0.7418, 'Rebuttal': 0.4974})

# exp045 1.1
(0.7162, {'Claim': 0.6849, 'Concluding Statement': 0.8734, 'Counterclaim': 0.5902, 'Evidence': 0.7788, 'Lead': 0.8441, 'Position': 0.7419, 'Rebuttal': 0.5})

In [None]:
%%time
id_df = create_id_df(test_samples)
for path in paths_:
    model_name = path.split('/')[-1].replace(".pickle", "").replace("oof_dict_", "")
    print(model_name)
    sub = create_sub_and_feats(path, paths_)
    sub['model_name'] = model_name
    sub['fold'] = sub['id'].map(fold_dict)
    sub.to_pickle(f'feats_{model_name}.pkl')
    
    #post_sub = post_process_sub(sub[sub.fold.isin([2,3])][['id', 'class', 'predictionstring', 'prob', 'len']])
    #print(get_score(post_sub, test_df, return_class_scores=True))
    
    del sub
#    del sub

In [None]:
%%time

test_df = pd.read_csv('../input/feedback-prize-2021/train.csv')
test_ids = test_df['id'].unique()
test_ids = test_ids[~(test_ids=='AD005493F9BF')]
test_df = test_df[test_df['id'].isin(test_ids)]
test_df['fold'] = test_df['id'].map(fold_dict)
test_df = test_df[test_df.fold.isin([2,3,4])]

#test_df = test_df[test_df['id'].isin(test_df['id'].unique()[:10])]

test_ids = test_df['id'].unique()

tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer)
test_samples = prepare_test_data(test_ids, tokenizer, cfg.num_jobs)
id_df = create_id_df(test_samples)

sub = create_ensemble_sub_and_feats(paths_, weight)
sub['fold'] = sub['id'].map(fold_dict)
sub.to_pickle(f'feats_ensemble_0314.pkl')

post_sub = post_process_sub(sub[sub.fold.isin([2,3,4])][['id', 'class', 'predictionstring', 'prob', 'len']])
print(get_score(post_sub, test_df, return_class_scores=True))

In [None]:
sub

In [None]:
#weight = [0.5, 1, 1, 1, 1, 1, 1.5]
(0.714, {'Claim': 0.6843, 'Concluding Statement': 0.8727, 'Counterclaim': 0.5866, 'Evidence': 0.7808, 'Lead': 0.8431, 'Position': 0.7423, 'Rebuttal': 0.4881})

#weight = [0.5, 1, 1, 1, 1, 1, 0]
(0.7136, {'Claim': 0.685, 'Concluding Statement': 0.8722, 'Counterclaim': 0.5858, 'Evidence': 0.7806, 'Lead': 0.8438, 'Position': 0.7385, 'Rebuttal': 0.4891})

#weight = [0.5, 1, 1, 1, 1, 1, 0.75]
(0.7141, {'Claim': 0.6849, 'Concluding Statement': 0.8728, 'Counterclaim': 0.5866, 'Evidence': 0.7808, 'Lead': 0.8441, 'Position': 0.7418, 'Rebuttal': 0.4879})

#weight = [0.5, 1, 1, 1, 1, 1, 0.75]
(0.7135, {'Claim': 0.6844, 'Concluding Statement': 0.8714, 'Counterclaim': 0.5856, 'Evidence': 0.7806, 'Lead': 0.8429, 'Position': 0.7413, 'Rebuttal': 0.488})
