In [None]:
import random
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from sklearn.model_selection import *
from transformers import *

In [None]:
CFG = {
    'fold_num': 5, 
    'seed': 42,
    'model': '../input/huggingfacebigbirdrobertabase',
    'max_len': 1024,
    'epochs': 5,
    'train_bs': 24,
    'valid_bs': 32,
    'lr': 2e-5,
    'num_workers': 0,
    'weight_decay': 1e-6,
}

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(CFG['seed'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
test_df = pd.read_csv('../input/feedback-prize-2021/sample_submission.csv')
test_df

In [None]:
test_names, test_texts = [], []
for f in tqdm(list(os.listdir('../input/feedback-prize-2021/test'))):
    test_names.append(f.replace('.txt', ''))
    test_texts.append(open('../input/feedback-prize-2021/test/' + f, 'r').read())
test_texts = pd.DataFrame({'id': test_names, 'text': test_texts})
test_texts['text'] = test_texts['text'].apply(lambda x:x.split())
test_texts

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG['model'], add_prefix_space=True)

In [None]:
class MyDataset(Dataset):
    def __init__(self, df):
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        text = self.df.text.values[idx]
        
        return text

In [None]:
def collate_fn(data):
    tokenized_inputs = tokenizer(
        data,
        max_length=CFG['max_len'],
        padding='max_length',
        truncation=True,
        is_split_into_words=True,
        return_tensors='pt'
    )

    words = []
    for i in range(len(data)):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        words.append(word_ids)

    tokenized_inputs["word_ids"] = words
    
    return tokenized_inputs

In [None]:
test_loader = DataLoader(MyDataset(test_texts), batch_size=CFG['valid_bs'], collate_fn=collate_fn, shuffle=False, num_workers=4)
batch = next(iter(test_loader))
batch

In [None]:
model =  AutoModelForTokenClassification.from_pretrained(CFG['model'], num_labels=15).to(device)
model.load_state_dict(torch.load('../input/feedback-bigbird/bigbird-roberta-base_fold_0.pt'))
model.eval()

In [None]:
y_pred = []
words = []

with torch.no_grad():
    tk = tqdm(test_loader, total=len(test_loader), position=0, leave=True)
    for step, batch in enumerate(tk):
        word_ids = batch['word_ids']
        words.extend(word_ids)
        batch = {k: v.to(device) for k, v in batch.items() if k != 'word_ids'}

        output = model(**batch).logits

        y_pred.extend(output.argmax(-1).cpu().numpy())
        
y_pred = np.array(y_pred)

In [None]:
labels = ['O', 'B-Lead', 'I-Lead', 'B-Position', 'I-Position', 'B-Claim', 'I-Claim', 'B-Counterclaim', 'I-Counterclaim', 
          'B-Rebuttal', 'I-Rebuttal', 'B-Evidence', 'I-Evidence', 'B-Concluding Statement', 'I-Concluding Statement']

In [None]:
final_preds = []

for i in tqdm(range(len(test_texts))):
    idx = test_texts.id.values[i]
    pred = ['']*len(test_texts.text.values[i])

    for j in range(len(y_pred[i])):
        if words[i][j] != None:
            pred[words[i][j]] = labels[y_pred[i][j]]

    preds = []
    j = 0
    while j < len(pred):
        cls = pred[j]
        if cls == 'O' or cls == '' or cls[0]== 'I':
            j += 1
        else:
            end = j + 1
            while end < len(pred) and pred[end].replace('B-','').replace('I-','') == cls.replace('B-','').replace('I-',''):
                end += 1

            if end - j > 5:
                final_preds.append((idx, cls.replace('B-',''), ' '.join(map(str, list(range(j, end))))))

            j = end
        
final_preds[0]

In [None]:
sub = pd.DataFrame(final_preds)
sub.columns = test_df.columns
sub

In [None]:
sub.to_csv('submission.csv', index=False)