In [None]:
import os
import re
import time
import pickle
import string
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn


from torch.nn.functional import softmax
from collections import defaultdict, Counter
from sklearn.metrics import confusion_matrix, f1_score
from nltk.tokenize import sent_tokenize, word_tokenize
from spacy.lang.en.stop_words import STOP_WORDS


In [None]:
device=torch.device( 'cuda' if  torch.cuda.is_available() else 'cpu')
device

In [None]:
config={
    'label_batch_size': 64,
    'unlabel_batch_size': 64,
    'test_batch_size': 150,
    'max_seq_len': 100,
    'glove_dim': 100,
    'num_labels': 3,
    'ner_format': "BIO",
    'grad_clip_val': 1,
    'base_lr': 1e-5,
    'max_lr': 1e-4,
    'T': 0.6,
    'weight_decay': 1e-5,
    'num_iterations': 50,
    'print_every': 10,
    'eval_every':10,
    'overlap_len': 20,
    'glove_path': '../input/glove6b/glove.6B.100d.txt',
    'unlabeled_datapath': '../input/dataset/train_data.pkl',
    'train_folder': "../input/coleridgeinitiative-show-us-the-data/train",
    'test_folder': '../input/coleridgeinitiative-show-us-the-data/test',
    'ss_folder': '../input/coleridge-semisuperviseddata',
    'model_path': '../input/pretrain1/best_fscore_model.pt',
    'test_dir': '../input/coleridgeinitiative-show-us-the-data/test'
}

model_params={
    "pre_embedd_dim": 100,
    'word_shape_size': 7,
    "word_shape_embedd_dim": 20,
    "hdim": 128,
    "proj_dim": 512,
    "out_dim": 3,
    'max_seq_len': 100,
}

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Cleaning With Jaccard is taken from https://www.kaggle.com/tungmphung/coleridge-matching-bert-ner

# Utils

In [None]:
exclude_entities=['table', 'fig', 'provide','data','mri', 'result']


def get_word_properties(word):
    prop={}
    
    prop['word']=word
    prop['is_alpha']=False
    prop['is_title']=False
    prop['is_upper']=False
    prop['is_lower']=False
    prop['has_upper']=False
    prop['is_number']=False
    prop['is_stopword']=False
    prop['is_punct'] = False
    prop['alpha_1']=False
    
    
    if word.isalpha():
        prop['is_alpha']=True
    if word[0].isalpha():
        prop['alpha_1']=True
    if word.islower():
        prop['is_lower']=True
    if word.isupper():
        prop['is_upper']=True
        prop['has_upper']=True
    if word.lower() in STOP_WORDS:
        prop['is_stopword']=True
    if word.istitle():
        prop['is_title']=True
        prop['has_upper']=True
    if word.isnumeric():
        prop['is_number']=True
        
    if word in string.punctuation:
        prop['is_punct']=True
    
    if (prop['is_alpha']) and (not prop['has_upper']) and (not prop['is_lower']):
        for ch in word:
            if ch.isupper():
                prop['has_upper']=True
    return prop


def get_candidate_index(i, word_props, word_len, candidates):
    word=word_props[i]['word'].lower()
    if (candidates[i-1] == 0) and  (word_props[i]['has_upper']) and (not word_props[i]['is_stopword']):
        return 1
    
    if candidates[i-1]==1 or candidates[i-1]==2:
        if word_props[i]['is_punct'] and word in ['(', ')', '-']:
            return 2
        if word_props[i]['is_punct']:
            return 0
        if word_props[i]['has_upper']:
            return 2
        if word_props[i]['is_lower'] and (i+1 < word_len) and (word_props[i+1]['is_lower']):
            return 0
        if (word_props[i]['is_lower'] and (i+1 < word_len) and 
                (word_props[i+1]['has_upper']) and word in ['in', 'for', 'of', 'the', 'and']):
            return 2
    return 0
    
def get_candidate_entities(sentence):
    words=word_tokenize(sentence)
    words=[word.strip() for word in words]
    words=[word for word in words if len(word)>0]
    if words[0].isnumeric():
        words=words[1:]
    words_len=len(words)
    word_props=[]
    candidates=[0]*words_len
    cwords=[]
    if words_len <= 5:
        return (candidates, cwords)
    
    for word in words:
        prop=get_word_properties(word)
        word_props.append(prop)
    
    if (not word_props[0]['is_stopword']) and (word_props[0]['has_upper'] and (not word_props[1]['word']==',')):
        candidates[0]=1
    
    for i in range(1, words_len):
        prop=word_props[i]
        word=prop['word']
        candidates[i]=get_candidate_index(i, word_props, words_len, candidates)
    
    #Removing the first word sequence as candidate words.
    if candidates[0]==1:
        candidates[0]=0
        for i in range(1, words_len):
            if candidates[i]==1:
                break
            candidates[i]=0
    s=-1; e=-1
    for i in range(words_len):
        if candidates[i]==1 and s==-1:
            s=i
            e=i
        elif candidates[i]==1 and s!=-1:
            cwords.append(' '.join(words[s: e+1]))
            s=i
        elif candidates[i]==2:
            e=i
        elif s!=-1 and candidates[i]==0:
            cwords.append(' '.join(words[s: e+1]))
            s=-1
            e=-1
    cwords=[word for word in cwords if (len(word)>2) and (not word[0].isnumeric())]
    return (candidates, cwords)


def get_abbrevation(sentence_words, sentence_len, i):
    if i==sentence_len or sentence_words[i]!='(':
        return ''
    else:
        j=i
        while j<sentence_len:
            if sentence_words[j] ==')':
                break
            j+=1
        abbr=sentence_words[i+1: j]
        abbr=[word.strip() for word in abbr]
        abbr=[word for word in abbr if len(word)!=0]
        if len(abbr) ==1 and abbr[0] == 1:
            return ''
        if not abbr[0][0].isupper():
            return ''
    return ''.join(sentence_words[i: j+1])
        

def get_abbrevated_labels(sentence, all_labels):
    sentence_words=word_tokenize(sentence)
    sentence_words=[w for w in sentence_words if w!='']
    sentence_len=len(sentence_words)
    abbr_labels=[]
    for label in all_labels:
        label_words=word_tokenize(label)
        for i, sword in enumerate(sentence_words):
            flag=True
            for j, lword in enumerate(label_words):
                if sentence_words[i+j] != lword:
                    flag=False
                    break
            if flag:
                abbr=get_abbrevation(sentence_words, sentence_len, i+len(label_words))
                abbr_labels.append(label+" "+abbr)
                break
    return abbr_labels



def get_new_entities(candidates, all_labels, sentence):
    new_ents=[]
    if len(all_labels)==0:
        return []
    sent_words=word_tokenize(sentence)
    words_len=len(sent_words)
    entity_markers=[0]*words_len
    ent_id=1
    
    try:
        for label in all_labels:
            lwords=word_tokenize(label)
            for i in range(words_len):
                flag=True
                for j in range(len(lwords)):
                    if lwords[j] != sent_words[i+j]:
                        flag=False
                        break
                if flag:
                    for k in range(i, i+len(lwords)):
                        entity_markers[k]=ent_id
                    ent_id+=1
        for i, em in enumerate(entity_markers):
            if (em==0):
                continue
            tol=1
            if (i-1)>=0 and entity_markers[i-1]==0 and sent_words[i-1]==',':
                s=-1;e=-1
                for j in range(i-2, -1, -1):
                    if tol==0 or entity_markers[j]!=0:
                        s=-1;e=-1;
                        break
                    if e==-1 and candidates[j]==0 and sent_words[j]!=')':
                        tol-=1

                    if s!=-1 and candidates[j]==1:
                        s=j
                        break

                    if e==-1 and candidates[j]!=0:
                        e=j;s=j;tol=2;
                        if candidates[j]==1:
                            break

                    if candidates[j]==2:
                        s=j
                if s!=-1:
                    new_ents.append(sent_words[s:e+1])


            tol=1
            if (i+1)<words_len and entity_markers[i+1]==0 and (sent_words[i+1]==',' or sent_words[i+1]=='and'):
                s=-1; e=-1
                for j in range(i+2, words_len):
                    if tol==0 or entity_markers[j]!=0:
                        s=-1; e=-1
                        break
                    if s==-1 and candidates[j]==0 and sent_words[j]!='(':
                        tol-=1
                    if s==-1 and candidates[j]!=0:
                        s=j; e=j; tol=2
                    if s!=-1 and (candidates[j]==1 or candidates[j]==0):
                        e=j-1
                        break
                    else:
                        e=j
                if s!=-1:
                    new_ents.append(sent_words[s:e+1])
    except:
        pass
    
    new_ents_final=[]
    for ent in new_ents:
        flag=True
        for f in exclude_entities:
            if f in ent:
                flag=False
                break
        if flag:
            new_ents_final.append(ent)
    return new_ents_final

def get_datalabels(pub_id):
    data=get_publications_data(pub_id, config['test_dir'])
    all_labels=[]
    new_entities=[]
    for section in data:
        text=section['text']
        sentences=sent_tokenize(text)
        for sentence in sentences:
            sentence_labels=[]
            for label in database:
                if label in sentence:
                    sentence_labels.append(label)
            if len(sentence_labels) != 0:
                abrevation_labels=get_abbrevated_labels(sentence, sentence_labels)
                all_labels+=sentence_labels
                all_labels+=abrevation_labels
                
                #Taken from https://www.kaggle.com/tungmphung/pytorch-bert-for-named-entity-recognition/comments
                if any( [word in sentence.lower() for word in ['data', 'study']] ): 
                    (candidates, cwords)=get_candidate_entities(sentence)
                    new_entities+=get_new_entities(candidates, all_labels, sentence)
    return (all_labels, new_entities)


In [None]:
def update_annotation(annots, s, l):
    for i in range(s, s+l):
        if annots[i] == 'I':
            continue
        if i == s:
            annots[i]='B'
        else:
            annots[i]='I'

def get_annotated_data(data):
    sentence=data['sentence']
    labels=data['labels'] if data.get('labels', None) else []
    words=word_tokenize(sentence)
    annots=['O']*len(words)
    labels=[word_tokenize(label) for label in labels]
    
    for i, word in enumerate(words):
        for label in labels:
            if words[i:i+len(label)] == label:
                update_annotation(annots, i, len(label))
    return (words, annots)

def read_file(filepath):
    with open(filepath) as file:
        data=file.read()
    return data

def read_json_file(filepath):
    data=json.loads(read_file(filepath))
    return data
def read_pickle(filepath):
    with open(filepath, 'rb') as file:
        data=pickle.load(file)
    return data

def get_word_shape(word):
    if len(word)==0:
        return 0
    ch=word[0]
    if ch.isupper():
        return 1
    elif ch.islower():
        return 2
    elif ch.isnumeric():
        return 3
    elif ch == ',':
        return 4
    elif ch == '(':
        return 5
    elif ch==')':
        return 6
    return 0


def get_publications_data(pub_id, dirname):
    pub_filename=dirname+'/{}'.format(pub_id)
    with open(pub_filename) as file:
        data=json.load(file)
    return data

def get_short_sentences(sentence):
    sentence_words=word_tokenize(sentence)
    short_sentences=[]
    if len(sentence_words) < config['max_seq_len']:
        return [sentence]
    
    for i in range(0, len(sentence_words), config['max_seq_len']-config['overlap_len']):
        if len(sentence_words[i:i+config['max_seq_len']]) <= 10:
            continue
        short_sentences.append( ' '.join(sentence_words[i:i+config['max_seq_len']] ))
    return short_sentences

def is_eligible_sentence(sentence_words):
    lower_cnt=0
    for word in sentence_words:
        if word.islower():
            lower_cnt+=1
    if (lower_cnt<=5) or (len(sentence_words)>=300):
        return False
    return True

def get_sentences(pub_id, text):
    sentences=sent_tokenize(text)
    short_sentences=[]
    infer_sentences=[]
    
    for sentence in sentences:
        short_sentences+=get_short_sentences(sentence)
    for sentence in short_sentences:
        sentence_words=word_tokenize(sentence)
        if (not is_eligible_sentence(sentence_words)):
            continue
        if any([word in sentence.lower() for word in ['study', 'data']]):
            infer_sentences.append(sentence_words)
    return infer_sentences

# Processing Data

In [None]:
class PreprocessData:
    def __init__(self, config):
        self.train_folder=config['train_folder']
        self.ss_folder=config['ss_folder']
        self.glove_path=config['glove_path']
        self.unlabled_path=config['unlabeled_datapath']
    
    def download_glove(self):
        glove_embeddings={}
        with open(self.glove_path) as file:
            for line in file:
                line=line.split()
                word=line[0]
                v=np.array(line[1:]).astype(np.float)
                glove_embeddings[word]=v
        return glove_embeddings

    def get_labeled_data(self):
        pos_data=read_json_file(os.path.join(config['ss_folder'], 'annotation.txt'))
        neg_data=read_json_file(os.path.join(config['ss_folder'], 'negative_sentences_corrected.txt'))
        return (pos_data, neg_data)
    
    def get_data(self):
        (pos_data, neg_data)=self.get_labeled_data()
        unlabeled_data=read_pickle(self.unlabled_path)
        for data in pos_data:
            pub_id=data.get('pub_id', None)
            unlabeled_data.pop(pub_id) if unlabeled_data.get(pub_id, None) else ''
        for data in neg_data:
            pub_id=data.get('pub_id', None)
            unlabeled_data.pop(pub_id) if unlabeled_data.get(pub_id, None) else ''
        
        unlabeled_data=list(unlabeled_data.values())
        pos_data=[get_annotated_data(data) for data in pos_data]
        neg_data=[get_annotated_data(data) for data in neg_data]
        
        glove_embeddings=self.download_glove()
        labeled_data={
            'pos_data': pos_data,
            'neg_data': neg_data
        }
        return (glove_embeddings, labeled_data, unlabeled_data)

(glove_embeddings, labeled_data, unlabeled_data)=PreprocessData(config).get_data()

# Dataset & DataIterators

In [None]:
class LabeledIterator:
    def __init__(self,config,glove_embeddings, labeled_data):
        self.batch_size=config['label_batch_size']
        self.glove_dim=config['glove_dim']
        self.num_labels=config['num_labels']
        self.max_seq_len=config['max_seq_len']
        self.glove_embeddings=glove_embeddings
        self.labeled_data=labeled_data
        
        
        pos_data=labeled_data['pos_data']
        neg_data=labeled_data['neg_data']
        
        pos_ids=list(range(len(pos_data)))
        neg_ids=list(range(len(neg_data)))
        
        pos_train_ids=pos_ids[:-1]
        neg_train_ids=neg_ids[:-1]
        
        pos_val_ids=pos_ids[-1:]
        neg_val_ids=neg_ids[-1:]
        
        self.train_data=([data for i, data in enumerate(pos_data) if i in pos_train_ids]+
                         [data for i, data in enumerate(neg_data) if i in neg_train_ids])
        
        self.val_data  =([data for i, data in enumerate(pos_data) if i in pos_val_ids]+
                         [data for i, data in enumerate(neg_data) if i in neg_val_ids])
        
        
        self.num_train_records=len(self.train_data)
        self.num_val_records=len(self.val_data)
    
    def make_shuffle_data(self, data, shuffle):
        if shuffle:
            np.random.shuffle(data)
            
    def get_data_by_mode(self, mode, shuffle):
        if mode=='val':
            self.make_shuffle_data(self.val_data, shuffle)
            num_records=self.num_val_records
            data=self.val_data
        elif mode == 'train':
            self.make_shuffle_data(self.train_data, shuffle)
            num_records=self.num_train_records
            data=self.train_data
        return num_records, data
   
    def convert_annotation_to_label(self, annot):
        if annot == 'B':
            return 1
        elif annot =='I':
            return 2
        return 0
    
    def convert_rawdata_to_tensors(self, mbs):
        mb_size=len(mbs)
        X=torch.zeros((mb_size, self.max_seq_len, self.glove_dim), dtype=torch.float32)
        X_embedd=torch.zeros( (mb_size, self.max_seq_len), dtype=torch.long)
        y_bios=torch.full((mb_size, self.max_seq_len), 3, dtype=torch.long)
        y_ents=torch.zeros((mb_size, self.max_seq_len), dtype=torch.float32)
        slens=torch.zeros(mb_size, dtype=torch.long)
        
        for i in range(mb_size):
            (words, annots)=mbs[i]
            words_len=min(len(words), self.max_seq_len)
            slens[i]=words_len
            for j in range(words_len):
                wshape=get_word_shape(words[j])
                word=words[j].lower()
                
                X_embedd[i][j]=wshape
                y_bios[i][j]=self.convert_annotation_to_label(annots[j])
                if y_bios[i][j] == 1 or y_bios[i][j]==2:
                    y_ents[i][j]=1
                if word in self.glove_embeddings:
                    X[i][j]=torch.tensor(self.glove_embeddings[word], dtype=torch.float32)
                    
        return (X, X_embedd, y_ents, y_bios, slens)
    def get_raw_minibatch(self, mode='val', shuffle=False):
        num_records, data=self.get_data_by_mode(mode, shuffle)
        for i in range(0, num_records, self.batch_size):
            yield data[i:i+self.batch_size]
            
    def get_minibatch(self, mode='val', shuffle=False):
        for mbs in self.get_raw_minibatch(mode, shuffle):
            mbs=self.convert_rawdata_to_tensors(mbs)
            yield mbs
    def get_infinite_minibatch(self, mode='val', shuffle=False):
        while True:
            for mbs in self.get_minibatch(mode, shuffle):
                yield mbs
    def __iter__(self):
        while True:
            for mbs in self.get_raw_minibatch('train', True):
                mbs=self.convert_rawdata_to_tensors(mbs)
                yield mbs

In [None]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, df, max_seq_len, glove_dim):
        self.df=df
        self.max_seq_len=max_seq_len
        self.glove_dim=glove_dim
        
    def __getitem__(self, idx):
        row=self.df.iloc[idx]
        sentence_words=row.sentence_words
        
        X=torch.zeros((self.max_seq_len, self.glove_dim))
        X_embedd=torch.zeros(self.max_seq_len, dtype=torch.long)
        slen=torch.tensor( min(self.max_seq_len, len(sentence_words)), dtype=torch.long)
        for i, word in enumerate(sentence_words):
            if i>=self.max_seq_len:
                break
            wshape=get_word_shape(sentence_words[i])
            word=word.lower()
            X_embedd[i]=wshape
            if word in glove_embeddings:
                X[i]=torch.tensor(glove_embeddings[word], dtype=torch.float32)
        return (slen, X, X_embedd)
    def __len__(self):
        return len(self.df)

# Model

In [None]:
class PrimaryHead(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.linear=nn.Linear(2*params['hdim'], params['proj_dim'])
        self.bn=nn.BatchNorm1d(params['proj_dim'])
        self.dropout=nn.Dropout(0.3)
        self.relu=nn.ReLU()
        self.out_layer=nn.Linear(params['proj_dim'], params['out_dim'])
    def forward(self, x):
        x=self.linear(x)
        x=x.permute(0, 2, 1)
        x=self.bn(x)
        x=self.dropout(x)
        x=self.relu(x)
        x=x.permute(0, 2, 1)
        x=self.out_layer(x)
        return x
    
class AuxilaryHead(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.linear=nn.Linear(2*params['hdim'], params['proj_dim'])
        self.bn=nn.BatchNorm1d(params['proj_dim'])
        self.dropout=nn.Dropout(0.3)
        self.relu=nn.ReLU()
        self.out_layer=nn.Linear(params['proj_dim'], 1)
    def forward(self, x):
        x=self.linear(x)
        x=x.permute(0, 2, 1)
        x=self.bn(x)
        x=self.dropout(x)
        x=self.relu(x)
        x=x.permute(0, 2, 1)
        x=self.out_layer(x)
        return x    

class Model(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.params=params
        self.embedd_layer=nn.Embedding(params['word_shape_size'], params['word_shape_embedd_dim'],
                                       max_norm=1, padding_idx=0)
        self.gru=nn.GRU(params['pre_embedd_dim'] + params['word_shape_embedd_dim'],
                        params['hdim'], num_layers=2,
                        bidirectional=True, dropout=0.3, batch_first=True)
        
        self.primary_head=PrimaryHead(params)
        self.aux_head=AuxilaryHead(params)
    
    def forward(self, x, x_embedd):
        batch_size=x.shape[0]
        seq_len=x.shape[1]
        x_embedd=self.embedd_layer(x_embedd)
        (h_n, _)=self.gru(torch.cat([x, x_embedd], dim=-1))
        h_n=h_n.view(batch_size, seq_len, 2, self.params['hdim'] )
        h1=h_n[:, :, 0, :]
        h2=h_n[:, :, 1, :]
        
        h=torch.cat([h1, h2], dim=-1)
        y_ent=self.aux_head(h)
        y_bios=self.primary_head(h)
        return (y_ent, y_bios)

# Evaluate

In [None]:
def evaluate(model):
    H=0.0;num_records=0;
    true_ent_labels=[];pred_ent_labels=[];
    true_bio_labels=[];pred_bio_labels=[];
    
    model.eval()
    labeled_iterator=LabeledIterator(config, glove_embeddings, labeled_data)
    for mbs in labeled_iterator.get_minibatch(mode='train', shuffle=False):
        (X, X_embedd, y_ents, y_bios, _)=mbs
        with torch.no_grad():
            (yhat_ent, yhat_bios)=model(X, X_embedd)
            yhat_ent=(yhat_ent>0).to(int)
            yhat_bios=yhat_bios.argmax(dim=-1)
            num_records+=1
        
        batch_size=X.shape[0]
        y_ents=y_ents.view(-1)
        yhat_ent=yhat_ent.view(-1)
        
        y_bios=y_bios.view(-1)
        yhat_bios=yhat_bios.view(-1)
        
        
        yhat_bios=yhat_bios[y_bios!=3]
        yhat_ent=yhat_ent[y_bios!=3]
        y_ents=y_ents[y_bios!=3]
        y_bios=y_bios[y_bios!=3]
        
        true_ent_labels+=list(y_ents.numpy())
        pred_ent_labels+=list(yhat_ent.numpy())
        
        true_bio_labels+=list(y_bios.numpy())
        pred_bio_labels+=list(yhat_bios.numpy())
    
    cm1=confusion_matrix(true_ent_labels, pred_ent_labels)
    cm2=confusion_matrix(true_bio_labels, pred_bio_labels)
    
    micro_fscore_ent=f1_score(true_ent_labels, pred_ent_labels, average='micro')
    micro_fscore_bio=f1_score(true_bio_labels, pred_bio_labels, average='micro')
    
    macro_fscore_ent=f1_score(true_ent_labels, pred_ent_labels, average='macro')
    macro_fscore_bio=f1_score(true_bio_labels, pred_bio_labels, average='macro')
    
    return (cm1,cm2, micro_fscore_ent, micro_fscore_bio, macro_fscore_ent, macro_fscore_bio)

# Fine Tuning

In [None]:
class Trainer:
    def __init__(self, model):
        self.model=model
        self.iter_count=0
        self.total_loss=0.0
        self.ent_loss=0.0
        self.bios_loss=0.0
        self.constraint_loss=0.0
        
        self.loss_=[]
        self.ent_loss_=[]
        self.bios_loss_=[]
        self.constraint_loss_=[]
        
        self.criterion1=nn.BCEWithLogitsLoss(reduction='mean')
        self.criterion2=nn.CrossEntropyLoss(ignore_index=3, reduction='mean')
        
        self.logsoftmax=nn.LogSoftmax(dim=-1)
        self.logsigmoid=nn.LogSigmoid()
        self.sigmoid=nn.Sigmoid()
        
        self.optimizer=torch.optim.AdamW(model.parameters(), lr=config['max_lr'], 
                                         weight_decay=config['weight_decay'])
        self.schedular=torch.optim.lr_scheduler.OneCycleLR(self.optimizer,
                                                           max_lr=config['max_lr'], 
                                                           total_steps=config['num_iterations'])
        self.labeledIterator=iter(LabeledIterator(config, glove_embeddings, labeled_data))
    
    
    def get_entity_loss(self, slens, y_ents, logyhat_ent0, logyhat_ent1 ):
        batch_size=y_ents.shape[0]
        loss=torch.tensor(0.0)
        pos_loss=torch.tensor(0.0); neg_loss=torch.tensor(0.0)
        pos_cnt=0; neg_cnt=0
        
        for i in range(batch_size):
            ysum=torch.sum(y_ents[i])
            seq_len=slens[i].item()
            if ysum==0:
                neg_loss+=(-1 * 0.95*logyhat_ent0[i][:seq_len]).sum() + (-1 * 0.05 * logyhat_ent1[i][:seq_len]).sum()
                neg_cnt+=slens[i]
                continue
            for j in range(seq_len):
                if y_ents[i][j]==0:
                    neg_loss+=((-1 * 0.95*logyhat_ent0[i][j]) + (-1 * 0.05 * logyhat_ent1[i][j])).sum()
                    neg_cnt+=1
                if y_ents[i][j]==1:
                    pos_loss+=((-1 * 0.95*logyhat_ent1[i][j]) + (-1 * 0.05 * logyhat_ent0[i][j])).sum()
                    pos_cnt+=1
                    
        pos_loss/=max(1, pos_cnt)
        neg_loss/=max(1, neg_cnt)
        loss=pos_loss+neg_loss
        return loss
    
    def get_bios_loss(self, slens, y_bios, logyhat_bios):
        batch_size=y_bios.shape[0]
        loss=torch.tensor(0.0)
        pos_loss=torch.tensor(0.0); neg_loss=torch.tensor(0.0)
        pos_cnt=0; neg_cnt=0

        for i in range(batch_size):
            seq_len=slens[i].item()
            ysum=torch.sum(y_bios[i][:seq_len].sum())
            if ysum==0:
                neg_loss+=(-1*0.95*logyhat_bios[i][:seq_len][0].sum()) + \
                (-1*0.025*logyhat_bios[i][:slens[i]][1].sum()) + \
                (-1*0.025*logyhat_bios[i][:slens[i]][2].sum())
                neg_cnt+=slens[i].item()
                continue
                
            for j in range(slens[i]):
                if y_bios[i][j]==0:
                    neg_loss+=(-1*0.95*logyhat_bios[i][j][0].sum()) + \
                    (-1*0.025*logyhat_bios[i][j][1].sum()) + \
                    (-1*0.025*logyhat_bios[i][j][2].sum())
                    neg_cnt+=1
                    
                elif y_bios[i][j]==1:
                    pos_loss+=(-1*0.9*logyhat_bios[i][j][1].sum()) + \
                    (-1*0.08*logyhat_bios[i][j][2].sum()) + \
                    (-1*0.02*logyhat_bios[i][j][0].sum())
                    pos_cnt+=1
                    
                elif y_bios[i][j]==2:
                    pos_loss+=(-1*0.9*logyhat_bios[i][j][2].sum()) + \
                    (-1*0.08*logyhat_bios[i][j][1].sum()) + \
                    (-1*0.02*logyhat_bios[i][j][0].sum())
                    pos_cnt+=1
                    pos_cnt+=1
        pos_loss/=max(1, pos_cnt)
        neg_loss/=max(1, neg_cnt)
        
        loss=pos_loss+neg_loss
        return loss
    
    def get_constraint_loss(self, slens, logyhat_bios, logyhat_ent0, logyhat_ent1):
        loss=torch.tensor(0.0)
        batch_size=logyhat_bios.shape[0]
        for i in range(batch_size):
            cur_loss=torch.tensor(0.0)
            for j in range(slens[i]):
                cur_loss+=torch.abs(logyhat_bios[i][j][0] - logyhat_ent0[i][j][0])
                cur_loss+=torch.abs(logyhat_bios[i][j][1] + logyhat_bios[i][j][2] - logyhat_ent1[i][j][0])
            loss+=(cur_loss/max(1, slens[i]))
        loss/=max(batch_size, 1)
        return loss
    
    def train_ops(self, mbs):
        self.model.train()
        
        (X, X_embedd, y_ents, y_bios, slens)=mbs
        batch_size=X.shape[0]
        seq_len=X.shape[1]
        
        (yhat_ent, yhat_bios)=self.model(X, X_embedd)
        logyhat_ent1=torch.log(self.sigmoid(yhat_ent))
        logyhat_ent0=torch.log(1-self.sigmoid(yhat_ent))
        logyhat_bios=self.logsoftmax(yhat_bios)
        
        
        loss_ents=self.get_entity_loss(slens, y_ents, logyhat_ent0, logyhat_ent1 )
        loss_bios=self.get_bios_loss(slens, y_bios, logyhat_bios)
        loss_constraint=self.get_constraint_loss(slens, logyhat_bios, logyhat_ent0, logyhat_ent1)
        loss = (loss_ents+loss_bios+loss_constraint)/3
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), config['grad_clip_val'])
        self.optimizer.step()
        self.schedular.step()
        return (loss.item(), loss_ents.item(), loss_bios.item(), loss_constraint.item())
    
    def finetune(self):
        best_loss=None
        t1=time.time()
        self.model.train()
        while self.iter_count<config['num_iterations']:
            self.iter_count+=1
            mbs=next(self.labeledIterator)
            (loss, loss_ents, loss_bios, loss_constraint) = self.train_ops(mbs)
            
            self.total_loss+=loss
            self.ent_loss+=loss_ents
            self.bios_loss+=loss_bios
            self.constraint_loss+=loss_constraint

            
            self.loss_.append(loss)
            self.ent_loss_.append(loss_ents)
            self.bios_loss_.append(loss_bios)
            self.constraint_loss_.append(loss_constraint)
        
            if self.iter_count%config['eval_every']==0:
                print("Evaluating:")
                print("======"*10)
                (cm1,cm2, micro_fscore_ent, micro_fscore_bio, macro_fscore_ent, macro_fscore_bio)=evaluate(model)
                print("Confusion Matrix:")
                print(cm1)
                print(cm2)
                
                print("Micro F-Score ==> {:.3f} -- {:.3f}".format(micro_fscore_ent, micro_fscore_bio))
                print("Macro F-Score ==> {:.3f} -- {:.3f}".format(macro_fscore_ent, macro_fscore_bio))
                print("======"*10)
                print()
                torch.save(model, 'model.pt')
                
            if self.iter_count%config['print_every']==0:
                t2=time.time()
                print("===="*10)
                print("Iteration:{} | Time Taken:{:.1f}".format(self.iter_count, (t2-t1)/60))
                print("Total loss:{:.4f}".format(self.total_loss/config['print_every']))
                print("Entity loss:{:.4f}".format(self.ent_loss/config['print_every']))
                print("Bios loss:{:.4f}".format(self.bios_loss/config['print_every']))
                print("Constraint loss:{:.4f}".format(self.constraint_loss/config['print_every']))
                t1=time.time()
                self.total_loss=0
                self.ent_loss=0
                self.bios_loss=0
                self.constraint_loss=0
                print()
        print("Evaluating:")
        print("======"*10)
        (cm1,cm2, micro_fscore_ent, micro_fscore_bio, macro_fscore_ent, macro_fscore_bio)=evaluate(model)
        print("Confusion Matrix:")
        print(cm1)
        print(cm2)

        print("Micro F-Score ==> {:.3f} -- {:.3f}".format(micro_fscore_ent, micro_fscore_bio))
        print("Macro F-Score ==> {:.3f} -- {:.3f}".format(macro_fscore_ent, macro_fscore_bio))
        print("======"*10)
        print() 

In [None]:
model=torch.load(config['model_path'])
model=model.to(device)
#trainer=Trainer(model)
#trainer.finetune()

# Inference

In [None]:
train_df=pd.read_csv('../input/coleridgeinitiative-show-us-the-data/train.csv')
train_df=train_df.groupby('Id')[['dataset_title', 'dataset_label', 'cleaned_label']].agg(list).reset_index()

def get_all_datalabels(row):
    dataset_title=row['dataset_title']
    dataset_label=row['dataset_label']
    cleaned_label=row['cleaned_label']
    
    all_labels=list(set(dataset_label+dataset_title+dataset_title))
    return all_labels
    
train_df['all_datalabels']=train_df.apply(get_all_datalabels, axis=1)
train_df.head()

In [None]:
database=set()
for labels in train_df['all_datalabels'].values:
    database=database.union(labels)
print('Number Of Datasets:', len(database))

In [None]:
class Inference:
    def __init__(self, model):
        self.model=model
        self.sigmoid=nn.Sigmoid()
    def getPredictionLabels(self, dataloader):
        self.model.eval()
        predLabels=[]
        for (slen, X, X_embedd)  in test_dataloader:
            with torch.no_grad():
                X=X.to(device)
                X_embedd=X_embedd.to(device)
                yent, ybios=self.model(X, X_embedd)
                yent=yent.detach().cpu()
                ybios=ybios.detach().cpu()
                
                yent=self.sigmoid(yent)
                batch_size=X.shape[0]
                ybios=ybios.argmax(dim=-1).numpy()
                #ybios=ybios.softmax(dim=-1)
                for i in range(batch_size):
                    labels=[]
                    for j in range(slen[i]):
                        label1=(yent[i][j]>0.6).to(int)
                        label2=ybios[i][j]
                        #for k in range(3):
                            #if ybios[i][j][k]>=0.5:
                                #label2=k
                        if label1==1 and (label2==1 or label2==2):
                            labels.append(label2)
                        else:
                            labels.append(0)
                    predLabels.append(labels)
        return predLabels

# Clean Dataset

In [None]:
def jaccard_similarity(s1, s2):
    l1 = s1.split(" ")
    l2 = s2.split(" ")    
    intersection = len(list(set(l1).intersection(l2)))
    union = (len(l1) + len(l2)) - intersection
    return float(intersection) / union

In [None]:
class CleanDataset:
    def __init__(self):
        pass
    def clean_text(self, txt):
        return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower().strip())
    
    def is_label_eligible(self, label):
        label=label.strip()
        if (label.isnumeric() or label in string.punctuation):
            return False
        if len(label)==1:
            return False
        if len(label)>1 and label[0].isnumeric():
            return False
        if label in ['of', 'and' ,'in', 'or', 'america', 'american', 
                     'england', 'study', 'survey', 'data']:
            return False
        return True
    
    def predictString(self, row):
        labels=row['labels']
        words=row['sentence_words']
        s=-1;e=-1;

        datalabels=[]
        for i, label in enumerate(labels):
            if i==0:
                continue;
                
            if (s==-1) and (labels[i]==1) and (len(words[i])>0) and words[i][0].isupper():
                s=i;e=i
            elif labels[i]==2:
                e=i
            elif labels[i]==0 and s!=-1:
                datalabels.append(' '.join(words[s:e+1]) )
                s=-1; e=-1
        if s!=-1:
            datalabels.append(' '.join(words[s:e+1]) )
        return datalabels

    def process_prediction(self, datalabels, matched_labels):
        lst=[]
        for label in datalabels:
            lst +=label
        lst+=matched_labels
        lst=[self.clean_text(l) for l in lst]
        lst=[l for l in lst if self.is_label_eligible(l)]
        lst=list(set([l.strip() for l in lst]))
        
        filtered_labels=[]
        for label in sorted(lst, key=len):
            label = clean_text(label)
            if len(filtered_labels) == 0 or all(jaccard_similarity(label, got_label) < 0.75 for got_label in filtered_labels):
                filtered_labels.append(label)

        filtered_labels.sort()
        filtered_labels=filtered_labels[:100]
        return '|'.join(filtered_labels)
    
    def get_datasets(self, test_df, pred_labels, matched_labels):
        test_df['labels']=pred_labels
        test_df['PredictionString']=test_df.apply( self.predictString, axis=1)
        test_df=test_df.groupby('Id')[['PredictionString']].agg(list).reset_index()
        test_df['PredictionString']=test_df['PredictionString'].apply(self.process_prediction, args=(matched_labels, ))
        return test_df

In [None]:
inferObj=Inference(model)
cleanDataset=CleanDataset()

In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower())

In [None]:
all_data=[]
for idx, pub_id in enumerate(os.listdir(config['test_dir'])):
    sentences=[]
    paper=get_publications_data(pub_id, config['test_dir'])
    
    #(all_labels, new_entities)=get_datalabels(pub_id)
    #datalabels=all_labels+new_entities
    #datalabels=[clean_text(label).strip() for label in datalabels]
    datalabels=[]
    for section in paper:
        sentences+=get_sentences(pub_id, section['text'])
    if len(sentences)==0:
        all_data.append({'Id': pub_id.replace('.json', ''), "PredictionString": ''})
        continue
    
    data=[]
    for sentence in sentences:
        data.append({
            'Id': pub_id.replace('.json', ''),
            'sentence_words': sentence
        })
    test_df=pd.DataFrame.from_dict(data)
    test_dataset=TestDataset(test_df,config['max_seq_len'], config['glove_dim'])
    test_dataloader=torch.utils.data.DataLoader(test_dataset, shuffle=False, 
                                                batch_size=config['test_batch_size'])
    pred_labels=inferObj.getPredictionLabels(test_dataloader)
    if len(pred_labels) < len(test_df):
        for _ in range(len(test_df) - len(pred_labels)):
            pred_labels.append([])
    test_df=cleanDataset.get_datasets(test_df, pred_labels, datalabels)
    all_data.append({'Id': pub_id.replace(".json", ""), "PredictionString": test_df['PredictionString'].values[0]})

In [None]:
all_pub_ids=set()
pred_pub_ids=set()

for pub_id in os.listdir(config['test_dir']):
    all_pub_ids.add(pub_id.replace('.json', ''))

for data in all_data:
    pred_pub_ids.add(data['Id'])

remain_pub_ids=all_pub_ids.difference(pred_pub_ids)
for pub_id in remain_pub_ids:
    all_data.append({'Id': pub_id.replace(".json", ""), "PredictionString":''})


submission_df=pd.DataFrame.from_dict(all_data)
submission_df.to_csv('submission.csv', index=False)
submission_df.head()

In [None]:
for idx, p in enumerate(submission_df.PredictionString.values):
    if idx==4:
        break
    print(p)
    print('=='*10)

In [None]:
for idx, pub_id in enumerate(os.listdir(config['test_dir'])):
    if pub_id!='2f392438-e215-4169-bebf-21ac4ff253e1.json':
        continue
    paper=get_publications_data(pub_id, config['test_dir'])
    for section in paper:
        text=section['text']
        for s in sent_tokenize(text):
            if 'the nces co' in s.lower():
                print(s)
                print()