In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm


In [None]:
class Config:
    model_name = '../input/allenailongformerbase4096'
    max_length = 1024
    train_batch_size = 4
    valid_batch_size = 4
    epochs = 5
    learning_rates = [2e-5,2e-5,2e-5,2e-5,2e-6]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    seed = 318
    tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
    bert_lr = 1e-3

In [None]:
train_df = pd.read_csv('../input/feedback-prize-2021/train.csv')
train_df

In [None]:
import os
test_names, test_texts = [], []
for f in list(os.listdir('../input/feedback-prize-2021/test')):
    test_names.append(f.replace('.txt',''))
    test_texts.append(open('../input/feedback-prize-2021/test/'+f).read())
test_texts = pd.DataFrame({"id":test_names,'text':test_texts})
test_texts

In [None]:
import os
train_names, train_texts = [], []
for f in list(os.listdir('../input/feedback-prize-2021/train')):
    train_names.append(f.replace('.txt',''))
    train_texts.append(open('../input/feedback-prize-2021/train/'+f).read())
train_texts = pd.DataFrame({"id":train_names,'text':train_texts})
train_texts

In [None]:
all_entities = []
for i, item in enumerate(train_texts.iterrows()):
    if i % 100 == 0: print(i, ", ", end="")
    total = len(item[1]['text'].split())
    entities = ['O'] * total
    for j in train_df[train_df['id']==item[1]['id']].iterrows():
        discourse = j[1]['discourse_type']
        list_idx = [int(x) for x in j[1]['predictionstring'].split()]
        entities[list_idx[0]] = f'B-{discourse}'
        for k in list_idx[1:]:
            entities[k] = f'I-{discourse}'
        
    all_entities.append(entities)
train_texts['entities'] = all_entities
    


In [None]:
train_texts

In [None]:
output_labels = ['O', 'B-Lead', 'I-Lead', 'B-Position', 'I-Position', 'B-Claim', 'I-Claim', 'B-Counterclaim', 'I-Counterclaim', 
          'B-Rebuttal', 'I-Rebuttal', 'B-Evidence', 'I-Evidence', 'B-Concluding Statement', 'I-Concluding Statement']

labels_to_ids = {v:k for k,v in enumerate(output_labels)}
ids_to_labels = {k:v for k,v in enumerate(output_labels)}

In [None]:
LABEL_ALL_SUBTOKENS = False

class MyDataset(Dataset):
  def __init__(self, dataframe,get_wids):
        
        self.data = dataframe
        self.tokenizer = Config.tokenizer
        self.max_len = Config.max_length
        self.get_wids = get_wids # for validation

  def __getitem__(self, index):
        # GET TEXT AND WORD LABELS 
        text = self.data.text[index]        
        word_labels = self.data.entities[index] if not self.get_wids else None

        # TOKENIZE TEXT
        encoding = self.tokenizer(text.split(),
                             is_split_into_words=True,
                             #return_offsets_mapping=True, 
                             padding='max_length', 
                             truncation=True, 
                             max_length=self.max_len)
        word_ids = encoding.word_ids()  
        
        # CREATE TARGETS
        if not self.get_wids:
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:                            
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:              
                    # print(word_labels[word_idx])
                    label_ids.append( labels_to_ids[word_labels[word_idx]] )
                else:
                    if LABEL_ALL_SUBTOKENS:
                        label_ids.append( labels_to_ids[word_labels[word_idx]] )
                    else:
                        label_ids.append(-100)
                previous_word_idx = word_idx
            encoding['labels'] = label_ids

        # CONVERT TO TORCH TENSORS
        item = {key: torch.as_tensor(val) for key, val in encoding.items()}
        if self.get_wids: 
            word_ids2 = [w if w is not None else -1 for w in word_ids]
            item['wids'] = torch.as_tensor(word_ids2)
        
        return item

  def __len__(self):
        return len(self.data)

In [None]:
import random
IDS = train_texts.id.unique()

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(Config.seed)

In [None]:
train_idx = np.random.choice(np.arange(len(IDS)),int(0.9*len(IDS)),replace=False)
valid_idx = np.setdiff1d(np.arange(len(IDS)), train_idx)

In [None]:
data = train_texts[['id','text','entities']]
train_dataset = data.loc[data['id'].isin(IDS[train_idx]), ['text', 'entities']].reset_index(drop=True)
test_dataset = data.loc[data['id'].isin(IDS[valid_idx])].reset_index(drop=True)

print("FULL Dataset: {}".format(data.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

train_dataset = MyDataset(train_dataset,False)
test_dataset = MyDataset(test_dataset,False)

In [None]:
train_dataloader = DataLoader(train_dataset, 
                              batch_size=Config.train_batch_size,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)
test_dataloader = DataLoader(test_dataset,
                             batch_size=Config.valid_batch_size,
                             shuffle=False,
                             num_workers=2,
                             pin_memory=True)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = AutoModel.from_pretrained(Config.model_name)
        self.ffn1 = nn.Linear(768, 15)
        # self.ffn2 = nn.Linear(128, 15)
        # self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        self.dropout = nn.Dropout(0.1)
    def forward(self,input_ids, attention_mask):
        output = self.bert(input_ids, attention_mask)['last_hidden_state']
        output = self.dropout(output)
        output = self.ffn1(output)
        # output = self.relu(output)
        # output = self.ffn2(output)
        output = self.softmax(output)
        return output

In [None]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    model.train()
    running_loss = 0
    data_size = 0
    for step, batch in bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        loss_fn = nn.CrossEntropyLoss()

        batch_size = input_ids.size(0)
        output = model(input_ids,attention_mask)
        output = torch.permute(output,(0,2,1))
        # print(label.shape)
        # print(output.shape)
        
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

        running_loss += loss.item() * batch_size
        data_size += batch_size
        epoch_loss = running_loss / data_size
        bar.set_postfix(EPOCH=epoch, TRAINING_LOSS="{:.6f}".format(epoch_loss))
    return epoch_loss



In [None]:
def valid_one_epoch(model, data_loader, device, epoch):
    model.eval()
    data_size = 0
    running_loss = 0
    bar = tqdm(enumerate(data_loader), total=len(data_loader))

    for step, batch in bar:
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            output = model(input_ids, attention_mask)
            output = torch.permute(output,(0,2,1))
            batch_size = input_ids.size(0)
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(output, labels)

            running_loss += loss.item() * batch_size
            data_size += batch_size
            epoch_loss = running_loss / data_size
            bar.set_postfix(EPOCH=epoch, VALIDATION_LOSS="{:.6f}".format(epoch_loss))
    return epoch_loss




In [None]:
import copy
from collections import defaultdict


def train(model, num_epochs, training_dataloader, validation_dataloader, device):
    best_model_weights = copy.deepcopy(model.state_dict())
    best_epoch_loss = float('inf')
    history = defaultdict(list)
    optimizer = AdamW([
        {'params': model.bert.parameters()},
        {'params': model.ffn1.parameters()},
        # {'params': model.ffn2.parameters()}
    ], lr=Config.bert_lr)

    num_train_steps = len(training_dataloader.dataset) / Config.train_batch_size * Config.epochs
    num_warm_steps = int(num_train_steps * 0.1)

    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=num_warm_steps,
                                                num_training_steps=num_train_steps)

    for epoch in range(1, num_epochs + 1):
        train_epoch_loss = train_one_epoch(model,
                                           optimizer,
                                           scheduler,
                                           training_dataloader,
                                           device,
                                           epoch)
        valid_epoch_loss = valid_one_epoch(
                        model,
                        validation_dataloader,                                                              
                        device,
                        epoch)
#(model, data_loader, device, epoch
        history['TRAINING_LOSS'].append(train_epoch_loss)
        history['VALIDATION_LOSS'].append(valid_epoch_loss)

        if valid_epoch_loss < best_epoch_loss:
            print("The best loss was {}, current loss was{}".format(best_epoch_loss, valid_epoch_loss))
            best_epoch_loss = valid_epoch_loss
            best_model_weights = copy.deepcopy(model.state_dict)
            PATH = f'model.bin'
            torch.save(model.state_dict, PATH)
    return model, history


In [None]:
model = Model()
model.to(Config.device)
train(model, Config.epochs, train_dataloader, test_dataloader, Config.device)

In [None]:
# model = Model()
# model.to(Config.device)


In [None]:
# test_names, test_texts = [], []
# for f in list(os.listdir('../input/feedback-prize-2021/test')):
#     test_names.append(f.replace('.txt',''))
#     test_texts.append(open('../input/feedback-prize-2021/test/'+f).read())

# test_texts = pd.DataFrame({'id':test_names, 'text':test_texts})
# test_dataset = MyDataset(test_texts, True)
# test_texts_loader = DataLoader(test_dataset, 
#                                batch_size=Config.valid_batch_size,
#                                shuffle=False,
#                                num_workers=2,
#                                pin_memory=True
#                                )

In [None]:
# def inference(model, batch):
#     input_ids = batch['input_ids'].to(Config.device)
#     attention_mask = batch['attention_mask'].to(Config.device)
#     outputs = model(input_ids, attention_mask)
    
#     all_preds = torch.argmax(outputs, axis=-1).cpu().numpy()
#     print(all_preds.shape)
#     predictions = []
#     for k, text_preds in enumerate(all_preds):
#         token_preds = [ids_to_labels[i] for i in text_preds]

#         prediction = []
#         word_ids = batch['wids'][k].numpy()
#         previous_word_idx = -1
#         for idx, word_idx in enumerate(word_ids):
#             if word_idx == -1:
#                 pass
            
#             elif word_idx != previous_word_idx:
#                 prediction.append(token_preds[idx])
#                 previous_word_idx = word_idx
#         predictions.append(prediction)
#     return predictions



In [None]:
# def get_predictions(model, df, dataloader):
#     model.eval()
#     y_pred2 = []
#     for i, batch in enumerate(dataloader):
#         labels = inference(model, batch)
#         y_pred2.extend(labels)
#     final_preds2 = []
#     for i in range(len(df)):
#         idx = df.id.values[i]
#         pred = y_pred2[i]
#         preds = []
#         j = 0
#         while j < len(pred):
#             cls = pred[j]
#             if cls == 'O': j += 1
#             else: cls = cls.replace("B","I")
#             end = j + 1

#             while end < len(pred) and pred[end] == cls:
#                 end += 1
#             if cls != "O" and cls != "" and end - j >7:
#                 final_preds2.append((idx,cls.replace("I-","")," ".join(map(str, list(range(j, end))))))
            
#             j = end
        
#     oof = pd.DataFrame(final_preds2)
#     oof.columns = ['id','class','predictionstring']
#     return oof
    


In [None]:
# df = get_predictions(model, test_texts, test_texts_loader)

In [None]:
# tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base",add_prefix_space=True)
# model = AutoModel.from_pretrained(Config.model_name)

In [None]:
# sentences = "Go Home RIGHT NOW!"
# inputs = tokenizer(sentences.split(),
#                    is_split_into_words=True,
#                    padding='max_length',
# #                    max_length=1024)

In [None]:
# inputs.word_ids()

In [None]:
# output = model(torch.tensor([inputs['input_ids']]),torch.tensor([inputs['attention_mask']]))

In [None]:
# output.keys()