# Beam search inference

I didn't have enough computational resources, so I improved only the inference of https://www.kaggle.com/abhishek/two-longformers-are-better-than-1 .  
I defined the evaluation function and searched for the prediction that maximized it by using beam search.

In [None]:
import gc
gc.enable()

import sys
sys.path.append("../input/tez-lib/")

import os
import bisect
import copy
import time

from scipy.stats import norm
import numpy as np
import pandas as pd
import tez
import torch
import torch.nn as nn
from joblib import Parallel, delayed
from transformers import AutoConfig, AutoModel, AutoTokenizer

In [None]:
target_id_map = {
    "B-Lead": 0,
    "I-Lead": 1,
    "B-Position": 2,
    "I-Position": 3,
    "B-Evidence": 4,
    "I-Evidence": 5,
    "B-Claim": 6,
    "I-Claim": 7,
    "B-Concluding Statement": 8,
    "I-Concluding Statement": 9,
    "B-Counterclaim": 10,
    "I-Counterclaim": 11,
    "B-Rebuttal": 12,
    "I-Rebuttal": 13,
    "O": 14,
    "PAD": -100,
}


id_target_map = {v: k for k, v in target_id_map.items()}

class args1:
    input_path = "../input/feedback-prize-2021/"
    model = "../input/longformerlarge4096/longformer-large-4096/"
    tez_model= "../input/fblongformerlarge1536/"
    output = "."
    batch_size = 8
    max_len = 4096
    
class args2:
    input_path = "../input/feedback-prize-2021/"
    model = "../input/longformerlarge4096/longformer-large-4096/"
    tez_model= "../input/tez-fb-large/"
    output = "."
    batch_size = 8
    max_len = 4096

In [None]:
class FeedbackDataset:
    def __init__(self, samples, max_len, tokenizer):
        self.samples = samples
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.length = len(samples)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_ids = self.samples[idx]["input_ids"]
        # print(input_ids)
        # print(input_labels)

        # add start token id to the input_ids
        input_ids = [self.tokenizer.cls_token_id] + input_ids

        if len(input_ids) > self.max_len - 1:
            input_ids = input_ids[: self.max_len - 1]

        # add end token id to the input_ids
        input_ids = input_ids + [self.tokenizer.sep_token_id]
        attention_mask = [1] * len(input_ids)

        return {
            "ids": input_ids,
            "mask": attention_mask,
        }

In [None]:
class Collate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        output = dict()
        output["ids"] = [sample["ids"] for sample in batch]
        output["mask"] = [sample["mask"] for sample in batch]

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in output["ids"]])

        # add padding
        if self.tokenizer.padding_side == "right":
            output["ids"] = [s + (batch_max - len(s)) * [self.tokenizer.pad_token_id] for s in output["ids"]]
            output["mask"] = [s + (batch_max - len(s)) * [0] for s in output["mask"]]
        else:
            output["ids"] = [(batch_max - len(s)) * [self.tokenizer.pad_token_id] + s for s in output["ids"]]
            output["mask"] = [(batch_max - len(s)) * [0] + s for s in output["mask"]]

        # convert to tensors
        output["ids"] = torch.tensor(output["ids"], dtype=torch.long)
        output["mask"] = torch.tensor(output["mask"], dtype=torch.long)

        return output

In [None]:
class FeedbackModel(tez.Model):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        config = AutoConfig.from_pretrained(model_name)

        hidden_dropout_prob: float = 0.18
        layer_norm_eps: float = 17589e-7
        config.update(
            {
                "output_hidden_states": True,
                "hidden_dropout_prob": hidden_dropout_prob,
                "layer_norm_eps": layer_norm_eps,
                "add_pooling_layer": False,
            }
        )
        self.transformer = AutoModel.from_config(config)
        self.output = nn.Linear(config.hidden_size, self.num_labels)

    def forward(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        logits = self.output(sequence_output)
        logits = torch.softmax(logits, dim=-1)
        return logits, 0, {}

In [None]:
def _prepare_test_data_helper(args, tokenizer, ids):
    test_samples = []
    for idx in ids:
        filename = os.path.join(args.input_path, "test", idx + ".txt")
        with open(filename, "r") as f:
            text = f.read()

        encoded_text = tokenizer.encode_plus(
            text,
            add_special_tokens=False,
            return_offsets_mapping=True,
        )
        input_ids = encoded_text["input_ids"]
        offset_mapping = encoded_text["offset_mapping"]

        sample = {
            "id": idx,
            "input_ids": input_ids,
            "text": text,
            "offset_mapping": offset_mapping,
        }

        test_samples.append(sample)
    return test_samples


def prepare_test_data(df, tokenizer, args):
    test_samples = []
    ids = df["id"].unique()
    ids_splits = np.array_split(ids, 4)

    results = Parallel(n_jobs=4, backend="multiprocessing")(
        delayed(_prepare_test_data_helper)(args, tokenizer, idx) for idx in ids_splits
    )
    for result in results:
        test_samples.extend(result)

    return test_samples

In [None]:
df = pd.read_csv(os.path.join("../input/feedback-prize-2021/", "sample_submission.csv"))
df_ids = df["id"].unique()

tokenizer = AutoTokenizer.from_pretrained(args1.model)
test_samples = prepare_test_data(df, tokenizer, args1)
collate = Collate(tokenizer=tokenizer)

raw_preds = []
n_fold = 10
for fold_ in range(n_fold):
    current_idx = 0
    test_dataset = FeedbackDataset(test_samples, args1.max_len, tokenizer)
    
    if fold_ < 5:
        model = FeedbackModel(model_name=args1.model, num_labels=len(target_id_map) - 1)
        model.load(os.path.join(args1.tez_model, f"model_{fold_}.bin"), weights_only=True)
        
        preds_iter = model.predict(test_dataset, batch_size=args1.batch_size, n_jobs=-1, collate_fn=collate)
    else:
        model = FeedbackModel(model_name=args2.model, num_labels=len(target_id_map) - 1)
        model.load(os.path.join(args2.tez_model, f"model_{fold_-5}.bin"), weights_only=True)
        
        preds_iter = model.predict(test_dataset, batch_size=args2.batch_size, n_jobs=-1, collate_fn=collate)
        
    current_idx = 0
    
    for preds in preds_iter:
        preds = preds.astype(np.float16)
        preds = preds / n_fold
        if fold_ == 0:
            raw_preds.append(preds)
        else:
            raw_preds[current_idx] += preds
            current_idx += 1
    torch.cuda.empty_cache()
    gc.collect()

## Beam search

In [None]:
B_MIX_RATIO = 0.5
BEAM_WIDTH = 3
def beam_inference(raw_preds, samples, transition_cost=0.08):
    target_id_map = {
        "B-Lead": 0,
        "I-Lead": 1,
        "B-Position": 2,
        "I-Position": 3,
        "B-Evidence": 4,
        "I-Evidence": 5,
        "B-Claim": 6,
        "I-Claim": 7,
        "B-Concluding Statement": 8,
        "I-Concluding Statement": 9,
        "B-Counterclaim": 10,
        "I-Counterclaim": 11,
        "B-Rebuttal": 12,
        "I-Rebuttal": 13,
        "O": 14,
        "PAD": -100,
    }
    target_id_map2 = {
        "init": 0,
        "Lead": 1,
        "Position": 2,
        "Evidence": 3,
        "Claim": 4,
        "Concluding Statement": 5,
        "Counterclaim": 6,
        "Rebuttal": 7,
        "O": 8,
    }
    target_id_map_onlyclass = {
        "Lead": 0,
        "Position": 1,
        "Evidence": 2,
        "Claim": 3,
        "Concluding Statement": 4,
        "Counterclaim": 5,
        "Rebuttal": 6,
        "O": 7,
    }
    id_target_map = {v: k for k, v in target_id_map.items()}
    id_target_map2 = {v: k for k, v in target_id_map2.items()}
    id_target_map_onlyclass = {v: k for k, v in target_id_map_onlyclass.items()}

    norm0 = norm.pdf(0)

    proba_thresh = {
        "Lead": 0.7,
        "Position": 0.55,
        "Evidence": 0.65,
        "Claim": 0.55,
        "Concluding Statement": 0.7,
        "Counterclaim": 0.5,
        "Rebuttal": 0.55,
    }

    for k, v in proba_thresh.items():
        proba_thresh[k] = v * 0.85

    original_min_thresh = {
        "Lead": 9,
        "Position": 5,
        "Evidence": 14,
        "Claim": 3,
        "Concluding Statement": 11,
        "Counterclaim": 6,
        "Rebuttal": 4,
        "O": -100
    }
    min_thresh = {}

    for k, v in original_min_thresh.items():
        min_thresh[k] = v // 2

    final_preds = []
    final_scores = []
    final_raw_preds = []
    for rp in raw_preds:
        pred_class = np.argmax(rp, axis=2)
        pred_scrs = np.max(rp, axis=2)
        for pred, pred_scr, rp_each in zip(pred_class, pred_scrs, rp):
            pred = pred.tolist()
            pred_scr = pred_scr.tolist()
            final_preds.append(pred)
            final_scores.append(pred_scr)
            final_raw_preds.append(rp_each)

    transition_arr = np.loadtxt("../input/th-arr/transition_arr.csv")
    
    total_time1 = 0.0
    total_time2 = 0.0
    score_list = []
    for j in range(len(samples)):
        print(f"\r {j}/{len(samples)}", end="")
        raw_pred = final_raw_preds[j][1:]
        if j == 0:
            print(raw_pred.shape)

        division_cand_list = [0]
        init_division_list = [0]
        pre = id_target_map[np.argmax(raw_pred[0])]
        in_O = False
        for k, rp_each in enumerate(raw_pred):
            if k == 0:
                continue

            pred = id_target_map[np.argmax(rp_each)]
            rp_B_each = rp_each[[0, 2, 4, 6, 8, 10, 12]]

            if pred == "O" and pre != "O":
                division_cand_list.append(k)
                if rp_each[target_id_map["O"]] >= 0.9:
                    init_division_list.append(k)
                    in_O = True
            elif pred != "O" and pre == "O":
                division_cand_list.append(k)
                if in_O:
                    init_division_list.append(k)
                    in_O = False
            elif np.max(rp_B_each) >= 0.1:
                division_cand_list.append(k)

            pre = pred

        division_cand_list.append(raw_pred.shape[0])
        init_division_list.append(raw_pred.shape[0])

        def total_score(raw_pred, id_pred, transition_cost=transition_cost):
            pre = -1
            pre_state = "init"
            transition_score = 0.0
            count = 0
            for i, id_ in enumerate(id_pred):
                count += 1
                if pre != id_:
                    transition_score -= transition_cost
                    state = id_target_map[id_]
                    if state != "O":
                        state = state[2:]
                    if pre == -1 or not (id_target_map[pre].startswith("B-") and id_target_map[id_].startswith("I-")):
                        transition_score += 0.2 * transition_arr[target_id_map2[pre_state], target_id_map2[state]]
                        transition_score -= 0.2 * max(0, original_min_thresh[state] - count)
                    count = 0
                pre = id_
                pre_state = state

            return np.sum(raw_pred[np.arange(len(id_pred)), id_pred]) + transition_score

        modified_raw_pred = np.copy(raw_pred)
        modified_raw_pred[:, [0, 2, 4, 6, 8, 10, 12]] *= B_MIX_RATIO
        log_raw_pred = -np.log(1.0 + 1e-3 - raw_pred)
        modified_log_raw_pred = np.copy(log_raw_pred)
        modified_log_raw_pred[:, [0, 2, 4, 6, 8, 10, 12]] *= B_MIX_RATIO
        modified_log_raw_pred *= 1.0
        raw_pred_I = raw_pred[:, [1, 3, 5, 7, 9, 11, 13, 14]]
        raw_pred_B = raw_pred[:, [0, 2, 4, 6, 8, 10, 12, 14]]
        log_raw_pred_I = log_raw_pred[:, [1, 3, 5, 7, 9, 11, 13, 14]]
        log_raw_pred_B = log_raw_pred[:, [0, 2, 4, 6, 8, 10, 12, 14]]

        # Calculate sum in advance for speed up
        sum_dict = {}
        for i1, s in enumerate(division_cand_list[:-1]):
            for i2, e in enumerate(division_cand_list[i1 + 1:]):
                sum_dict[(s, e)] = np.sum(raw_pred_I[s:e], axis=0)

        division_list = init_division_list
        best_division_list = division_list
        max_score = 0.0
        exists_update = True
        division_beam = [(max_score, tuple(division_list))]
        while len(division_list) < len(division_cand_list):
            s_time = time.time()
            exists_update = False

            next_division_beam = []
            for _, division_tuple in division_beam:
                division_list = list(division_tuple)
                diff = set(division_cand_list) - set(division_list)
                for x in diff:
                    new_division_list = copy.copy(division_list)
                    index = bisect.bisect_left(new_division_list, x)
                    new_division_list.insert(index, x)
                    pred_list = []

                    time2 = time.time()
                    for s, e in zip(new_division_list[:-1], new_division_list[1:]):
                        pred_I_each = id_target_map_onlyclass[np.argmax(sum_dict[(s, e)])]
                        if pred_I_each == "O":
                            pred_list.append(np.array((e - s) * [target_id_map["O"]]))
                        else:
                            pred_list.append(np.array([target_id_map["B-" + pred_I_each]] + (e - s - 1) * [target_id_map["I-" + pred_I_each]]))
                    total_time2 += time.time() - time2

                    pred = np.concatenate(pred_list)
                    
                    score = total_score(modified_log_raw_pred, pred)
                    next_division_beam.append((score, tuple(new_division_list)))
                    if score > max_score:
                        max_score = score
                        exists_update = True
                        best_division_list = copy.copy(new_division_list)

            next_division_beam = list(set(next_division_beam))
            division_beam = sorted(next_division_beam, key=lambda x: x[0])[-BEAM_WIDTH:]

            total_time1 += time.time() - s_time
        division_list = best_division_list
        score_list.append(max_score)

        tt = []
        for start, end in zip(division_list[:-1], division_list[1:]):
            chunk_mean = np.mean(log_raw_pred_I[start:end] + B_MIX_RATIO * log_raw_pred_B[start:end], axis=0)
            pred_I = id_target_map_onlyclass[np.argmax(chunk_mean)]

            if pred_I == "O":
                chunk = ["O"] * (end - start)
            else:
                chunk = ["B-" + pred_I] + ["I-" + pred_I] * (end - start - 1)
            tt.extend(chunk)

        tt_score = final_scores[j][1:]
        samples[j]["preds"] = tt
        samples[j]["pred_scores"] = tt_score

    submission = []
    for _, sample in enumerate(samples):
        preds = sample["preds"]
        offset_mapping = sample["offset_mapping"]
        sample_id = sample["id"]
        sample_text = sample["text"]
        sample_pred_scores = sample["pred_scores"]

        # pad preds to same length as offset_mapping
        if len(preds) < len(offset_mapping):
            preds = preds + ["O"] * (len(offset_mapping) - len(preds))
            sample_pred_scores = sample_pred_scores + [0] * (len(offset_mapping) - len(sample_pred_scores))

        idx = 0
        phrase_preds = []
        while idx < len(offset_mapping):
            start, _ = offset_mapping[idx]
            if preds[idx] != "O":
                label = preds[idx][2:]
            else:
                label = "O"
            phrase_scores = []
            phrase_scores.append(sample_pred_scores[idx])
            idx += 1

            while idx < len(offset_mapping):
                if label == "O":
                    matching_label = "O"
                else:
                    matching_label = f"I-{label}"
                if preds[idx] == matching_label:
                    _, end = offset_mapping[idx]
                    phrase_scores.append(sample_pred_scores[idx])
                    idx += 1
                else:
                    break

            if "end" in locals():
                phrase = sample_text[start:end]
                phrase_preds.append((phrase, start, end, label, phrase_scores))

        temp_df = []
        for phrase_idx, (phrase, start, end, label, phrase_scores) in enumerate(phrase_preds):
            word_start = len(sample_text[:start].split())
            word_end = word_start + len(sample_text[start:end].split())
            word_end = min(word_end, len(sample_text.split()))
            ps = " ".join([str(x) for x in range(word_start, word_end)])
            if label != "O":
                if sum(phrase_scores) / len(phrase_scores) >= proba_thresh[label]:
                    temp_df.append((sample_id, label, ps))

        temp_df = pd.DataFrame(temp_df, columns=["id", "class", "predictionstring"])

        submission.append(temp_df)

    submission = pd.concat(submission).reset_index(drop=True)
    submission["len"] = submission.predictionstring.apply(lambda x: len(x.split()))

    def threshold(df):
        df = df.copy()
        for key, value in min_thresh.items():
            index = df.loc[df["class"] == key].query(f"len<{value}").index
            df.drop(index, inplace=True)
        return df

    submission = threshold(submission)

    print("=" * 30)
    print(total_time1)
    print(total_time2)

    print(f"score average = {np.mean(np.array(score_list))}")

    # drop len
    submission = submission.drop(columns=["len"])
    return submission

In [None]:
def jn(pst, start, end):
    return " ".join([str(x) for x in pst[start:end]])


def link_evidence(oof):
    thresh = 1
    idu = oof['id'].unique()
    idc = idu[1]
    eoof = oof[oof['class'] == "Evidence"]
    neoof = oof[oof['class'] != "Evidence"]
    for thresh2 in range(26,27, 1):
        retval = []
        for idv in idu:
            for c in  ['Lead', 'Position', 'Evidence', 'Claim', 'Concluding Statement',
                   'Counterclaim', 'Rebuttal']:
                q = eoof[(eoof['id'] == idv) & (eoof['class'] == c)]
                if len(q) == 0:
                    continue
                pst = []
                for i,r in q.iterrows():
                    pst = pst +[-1] + [int(x) for x in r['predictionstring'].split()]
                start = 1
                end = 1
                for i in range(2,len(pst)):
                    cur = pst[i]
                    end = i
                    if (cur == -1 and c != 'Evidence') or ((cur == -1) and ((pst[i+1] > pst[end-1] + thresh) or (pst[i+1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end+1))
                retval.append(v)
        roof = pd.DataFrame(retval, columns = ['id', 'class', 'predictionstring']) 
        roof = roof.merge(neoof, how='outer')
        return roof
    

In [None]:
submission = beam_inference(raw_preds, test_samples)
submission = link_evidence(submission)
submission.to_csv("submission.csv", index=False)

In [None]:
submission.head()