# Credit to
-  [CHRIS DEOTTE, PyTorch - BigBird - NER - [CV 0.615]](http://https://www.kaggle.com/cdeotte/pytorch-bigbird-ner-cv-0-615/notebook)
- [CPMP, Faster Metric Computation](http://https://www.kaggle.com/cpmpml/faster-metric-computation/notebook)

### Training notebook is here:
[feedback2022_pytorch lightning [Train]](https://www.kaggle.com/fangyu67/feedback2022-pytorch-lightning-train)

### If you feel useful please upvote :)

In [None]:
from tqdm.auto import tqdm
import os
import random
import numpy as np
import pandas as pd

import gc
pd.set_option('display.max_columns', None)
gc.enable()

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
#from torch.utils.data import RandomSampler, SequentialSampler,TensorDataset
from torch.optim.lr_scheduler import OneCycleLR#,CosineAnnealingLR
#from torch.optim import lr_scheduler

from pytorch_lightning import LightningModule, LightningDataModule,Trainer
from pytorch_lightning.callbacks import ModelCheckpoint,LearningRateMonitor

# transformer
from transformers import AutoTokenizer, AutoModel, AdamW,AutoConfig,AutoModelForTokenClassification


def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)


# Data

In [None]:
# https://www.kaggle.com/raghavendrakotala/fine-tunned-on-roberta-base-as-ner-problem-0-533
text_names, test_texts = [], []
for f in tqdm(list(os.listdir('../input/feedback-prize-2021/test'))):
    text_names.append(f.replace('.txt', ''))
    text = open('../input/feedback-prize-2021/test/' + f, 'r').read()
    text=text.replace(",", ", ")
    test_texts.append(text)
test_text_df = pd.DataFrame({'id': text_names, 'text': test_texts})
test_text_df.head()

In [None]:
# CREATE DICTIONARIES THAT WE CAN USE DURING TRAIN AND INFER
output_labels = ['O', 'B-Lead', 'I-Lead', 'B-Position', 'I-Position', 'B-Claim', 'I-Claim', 'B-Counterclaim', 'I-Counterclaim', 
          'B-Rebuttal', 'I-Rebuttal', 'B-Evidence', 'I-Evidence', 'B-Concluding Statement', 'I-Concluding Statement']

labels_to_ids = {v:k for k,v in enumerate(output_labels)}
ids_to_labels = {k:v for k,v in enumerate(output_labels)}

#{'O': 0,'B-Lead': 1,'I-Lead': 2,'B-Position': 3,'I-Position': 4,'B-Claim': 5,'I-Claim': 6,
# 'B-Counterclaim': 7,'I-Counterclaim': 8,'B-Rebuttal': 9,'I-Rebuttal': 10,'B-Evidence': 11,
#'I-Evidence': 12,'B-Concluding Statement': 13,'I-Concluding Statement': 14}

# Data Class

In [None]:
class Dataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # GET TEXT AND WORD LABELS 
        text = self.data.text[index]
        #text_id = self.data.id[index]

        # TOKENIZE TEXT (use is_split_into_words)
        encoding = self.tokenizer(text.split(),
                             is_split_into_words=True,
                             #return_offsets_mapping=True, 
                             padding='max_length', 
                             truncation=True, 
                             max_length=self.max_len)
        
        # padding and prefix=None
        # map token[0,0,0,1,2] to split['a.b','c','d']
        #word_ids = encoding.word_ids()
            
        return {
            'input_ids': torch.tensor(encoding['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(encoding['attention_mask'], dtype=torch.long),
            #'word_ids':str(word_ids),
            #'text_id':text_id
        }


class DataModule(LightningDataModule):
    def __init__(self, test_df, tokenizer, cfg=None):
        super().__init__()
        self.test_df = test_df
        self.cfg = cfg
        self.tokenizer = tokenizer

    
    def setup(self,stage):
        if stage == 'fit':
            pass
        elif stage=='predict':
            self.test_ds = Dataset(self.test_df, self.tokenizer, self.cfg.max_length)
    
    def predict_dataloader(self):
        return DataLoader(
            self.test_ds, batch_size=self.cfg.batch_size, 
            shuffle=False, num_workers=self.cfg.num_workers,
            pin_memory=True
            )

# Model

In [None]:
class ModelModule(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg=cfg
        config = AutoConfig.from_pretrained(self.cfg['modelpath']+'/config.json')
        self.model = AutoModelForTokenClassification.from_pretrained(self.cfg['modelpath']+'/pytorch_model.bin',config=config)
        #self.model = AutoModel.from_pretrained(self.hparams.modelpath,config=config)

    def forward(self, input_ids, attention_mask):
        out = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return out.logits
    
    def predict_step(self, batch, batch_idx):
        logits = self(batch['input_ids'], batch['attention_mask']) # (N,seq,label)
        return logits.cpu().detach().numpy()

# inference

In [None]:
class CFG:
    def __init__(self):
        self.n_procs=1
        self.num_workers=2
        self.precision = 16
        self.seed=2022
        #########################
        self.modelpath = '../input/py-bigbird-v26'
        self.ckppath = '../input/feedback2022/BigBird-ep5-len1024.ckpt'
        self.tokpath = '../input/py-bigbird-v26'
        self.max_length=1024
        self.num_labels=15
        self.batch_size = 4 

CFG1 = CFG()
seed_everything(CFG1.seed)

In [None]:

print(CFG1.ckppath)
    
model = ModelModule(CFG1.__dict__)
model.load_state_dict(torch.load(CFG1.ckppath)['state_dict'])
#model.load_state_dict(torch.load(CFG1.ckppath))

tokenizer = AutoTokenizer.from_pretrained(CFG1.tokpath,add_prefix_space=True)
test_loader = DataModule(test_text_df,tokenizer,CFG1)

trainer = Trainer(gpus=1,precision=CFG1.precision,num_sanity_val_steps=0)

preds = trainer.predict(model, datamodule=test_loader)
preds = np.concatenate(preds) # (N,seq,label)


#torch.cuda.empty_cache()
#del test_loader,trainer,model   
#gc.collect()
#!free -m

In [None]:
preds_max1 = np.argmax(preds,axis=2) # (N,seq)
text_id = test_text_df['id'].values  # (N)
word_ids = []                        # (N,seq)
for text in test_text_df['text']:
    encoding = tokenizer(
        text.split(),is_split_into_words=True,#return_offsets_mapping=True, 
        padding='max_length', truncation=True, max_length=CFG1.max_length
    )
    word_ids.append(encoding.word_ids())

# get prediction str

In [None]:
def get_prediction(preds,word_ids,text_id): 
        
        sub = pd.DataFrame(columns = ['id','class','predictionstring'])
        
        for k in range(len(text_id)):
            id_ = text_id[k]
            pred_ = [ids_to_labels[i] for i in preds[k]]
            word_ids_ = word_ids[k]
            
            prediction = [] #word wise
            previous_word_idx = -1
            
            for idx,word_idx in enumerate(word_ids_):                            
                if word_idx!=None and word_idx != previous_word_idx:
                    # use only first subword pred  
                    prediction.append(pred_[idx])
                    previous_word_idx = word_idx
            j = 0
            end = 0
            while j < len(prediction):
                if prediction[j]=='O':
                    j+=1
                else:
                    cls = prediction[j].replace('B','I') # Take I and B
                    end = j + 1
                    while end < len(prediction) and prediction[end] == cls:
                        end += 1
                    if end - j > 5: # 7 to check
                        sub = sub.append(
                            pd.Series([id_, cls.replace('I-','') ,' '.join(map(str, list(range(j, end))))], index = sub.columns), 
                            ignore_index=True)
                    j = end
            
        return sub
        

In [None]:
sub = get_prediction(preds_max1,word_ids,text_id)
#sub.head()

In [None]:
sub.to_csv("submission.csv", index=False)
sub.head()