In [None]:
import pandas as pd
import gc
import numpy as np
import os
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm

pd.set_option('max_columns', 300)

In [None]:
output_labels = ['O', 'B-Lead', 'I-Lead', 'B-Position', 'I-Position', 'B-Claim', 'I-Claim',
                 'B-Counterclaim', 'I-Counterclaim', 'B-Rebuttal', 'I-Rebuttal',
                 'B-Evidence', 'I-Evidence', 'B-Concluding', 'I-Concluding']

replace_labels = {'O': 'O', 'B-Lead': 'Lead', 'I-Lead': 'Lead', 'B-Position': 'Position', 'I-Position': 'Position', 
                  'B-Claim': 'Claim', 'I-Claim': 'Claim', 'B-Counterclaim': 'Counterclaim', 'I-Counterclaim': 'Counterclaim', 
                  'B-Rebuttal': 'Rebuttal', 'I-Rebuttal': 'Rebuttal', 'B-Evidence': 'Evidence', 'I-Evidence': 'Evidence', 
                  'B-Concluding': 'Concluding Statement', 'I-Concluding': 'Concluding Statement'}

num_labels = len(output_labels)
key2val = {k: v for v, k in enumerate(output_labels)}
val2key = {v: k for v, k in enumerate(output_labels)}
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def read_data(path):
    df = pd.DataFrame()
    id_list, text_list = [], []
    for filename in os.listdir(path):
        id_list.append(filename.split('.')[0])
        with open(f'{path}/{filename}') as file:
            text_list.append(file.read())
    df['id'] = id_list
    df['text'] = text_list
    return df


def create_entities(df, df_labels):
    df['entities'] = ''
    for i in range(len(df)):
        current_num_words = len(df.loc[i, 'text'].split())
        file_id = df.loc[i, 'id']
        pos_character_start_list = df_labels.loc[df_labels['id'] == file_id, 'discourse_start'].astype('int').tolist()
        pos_character_end_list = df_labels.loc[df_labels['id'] == file_id, 'discourse_end'].astype('int').tolist()
        labels_list = df_labels.loc[df_labels['id'] == file_id, 'discourse_type'].tolist()
        entities = ['O' for _ in range(len(df.loc[i, 'text'].split()))]
        
        for j in range(len(labels_list)):
            pos_character_start, pos_character_end = pos_character_start_list[j], pos_character_end_list[j]
            pos_word_start = len(df.loc[i, 'text'][:pos_character_start].split())
            pos_word_end = len(df.loc[i, 'text'][:pos_character_end].split()) - 1
            for k in range(pos_word_start, pos_word_end):
                if k == pos_word_start:
                    entities[k] = f'B-{labels_list[j].split()[0]}'
                else:
                    entities[k] = f'I-{labels_list[j].split()[0]}'
        df.loc[i, 'entities'] = ' '.join(entities)

    return df


def text_f1(labels, preds, word_ids):
    arr_length = 0
    active_preds = dict()
    active_labels = dict()
    
    last_word_id = -100
    
    for i in range(len(preds)):
        active_preds[i] = []
        active_labels[i] = []
        for j in range(len(preds[i])):
            #print(f'label: {labels[i][j]} pred: {preds[i][j]}')
            current_word_id = word_ids[i][j]
            if labels[i][j] != -100 and current_word_id != last_word_id:
                active_preds[i].append(preds[i][j])
                active_labels[i].append(labels[i][j])
                arr_length += 1
            last_word_id = current_word_id

    metric_preds = np.zeros(arr_length)
    metric_labels = np.zeros(arr_length)
    current_pos = 0
    for i in range(len(active_preds)):
        for j in range(len(active_preds[i])):
            metric_preds[current_pos] = active_preds[i][j]
            metric_labels[current_pos] = active_labels[i][j]
            current_pos += 1
    
    return f1_score(metric_labels, metric_preds, average='weighted')


def preds2submission(preds, file_id_list, word_ids):
    id_list, label_list, predictionstring_list = [], [], []
    for i in range(len(preds)):
        last_label = 'O'
        words_count = 0
        predictionstring = ''
        last_word_id = -100
        is_not_finished = True
        for j in range(len(preds[i])):
            current_label = replace_labels[val2key[preds[i][j]]]
            current_word_id = word_ids[i][j]
            #print(f'cur_word_id: {current_word_id} last_word_id: {last_word_id} cur_label: {current_label} last_label: {last_label} i: {i} j: {j}')
            if current_word_id == -100 or current_word_id == last_word_id:
                #print('Continue')
                
                if current_word_id == -100 and j > 0 and last_label != 'O' and current_label != 'O' and is_not_finished:
                    id_list.append(file_id_list[i])
                    label_list.append(current_label)
                    predictionstring_list.append(predictionstring)
                    is_not_finished = False
                
                continue
            if current_label != last_label:
                if last_label != 'O':
                    id_list.append(file_id_list[i])
                    label_list.append(last_label)
                    predictionstring_list.append(predictionstring)
                    
                    if current_label != 'O':
                        predictionstring = str(words_count)
                else:
                    predictionstring = str(words_count)
            elif current_label == last_label and current_label != 'O':
                predictionstring += f' {words_count}'
            
            """if j == len(preds[i]) - 1 and current_label != 'O':
                id_list.append(file_id_list[i])
                label_list.append(current_label)
                predictionstring_list.append(predictionstring)"""
                
            last_word_id = current_word_id
            last_label = current_label
            words_count += 1

    print(f'id: {len(id_list)}')
    print(f'class: {len(label_list)}')
    print(f'pred: {len(predictionstring_list)}')
    df = pd.DataFrame({'id': id_list, 'class': label_list, 'predictionstring': predictionstring_list})
    return df

In [None]:
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length, is_label=False):
        self.text = data['text']
        self.entities = data['entities'] if is_label else None
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_label = is_label

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        encoding = self.tokenizer(self.text[index].split(), is_split_into_words=True, truncation=True,
                                  padding='max_length', max_length=self.max_length)
        word_ids = encoding.word_ids()
        return_word_ids = []

        if self.is_label:
            entities_list = self.entities[index].split()
        labels = []
        for word_id in word_ids:
            if word_id is None:
                return_word_ids.append(-100)
                if self.is_label:
                    labels.append(-100)
            else:
                return_word_ids.append(word_id)
                if self.is_label:
                    labels.append(key2val[entities_list[word_id]])
        
        if self.is_label:
            encoding['labels'] = labels

        item = {key: torch.as_tensor(val) for key, val in encoding.items()}

        return item, np.array(return_word_ids)


class TextModel:
    def __init__(self, model_name, max_length):
        self.model_name = model_name
        self.num_folds = 5
        self.tokenizer = None
        self.max_length = max_length

    def fit(self, data, batch_size=4, lr=2.5e-5, epochs=1, num_folds=5):
        
        self.num_folds = num_folds
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, add_prefix_space=True)

        full_preds = np.zeros((len(data), self.max_length))
        full_labels = np.zeros((len(data), self.max_length))

        kfold = KFold(n_splits=self.num_folds, shuffle=True, random_state=0)
        for fold, (trn_ind, val_ind) in enumerate(kfold.split(data)):
            data_train = data.loc[trn_ind]
            data_train.index = range(len(data_train))
            data_val = data.loc[val_ind]
            data_val.index = range(len(data_val))

            dataset_train = TextDataset(data_train, self.tokenizer, max_length=self.max_length, is_label=True)
            loader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=False, pin_memory=True)

            dataset_val = TextDataset(data_val, self.tokenizer, max_length=self.max_length, is_label=True)
            loader_val = DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False, pin_memory=True)

            model = AutoModelForTokenClassification.from_pretrained(self.model_name, num_labels=num_labels).to(device)
            
            optimizer = optim.Adam(params=model.parameters(), lr=lr)

            for epoch in range(epochs):
                print(f'epoch: {epoch+1}')
                model.train()
                for batch, _ in tqdm(loader_train):
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)

                    loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=False)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                with torch.no_grad():
                    model.eval()
                    current_pos = 0
                    temp_preds = np.zeros((len(data_val), self.max_length))
                    temp_labels = np.zeros((len(data_val), self.max_length))
                    temp_words = np.zeros((len(data), self.max_length))
                    for batch, word_ids in tqdm(loader_val):
                        input_ids = batch['input_ids'].to(device)
                        attention_mask = batch['attention_mask'].to(device)
                        labels = batch['labels'].to(device)
                        current_batch_size = len(batch['input_ids'])

                        _, logit_preds = model(input_ids=input_ids, attention_mask=attention_mask,
                                               labels=labels, return_dict=False)
                        flattened_preds = torch.argmax(logit_preds.view(-1, model.num_labels), dim=1)
                        temp_preds[current_pos:current_pos+current_batch_size] = flattened_preds.view(current_batch_size, -1).cpu().numpy()
                        temp_labels[current_pos:current_pos+current_batch_size] = labels.cpu().numpy()
                        temp_words[current_pos:current_pos+current_batch_size] = word_ids
                        current_pos += current_batch_size
                    full_preds[val_ind] = temp_preds
                    full_labels[val_ind] = temp_labels
                    
                    print(f'f1_score: {text_f1(temp_labels, temp_preds, temp_words)}')
            
            torch.cuda.empty_cache()
            gc.collect()
            torch.save(model.state_dict(), f'./{self.model_name.split("/")[-1]}{fold+1}.pt')

        print(f'full_preds: {full_preds} shape: {full_preds.shape}')

        return full_preds, full_labels
    
    def predict(self, data, batch_size=4, is_label=False):
        dataset = TextDataset(data, self.tokenizer, max_length=self.max_length, is_label=is_label)
        loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
        
        preds = np.zeros((len(data), self.max_length, num_labels))
        labels = np.zeros((len(data), self.max_length))
        words = np.zeros((len(data), self.max_length))
        
        with torch.no_grad():
            for fold in range(self.num_folds):
                model = AutoModelForTokenClassification.from_pretrained(self.model_name, num_labels=num_labels).to(device)
                model.load_state_dict(torch.load(f'./{self.model_name.split("/")[-1]}{fold+1}.pt'))
                model.eval()
                current_pos = 0
                for batch, word_ids in tqdm(loader):
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels_temp = None
                    current_batch_size = len(batch['input_ids'])
                    if is_label:
                        labels_temp = batch['labels'].to(device)
                        _, logit_preds = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels_temp, return_dict=False)
                        labels[current_pos:current_pos+batch_size] = labels_temp.cpu().numpy()
                    else:
                        logit_preds = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)[0]
                    #print(f'\nwids: {word_ids} \nwids_len: {len(word_ids)} \nwords: {words} \nwords_len: {len(words)}')
                    #print(f'w_ids_shape: {word_ids.shape} words_shape: {words.shape} words_batch_shape: {words[current_pos:current_pos+current_batch_size].shape}')
                    preds[current_pos:current_pos+current_batch_size] += logit_preds.cpu().numpy()
                    words[current_pos:current_pos+current_batch_size] = word_ids
                    current_pos += current_batch_size
        preds = preds / self.num_folds
        return np.argmax(preds, axis=2), labels, words

In [None]:
df_labels = pd.read_csv('../input/feedback-prize-2021/train.csv')
df_train = read_data('../input/feedback-prize-2021/train')

#df_train, df_val = train_test_split(df_train, test_size=0.2, shuffle=True, random_state=0)

df_train.index = range(len(df_train))
df_train = create_entities(df_train, df_labels)

#df_val.index = range(len(df_val))
#df_val = create_entities(df_val, df_labels)

df_test = read_data('../input/feedback-prize-2021/test')

model = TextModel(model_name='../input/huggingfacebigbirdrobertabase', max_length=1024)
#model = TextModel(model_name='../input/py-bigbird-v26', max_length=1024)
val_preds, val_labels = model.fit(df_train)
#preds, labels, word_ids = model.predict(df_val, is_label=True)
test_preds, _, test_word_ids = model.predict(df_test)

In [None]:
#text_f1(labels, preds, word_ids)

In [None]:
submission_df = preds2submission(test_preds, df_test['id'], test_word_ids)
submission_df

In [None]:
submission_df.to_csv('submission.csv', index=False)