In [1]:
import pandas as pd
import gc
import numpy as np
import os
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm

pd.set_option('max_columns', 300)

In [2]:
output_labels = ['O', 'B-Lead', 'I-Lead', 'B-Position', 'I-Position', 'B-Claim', 'I-Claim',
                 'B-Counterclaim', 'I-Counterclaim', 'B-Rebuttal', 'I-Rebuttal',
                 'B-Evidence', 'I-Evidence', 'B-Concluding', 'I-Concluding']

replace_labels = {'O': 'O', 'B-Lead': 'Lead', 'I-Lead': 'Lead', 'B-Position': 'Position', 'I-Position': 'Position', 
                  'B-Claim': 'Claim', 'I-Claim': 'Claim', 'B-Counterclaim': 'Counterclaim', 'I-Counterclaim': 'Counterclaim', 
                  'B-Rebuttal': 'Rebuttal', 'I-Rebuttal': 'Rebuttal', 'B-Evidence': 'Evidence', 'I-Evidence': 'Evidence', 
                  'B-Concluding': 'Concluding Statement', 'I-Concluding': 'Concluding Statement'}

num_labels = len(output_labels)
key2val = {k: v for v, k in enumerate(output_labels)}
val2key = {v: k for v, k in enumerate(output_labels)}
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def read_data(path):
    df = pd.DataFrame()
    id_list, text_list = [], []
    for filename in os.listdir(path):
        id_list.append(filename.split('.')[0])
        with open(f'{path}/{filename}') as file:
            text_list.append(file.read())
    df['id'] = id_list
    df['text'] = text_list
    return df


def create_entities(df, df_labels):
    df['entities'] = ''
    #df['num_words'] = 0
    #max_num_words = 0
    for i in range(len(df)):
        current_num_words = len(df.loc[i, 'text'].split())
        #df.loc[i, 'num_words'] = current_num_words
        #if current_num_words > max_num_words:
        #    max_num_words = current_num_words
        #    print(f'max_words: {max_num_words}')
        file_id = df.loc[i, 'id']
        pos_character_start_list = df_labels.loc[df_labels['id'] == file_id, 'discourse_start'].astype('int').tolist()
        pos_character_end_list = df_labels.loc[df_labels['id'] == file_id, 'discourse_end'].astype('int').tolist()
        labels_list = df_labels.loc[df_labels['id'] == file_id, 'discourse_type'].tolist()
        entities = ['O' for _ in range(len(df.loc[i, 'text'].split()))]
        for j in range(len(labels_list)):
            pos_character_start, pos_character_end = pos_character_start_list[j], pos_character_end_list[j]
            pos_word_start = len(df.loc[i, 'text'][:pos_character_start].split())
            pos_word_end = len(df.loc[i, 'text'][:pos_character_end].split()) - 1
            for k in range(pos_word_start, pos_word_end):
                if k == pos_word_start:
                    entities[k] = f'B-{labels_list[j].split()[0]}'
                else:
                    entities[k] = f'I-{labels_list[j].split()[0]}'
        df.loc[i, 'entities'] = ' '.join(entities)

    return df


def text_accuracy(preds, labels):
    arr_length = 0
    active_preds = dict()
    active_labels = dict()
    for i in range(len(preds)):
        active_preds[i] = []
        active_labels[i] = []
        for j in range(len(preds[i])):
            if labels[i][j] != -100:
                active_preds[i].append(preds[i][j])
                active_labels[i].append(labels[i][j])
                arr_length += 1

    metric_preds = np.zeros(arr_length)
    metric_labels = np.zeros(arr_length)
    current_pos = 0
    for i in range(len(active_preds)):
        for j in range(len(active_preds[i])):
            metric_preds[current_pos] = active_preds[i][j]
            metric_labels[current_pos] = active_labels[i][j]
            current_pos += 1
    
    return accuracy_score(metric_labels, metric_preds)


def preds2submission(preds, file_id_list):
    id_list, label_list, predictionstring_list = [], [], []
    for i in range(len(preds)):
        last_label = 'O'
        words_count = 0
        predictionstring = ''
        for j in range(len(preds[i])):
            current_label = replace_labels[val2key[preds[i][j]]]
            if current_label != last_label:
                if last_label != 'O':
                    id_list.append(file_id_list[i])
                    label_list.append(last_label)
                    predictionstring_list.append(predictionstring)
                    
                    if current_label != 'O':
                        predictionstring = str(words_count)
                """else:
                    predictionstring_list.append(predictionstring)
                    id_list.append(file_id_list[i])
                    predictionstring = ''"""
            elif current_label == last_label and current_label != 'O':
                predictionstring += f' {words_count}'
            
            if j == len(preds[i]) - 1 and current_label != 'O':
                id_list.append(file_id_list[i])
                label_list.append(current_label)
                predictionstring_list.append(predictionstring)
                
            last_label = current_label
            words_count += 1

    print(f'id: {len(id_list)}')
    print(f'class: {len(label_list)}')
    print(f'pred: {len(predictionstring_list)}')
    df = pd.DataFrame({'id': id_list, 'class': label_list, 'predictionstring': predictionstring_list})
    return df

In [4]:
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length, is_label=False):
        self.text = data['text']
        self.entities = data['entities'] if is_label else None
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_label = is_label

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        #print(f'text1: {self.text}')
        #print(f'\n\n\ntext2: {self.text[index]}')
        encoding = self.tokenizer(self.text[index].split(), is_split_into_words=True, truncation=True,
                                  padding='max_length', max_length=self.max_length)

        if self.is_label:
            entities_list = self.entities[index].split()
            word_ids = encoding.word_ids()
            labels = []
            for word_id in word_ids:
                if word_id is None:
                    labels.append(-100)
                else:
                    labels.append(key2val[entities_list[word_id]])
            encoding['labels'] = labels

        item = {key: torch.as_tensor(val) for key, val in encoding.items()}

        return item


class TextModel:
    def __init__(self, model_name, max_length, num_folds=5):
        self.models = []
        self.model_name = model_name
        self.num_folds = num_folds
        self.tokenizer = None
        self.max_length = max_length

    def fit(self, data, batch_size=4, lr=2.5e-5, epochs=1):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, add_prefix_space=True)

        full_preds = np.zeros((len(data), self.max_length))
        full_labels = np.zeros((len(data), self.max_length))

        kfold = KFold(n_splits=self.num_folds, shuffle=True, random_state=0)
        for trn_ind, val_ind in kfold.split(data):
            data_train = data.loc[trn_ind]
            data_train.index = range(len(data_train))
            data_val = data.loc[val_ind]
            data_val.index = range(len(data_val))

            dataset_train = TextDataset(data_train, self.tokenizer, max_length=self.max_length, is_label=True)
            loader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=False, pin_memory=True)

            dataset_val = TextDataset(data_val, self.tokenizer, max_length=self.max_length, is_label=True)
            loader_val = DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False, pin_memory=True)

            model = AutoModelForTokenClassification.from_pretrained(self.model_name, num_labels=num_labels).to(device)
            optimizer = optim.Adam(params=model.parameters(), lr=lr)

            #print(f"model: {model}")

            for epoch in range(epochs):
                print(f'epoch: {epoch+1}')
                model.train()
                for batch in tqdm(loader_train):
                    #print(batch['input_ids'].shape)
                    #print(batch['attention_mask'].shape)
                    #print(batch['labels'].shape)
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)

                    loss, _ = model(input_ids=input_ids, attention_mask=attention_mask,
                                    labels=labels, return_dict=False)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                with torch.no_grad():
                    model.eval()
                    current_pos = 0
                    temp_preds = np.zeros((len(data_val), self.max_length))
                    temp_labels = np.zeros((len(data_val), self.max_length))
                    for batch in tqdm(loader_val):
                        input_ids = batch['input_ids'].to(device)
                        attention_mask = batch['attention_mask'].to(device)
                        labels = batch['labels'].to(device)
                        current_batch_size = len(batch['input_ids'])

                        _, logit_preds = model(input_ids=input_ids, attention_mask=attention_mask,
                                               labels=labels, return_dict=False)
                        #print(f'logit_fit: {logit_preds}')
                        flattened_preds = torch.argmax(logit_preds.view(-1, model.num_labels), dim=1)
                        #active_labels_mask = labels.view(-1) != -100
                        #preds = torch.masked_select(flattened_preds, active_labels_mask)
                        #print(f'preds1: {flattened_preds.view(batch_size, -1).shape}')
                        #print(f'preds2: {preds.shape}')
                        #print(f'preds3: {preds[current_pos:current_pos + batch_size].shape}')
                        temp_preds[current_pos:current_pos+current_batch_size] = flattened_preds.view(current_batch_size, -1).cpu().numpy()
                        temp_labels[current_pos:current_pos+current_batch_size] = labels.cpu().numpy()
                        current_pos += current_batch_size
                    full_preds[val_ind] = temp_preds
                    full_labels[val_ind] = temp_labels
                    
                    print(f'temp_preds: {temp_preds}')
            
        self.models.append(model)
        torch.cuda.empty_cache()
        gc.collect()

        print(f'full_preds: {full_preds} shape: {full_preds.shape}')

        return full_preds, full_labels
    
    def predict(self, data, batch_size=4, is_label=False):
        dataset = TextDataset(data, self.tokenizer, max_length=self.max_length, is_label=is_label)
        loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
        
        preds = np.zeros((len(data), self.max_length, num_labels))
        labels = np.zeros((len(data), self.max_length))
        
        with torch.no_grad():
            for model in self.models:
                model.eval()
                current_pos = 0
                for batch in tqdm(loader):
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels_temp = None
                    current_batch_size = len(batch['input_ids'])
                    if is_label:
                        labels_temp = batch['labels'].to(device)
                        _, logit_preds = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels_temp, return_dict=False)
                        labels[current_pos:current_pos+batch_size] = labels_temp.cpu().numpy()
                    else:
                        logit_preds = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)[0]
                    #flattened_preds = logit_preds.view(-1, model.num_labels)
                    #active_labels_mask = labels.view(-1) != -100
                    print(f'logit: {logit_preds}')
                    preds[current_pos:current_pos+current_batch_size] += logit_preds.cpu().numpy()
                    #preds[current_pos:current_pos+batch_size] = flattened_preds.view(batch_size, -1).cpu().numpy()
                    current_pos += current_batch_size
        preds = preds / self.num_folds
        return np.argmax(preds, axis=2), labels

In [5]:
df_labels = pd.read_csv('../input/feedback-prize-2021/train.csv')
df_train = read_data('../input/feedback-prize-2021/train')

df_train, df_val = train_test_split(df_train, test_size=0.2, shuffle=True, random_state=0)

df_train.index = range(len(df_train))
df_train = create_entities(df_train, df_labels)

df_val.index = range(len(df_val))
df_val = create_entities(df_val, df_labels)

df_test = read_data('../input/feedback-prize-2021/test')

model = TextModel(model_name='../input/huggingfacebigbirdrobertabase', max_length=1024)
val_preds, val_labels = model.fit(df_train)
preds, labels = model.predict(df_val, is_label=True)
test_preds, _ = model.predict(df_test)

In [None]:
text_accuracy(val_preds, val_labels)

In [None]:
submission_df = preds2submission(preds, df_val['id'])
submission_df

In [None]:
submission_df.to_csv('submission.csv', index=False)

In [None]:
"""arr_length = 0
active_preds = dict()
active_labels = dict()
for i in range(len(preds)):
    active_preds[i] = []
    active_labels[i] = []
    for j in range(len(preds[i])):
        if labels[i][j] != -100:
            active_preds[i].append(preds[i][j])
            active_labels[i].append(labels[i][j])
            arr_length += 1

metric_preds = np.zeros(arr_length)
metric_labels = np.zeros(arr_length)
current_pos = 0
for i in range(len(active_preds)):
    for j in range(len(active_preds[i])):
        metric_preds[current_pos] = active_preds[i][j]
        metric_labels[current_pos] = active_labels[i][j]
        current_pos += 1
accuracy_score(metric_labels, metric_preds)"""