In [None]:
import gc
gc.enable()

import sys
sys.path.append("../input/tez-lib/")
sys.path.append("../input/pytorch-crf/")
import os
from tqdm import tqdm
from torchcrf import CRF
import numpy as np
import pandas as pd
import tez
import torch
import torch.nn as nn
# from torch.optim import Optimizer
from torch.optim import Adam
from joblib import Parallel, delayed
from torch.utils.data import DataLoader, Dataset
from transformers import AutoConfig, AutoModel, AutoTokenizer

In [None]:
target_id_map = {
    "B-Lead": 0,
    "I-Lead": 1,
    "B-Position": 2,
    "I-Position": 3,
    "B-Evidence": 4,
    "I-Evidence": 5,
    "B-Claim": 6,
    "I-Claim": 7,
    "B-Concluding Statement": 8,
    "I-Concluding Statement": 9,
    "B-Counterclaim": 10,
    "I-Counterclaim": 11,
    "B-Rebuttal": 12,
    "I-Rebuttal": 13,
    "O": 14,
    "PAD": -100,
}


id_target_map = {v: k for k, v in target_id_map.items()}

class args1:
    input_path = "../input/feedback-prize-2021/"
    model = "../input/longformerlarge4096/longformer-large-4096/"
    tez_model= "../input/fblongformerlarge1536/"
    output = "."
    batch_size = 8
    max_len = 2048
    
class args2:
    input_path = "../input/feedback-prize-2021/"
    model = "../input/longformerlarge4096/longformer-large-4096/"
    tez_model= "../input/tez-fb-large/"
    output = "."
    batch_size = 8
    max_len = 2048

In [None]:
class FeedbackDataset:
    def __init__(self, samples, max_len, tokenizer):
        self.samples = samples
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.length = len(samples)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_ids = self.samples[idx]["input_ids"]
        # print(input_ids)
        # print(input_labels)

        # add start token id to the input_ids
        input_ids = [self.tokenizer.cls_token_id] + input_ids

        if len(input_ids) > self.max_len - 1:
            input_ids = input_ids[: self.max_len - 1]

        # add end token id to the input_ids
        input_ids = input_ids + [self.tokenizer.sep_token_id]
        attention_mask = [1] * len(input_ids)

        # padding_length = self.max_len - len(input_ids)
        # if padding_length > 0:
        #     if self.tokenizer.padding_side == "right":
        #         input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
        #         attention_mask = attention_mask + [0] * padding_length
        #     else:
        #         input_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
        #         attention_mask = [0] * padding_length + attention_mask

        # return {
        #     "ids": torch.tensor(input_ids, dtype=torch.long),
        #     "mask": torch.tensor(attention_mask, dtype=torch.long),
        # }

        return {
            "ids": input_ids,
            "mask": attention_mask,
        }

In [None]:
class Collate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        output = dict()
        output["ids"] = [sample["ids"] for sample in batch]
        output["mask"] = [sample["mask"] for sample in batch]

        # calculate max token length of this batch
        batch_max = 1860

        # add padding
        if self.tokenizer.padding_side == "right":
            output["ids"] = [s + (batch_max - len(s)) * [self.tokenizer.pad_token_id] for s in output["ids"]]
            output["mask"] = [s + (batch_max - len(s)) * [0] for s in output["mask"]]
        else:
            output["ids"] = [(batch_max - len(s)) * [self.tokenizer.pad_token_id] + s for s in output["ids"]]
            output["mask"] = [(batch_max - len(s)) * [0] + s for s in output["mask"]]

        # convert to tensors
        output["ids"] = torch.tensor(output["ids"], dtype=torch.long)
        output["mask"] = torch.tensor(output["mask"], dtype=torch.long)

        return output

In [None]:
class FeedbackModel(tez.Model):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        config = AutoConfig.from_pretrained(model_name)

        hidden_dropout_prob: float = 0.1
        layer_norm_eps: float = 1e-7
        config.update(
            {
                "output_hidden_states": True,
                "hidden_dropout_prob": hidden_dropout_prob,
                "layer_norm_eps": layer_norm_eps,
                "add_pooling_layer": False,
            }
        )
        self.transformer = AutoModel.from_config(config)
        self.output = nn.Linear(config.hidden_size, self.num_labels)

    def forward(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        logits = self.output(sequence_output)
        logits = torch.softmax(logits, dim=-1)
        return logits, 0, {}
    
    def get_sequence_output(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        logits = self.output(sequence_output)
        return logits
    
class CRFModel(tez.Model):
    def __init__(self, longformer_model, num_labels):
        super().__init__()
        self.longformer_model = longformer_model
        self.crf = CRF(num_labels)
        
    def forward(self, ids, mask, label=None):
        longformer_output = self.longformer_model.get_sequence_output(ids, mask)
        if label is not None:
            crf_output = self.crf(longformer_output, label)
            return crf_output, -crf_output, {}
        else:
#             crf_output = self.crf(longformer_output)
            crf_output = self.crf.decode(longformer_output)
            return torch.tensor(crf_output, dtype=torch.long),torch.tensor(crf_output, dtype=torch.long),{}
#         return crf_output, crf_output, {}

In [None]:
def _prepare_train_data_helper(args, tokenizer, df, ids):
    train_samples = []
    for idx in ids:
        filename = os.path.join(args.input_path, "train", idx + ".txt")
        with open(filename, "r") as f:
            text = f.read()
        text = text.rstrip()
        idx_info = df[df['id'] == idx]
        former_end = 0
        #deal with each idx_info
        train_X = []
        train_Y = []
        for i, row in idx_info.iterrows():
            start = int(row['discourse_start'])
            end = int(row['discourse_end'])
            discourse_type = row['discourse_type']
            encoded_text = tokenizer.encode_plus(
                text[former_end: start],
                add_special_tokens=False
            )
            train_X += encoded_text["input_ids"]
            train_Y+=[target_id_map.get('O')]*len(encoded_text["input_ids"])

            encoded_text = tokenizer.encode_plus(
                text[start: end],
                add_special_tokens=False
            )
            train_X += encoded_text["input_ids"]
            train_Y+=[target_id_map.get('B-'+discourse_type)]
            train_Y+=[target_id_map.get('I-'+discourse_type)]*(len(encoded_text["input_ids"])-1)
            former_end = end

        encoded_text = tokenizer.encode_plus(
            text[former_end:],
            add_special_tokens=False
        )
        train_X += encoded_text["input_ids"]
        train_Y += [target_id_map.get('O')]*len(encoded_text["input_ids"])

        assert len(train_X) == len(train_Y)
        sample = {
            "id": idx,
            "input_ids": train_X,
            "text": text,
            "label": train_Y
        }

        train_samples.append(sample)
    return train_samples


def prepare_train_data(df, tokenizer, args):
    train_samples = []
    ids = df["id"].unique()
    ids_splits = np.array_split(ids, 4)

    results = Parallel(n_jobs=4, backend="multiprocessing")(
        delayed(_prepare_train_data_helper)(args, tokenizer, df, idx) for idx in ids_splits
    )
    for result in results:
        train_samples.extend(result)

    return train_samples

In [None]:
def _prepare_test_data_helper(args, tokenizer, ids):
    test_samples = []
    for idx in ids:
        filename = os.path.join(args.input_path, "test", idx + ".txt")
        with open(filename, "r") as f:
            text = f.read()

        encoded_text = tokenizer.encode_plus(
            text,
            add_special_tokens=False,
            return_offsets_mapping=True,
        )
        input_ids = encoded_text["input_ids"]
        offset_mapping = encoded_text["offset_mapping"]
        sample = {
            "id": idx,
            "input_ids": input_ids,
            "text": text,
            "offset_mapping": offset_mapping,
        }

        test_samples.append(sample)
    return test_samples


def prepare_test_data(df, tokenizer, args):
    test_samples = []
    ids = df["id"].unique()
    ids_splits = np.array_split(ids, 4)

    results = Parallel(n_jobs=4, backend="multiprocessing")(
        delayed(_prepare_test_data_helper)(args, tokenizer, idx) for idx in ids_splits
    )
    for result in results:
        test_samples.extend(result)

    return test_samples

In [None]:
class TrainDataSet(Dataset):
    def __init__(self, train_samples, tokenizer, max_len=1860):
        self.train_samples = train_samples
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __getitem__(self, index):

            
        data = self.train_samples[index]
        input_ids = data["input_ids"]
        label = data['label']
        input_ids = [self.tokenizer.cls_token_id] + input_ids + [self.tokenizer.sep_token_id]
        mask = [1] * len(input_ids)
        label = [target_id_map.get('O')]+label
        
        if len(input_ids) > self.max_len - 1:
            input_ids = input_ids[: self.max_len - 1]
            mask = mask[: self.max_len - 1]
            label = label[: self.max_len - 1]
        input_ids = input_ids + (self.max_len - len(input_ids)) * [self.tokenizer.pad_token_id] 
        mask = mask + (self.max_len - len(mask)) * [0]
        label = label + (self.max_len - len(label)) * [target_id_map.get('O')]
        
        assert len(mask) == len(input_ids)
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        mask = torch.tensor(mask, dtype=torch.long)
        label = torch.tensor(label, dtype=torch.long)
        return input_ids.cuda(), mask.cuda(), label.cuda()
    
    def __len__(self):
        return len(self.train_samples)

In [None]:
longformer_model = FeedbackModel(model_name=args1.model, num_labels=len(target_id_map) - 1)
longformer_model.load(os.path.join(args1.tez_model, f"model_1.bin"), weights_only=True)
for p in longformer_model.parameters():
    p.requires_grad = False
model = CRFModel(longformer_model, len(target_id_map) - 1)

train_para = []
for name, param in model.named_parameters():
    if name.startswith('crf'):
        train_para.append(param)
        
model = model.cuda()
optimizer = Adam(train_para, 0.01, betas=(0.9, 0.999), eps=1e-6)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(args1.model)

In [None]:
df = pd.read_csv(os.path.join("../input/feedback-prize-2021/", "train.csv"))
df_ids = df["id"].unique()

train_samples = prepare_train_data(df, tokenizer, args1)
train_dataset = TrainDataSet(train_samples, tokenizer)

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=True)

In [None]:
loss_mean = 0
batch_size = 16
model.train()
for epoch in range(1):
    
    for i, data in tqdm(enumerate(train_loader)):
        inputs, mask, label = data
        output, crf_loss, _ = model(inputs, mask, label)
        loss = crf_loss/batch_size
        loss.backward()
        if i%10==0:
            optimizer.step()
            optimizer.zero_grad()
        
        loss_mean *= i
        loss_mean += crf_loss.item()
        loss_mean /= (i+1)
        if i%100==0:
            print(loss_mean)
    optimizer.step()
    optimizer.zero_grad()
    print(loss_mean)

In [None]:
model.save('crf1.bin')

In [None]:
longformer_model = FeedbackModel(model_name=args1.model, num_labels=len(target_id_map) - 1)
# longformer_model.load(os.path.join(args1.tez_model, f"model_1.bin"), weights_only=True)
# for p in longformer_model.parameters():
#     p.requires_grad = False
model = CRFModel(longformer_model, len(target_id_map) - 1)
model = model.cuda()

In [None]:
df = pd.read_csv(os.path.join("../input/feedback-prize-2021/", "sample_submission.csv"))
df_ids = df["id"].unique()

In [None]:
model.load('./crf1.bin')
# model.load('../input/crfversion/crf1.bin')

In [None]:
collate = Collate(tokenizer=tokenizer)

In [None]:
test_samples = prepare_test_data(df, tokenizer, args1)


In [None]:
raw_preds = []
test_dataset = FeedbackDataset(test_samples, args1.max_len, tokenizer)
# model.eval()
# model.cuda()

preds_iter = model.predict(test_dataset, batch_size=args1.batch_size, n_jobs=0, collate_fn=collate)

In [None]:
# torch.multiprocessing.set_start_method('spawn')
for preds in preds_iter:
    preds = preds.astype(np.float16)
    raw_preds.append(preds)
    
# torch.cuda.empty_cache()
# gc.collect()

In [None]:
final_preds = []

for rp in raw_preds:
    pred_class = raw_preds[0].T
    for pred in pred_class:
        pred = pred.tolist()
        final_preds.append(pred)


In [None]:
# final_preds = []
# final_scores = []

# for rp in raw_preds:
#     pred_class = np.argmax(rp, axis=2)
#     pred_scrs = np.max(rp, axis=2)
#     for pred, pred_scr in zip(pred_class, pred_scrs):
#         pred = pred.tolist()
#         pred_scr = pred_scr.tolist()
#         final_preds.append(pred)
#         final_scores.append(pred_scr)


In [None]:
for j in range(len(test_samples)):
    tt = [id_target_map.get(p, 14) for p in final_preds[j][1:]]
#     tt_score = final_scores[j][1:]
    test_samples[j]["preds"] = tt
#     test_samples[j]["pred_scores"] = tt_score

In [None]:
def jn(pst, start, end):
    return " ".join([str(x) for x in pst[start:end]])


def link_evidence(oof):
    thresh = 1
    idu = oof['id'].unique()
    idc = idu[1]
    eoof = oof[oof['class'] == "Evidence"]
    neoof = oof[oof['class'] != "Evidence"]
    for thresh2 in range(26,27, 1):
        retval = []
        for idv in idu:
            for c in  ['Lead', 'Position', 'Evidence', 'Claim', 'Concluding Statement',
                   'Counterclaim', 'Rebuttal']:
                q = eoof[(eoof['id'] == idv) & (eoof['class'] == c)]
                if len(q) == 0:
                    continue
                pst = []
                for i,r in q.iterrows():
                    pst = pst +[-1] + [int(x) for x in r['predictionstring'].split()]
                start = 1
                end = 1
                for i in range(2,len(pst)):
                    cur = pst[i]
                    end = i
                    #if pst[start] == 205:
                    #   print(cur, pst[start], cur - pst[start])
                    if (cur == -1 and c != 'Evidence') or ((cur == -1) and ((pst[i+1] > pst[end-1] + thresh) or (pst[i+1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end+1))
                #print(v)
                retval.append(v)
        roof = pd.DataFrame(retval, columns = ['id', 'class', 'predictionstring']) 
        roof = roof.merge(neoof, how='outer')
        return roof

In [None]:
proba_thresh = {
    "Lead": 0.7,
    "Position": 0.55,
    "Evidence": 0.65,
    "Claim": 0.55,
    "Concluding Statement": 0.7,
    "Counterclaim": 0.5,
    "Rebuttal": 0.55,
}

min_thresh = {
    "Lead": 9,
    "Position": 5,
    "Evidence": 14,
    "Claim": 3,
    "Concluding Statement": 11,
    "Counterclaim": 6,
    "Rebuttal": 4,
}


In [None]:
submission = []
for sample_idx, sample in enumerate(test_samples):
    preds = sample["preds"]
    offset_mapping = sample["offset_mapping"]
    sample_id = sample["id"]
    sample_text = sample["text"]
    sample_input_ids = sample["input_ids"]
#     sample_pred_scores = sample["pred_scores"]
    sample_preds = []

    if len(preds) < len(offset_mapping):
        preds = preds + ["O"] * (len(offset_mapping) - len(preds))
#         sample_pred_scores = sample_pred_scores + [0] * (len(offset_mapping) - len(sample_pred_scores))
    
    idx = 0
    phrase_preds = []

    while idx < len(offset_mapping):
        start, _ = offset_mapping[idx]
        if preds[idx] != "O":
            label = preds[idx][2:]
        else:
            label = "O"
        phrase_scores = []
#         phrase_scores.append(sample_pred_scores[idx])
        idx += 1
        while idx < len(offset_mapping):
            if label == "O":
                matching_label = "O"
            else:
                matching_label = f"I-{label}"
            if preds[idx] == matching_label:
                _, end = offset_mapping[idx]
#                 phrase_scores.append(sample_pred_scores[idx])
                idx += 1
            else:
                break
        if "end" in locals():
            phrase = sample_text[start:end]
            phrase_preds.append((phrase, start, end, label, phrase_scores))
    temp_df = []
    for phrase_idx, (phrase, start, end, label, phrase_scores) in enumerate(phrase_preds):
        word_start = len(sample_text[:start].split())
        word_end = word_start + len(sample_text[start:end].split())
        word_end = min(word_end, len(sample_text.split()))
        ps = " ".join([str(x) for x in range(word_start, word_end)])
        if label != "O":
#             if sum(phrase_scores) / len(phrase_scores) >= proba_thresh[label]:
            if len(ps.split()) >= min_thresh[label]:
                temp_df.append((sample_id, label, ps))

    temp_df = pd.DataFrame(temp_df, columns=["id", "class", "predictionstring"])
    submission.append(temp_df)

In [None]:
submission = pd.concat(submission).reset_index(drop=True)
submission = link_evidence(submission)
submission.to_csv("submission.csv", index=False)

In [None]:
submission.head()