In [70]:
import numpy as np 
import pandas as pd 
from tqdm import tqdm, trange
import time
import torch 
from torch.optim import Adam 
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences 
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForTokenClassification
from seqeval.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [2]:
data = pd.read_csv("ner_dataset.csv", encoding="latin1").fillna(method='ffill')
data.tail(10)

Unnamed: 0,Sentence #,Word,POS,Tag
1048565,Sentence: 47958,impact,NN,O
1048566,Sentence: 47958,.,.,O
1048567,Sentence: 47959,Indian,JJ,B-gpe
1048568,Sentence: 47959,forces,NNS,O
1048569,Sentence: 47959,said,VBD,O
1048570,Sentence: 47959,they,PRP,O
1048571,Sentence: 47959,responded,VBD,O
1048572,Sentence: 47959,to,TO,O
1048573,Sentence: 47959,the,DT,O
1048574,Sentence: 47959,attack,NN,O


In [90]:
data['Tag'].unique()

array(['O', 'B-geo', 'B-gpe', 'B-per', 'I-geo', 'B-org', 'I-org', 'B-tim',
       'B-art', 'I-art', 'I-per', 'I-gpe', 'I-tim', 'B-nat', 'B-eve',
       'I-eve', 'I-nat'], dtype=object)

In [3]:
class SentenceGetter(object):
    def __init__(self, data):
        self.n_sent = 1
        self.data = data 
        self.empty = False 
        agg_func = lambda s: [(w,p,t) for w,p,t in zip(s['Word'].values.tolist(),
                                                      s['POS'].values.tolist(),
                                                      s['Tag'].values.tolist())]
        self.grouped = self.data.groupby("Sentence #").apply(agg_func)
        self.sentences = [s for s in self.grouped]
        
        
    def get_next(self):
        try:
            s = self.grouped["Sentence: {}".format(self.n_sent)]
            self.n_sent += 1 
            return s
        except:
            return None

In [4]:
getter = SentenceGetter(data)

In [5]:
sentences = [' '.join([y[0] for y in x]) for x in getter.sentences]

In [6]:
sentences[0]

'Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country .'

In [41]:
labels = [[y[2] for y in x] for x in getter.sentences]

In [8]:
labels[0]

['O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-geo',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-geo',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-gpe',
 'O',
 'O',
 'O',
 'O',
 'O']

In [88]:
tags_vals = list(set(data['Tag'].values))
tag2idx = {t:i for i,t in enumerate(tags_vals)}
idx2tag = {i: t for t, i in tag2idx.items()}
tag2idx

{'B-tim': 0,
 'I-nat': 1,
 'O': 2,
 'I-eve': 3,
 'I-per': 4,
 'B-geo': 5,
 'I-gpe': 6,
 'B-nat': 7,
 'I-org': 8,
 'B-per': 9,
 'B-gpe': 10,
 'B-eve': 11,
 'B-org': 12,
 'I-tim': 13,
 'B-art': 14,
 'I-geo': 15,
 'I-art': 16}

In [33]:
MAX_LEN = 75
bs = 32

In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_gpu = torch.cuda.device_count()

In [35]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [27]:
tokenized_texts = [tokenizer.tokenize(sent) for sent in sentences]

In [36]:
input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
                         maxlen=MAX_LEN, dtype='long', truncating='post', padding='post')

In [37]:
input_ids[0]

array([ 5190,  1997, 28337,  2031,  9847,  2083,  2414,  2000,  6186,
        1996,  2162,  1999,  5712,  1998,  5157,  1996, 10534,  1997,
        2329,  3629,  2013,  2008,  2406,  1012,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0])

In [42]:
tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in labels],
                    maxlen=MAX_LEN, value=tag2idx['O'], padding='post',
                    dtype='long', truncating='post')

In [43]:
attention_masks = [[float(i>0) for i in ii] for ii in input_ids]

In [81]:
input_ids_u, input_ids_v, tags_u, tags_v = train_test_split(input_ids,
                                                           tags,
                                                           random_state=2018,
                                                           test_size=0.015)

attention_masks_u, attention_masks_v, _, _ = train_test_split(attention_masks, input_ids,
                                            random_state=2018, test_size=0.015)


tr_inputs, val_inputs, tr_tags, val_tags = train_test_split(input_ids_v,
                                                           tags_v,
                                                           random_state=2018,
                                                           test_size=0.4)

tr_masks, val_masks, _, _ = train_test_split(attention_masks_v, input_ids_v,
                                            random_state=2018, test_size=0.4)


In [82]:
tr_inputs = torch.tensor(tr_inputs)
val_inputs = torch.tensor(val_inputs)
tr_tags = torch.tensor(tr_tags)
val_tags = torch.tensor(val_tags)
tr_masks = torch.tensor(tr_masks)
val_masks = torch.tensor(val_masks)

In [83]:
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)


In [84]:
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=len(tag2idx))

In [22]:
# model.cuda();

In [85]:
FULL_FINETUNING = False
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias','gamma','beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate':0.01},
        {'params':[p for n,p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{'params':[p for n,p in param_optimizer]}]
optimizer = Adam(optimizer_grouped_parameters, lr=3e-5)

In [62]:
from seqeval.metrics import f1_score

def flat_accuracy(valid_tags, pred_tags):

    """
    Define a flat accuracy metric to use while training the model.
    """

    return (np.array(valid_tags) == np.array(pred_tags)).mean()

def annot_confusion_matrix(valid_tags, pred_tags):

    """
    Create an annotated confusion matrix by adding label
    annotations and formatting to sklearn's `confusion_matrix`.
    """

    # Create header from unique tags
    header = sorted(list(set(valid_tags + pred_tags)))

    # Calculate the actual confusion matrix
    matrix = confusion_matrix(valid_tags, pred_tags, labels=header)

    # Final formatting touches for the string output
    mat_formatted = [header[i] + "\t" + str(row) for i, row in enumerate(matrix)]
    content = "\t" + " ".join(header) + "\n" + "\n".join(mat_formatted)

    return content


In [59]:
tokenizer.vocab["[PAD]"]

0

In [98]:
len(tokenizer.vocab)

30522

In [89]:
epochs = 2
max_grad_norm = 1.0
pad_tok = tokenizer.vocab["[PAD]"]
sep_tok = tokenizer.vocab["[SEP]"]
cls_tok = tokenizer.vocab["[CLS]"]
o_lab = tag2idx["O"]
verbose = True
epoch = 0

for _ in trange(epochs, desc='Epoch'):
    epoch += 1
    print('starting train loop')
    model.train()
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], [] 
    
    for step, batch in enumerate(train_dataloader):
        #add batch to gpu (cpu probably)
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        #forward pass 
        outputs = model(
            b_input_ids,
            token_type_ids=None,
            attention_mask=b_input_mask,
            labels=b_labels,
        )
        loss, tr_logits = outputs[:2]
        print('backward pass')
        #backward pass
        loss.backward()
        #compute train loss
        tr_loss += loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1
        #subset out unwanted predictions on CLS/PAD/SEP tokens
        preds_mask = (
            (b_input_ids != cls_tok) & (b_input_ids != pad_tok) & (b_input_ids != sep_tok)
        )
        
        tr_logits = tr_logits.detach().cpu().numpy()
        tr_label_ids = torch.masked_select(b_labels, (preds_mask == 1))
        tr_batch_preds = np.argmax(tr_logits[preds_mask.squeeze()], axis=1)
        tr_batch_labels = tr_label_ids.to('cpu').numpy()
        tr_preds.extend(tr_batch_preds)
        tr_labels.extend(tr_batch_labels)
        
        #compute training accuracy
        tmp_tr_accuracy = flat_accuracy(tr_batch_labels, tr_batch_preds)
        tr_accuracy += tmp_tr_accuracy
        
        #gradient clipping
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=max_grad_norm
        )
        
        #update parameters
        optimizer.step()
        model.zero_grad()
        
    tr_loss = tr_loss / nb_tr_steps
    tr_accuracy = tr_accuracy / nb_tr_steps
    
    #print training loss and accuracy per epoch 
    print(f"Train loss: {tr_loss}")
    print(f"Train accuracy: {tr_accuracy}")
    
    #validation loop 
    print('starting validation loop')
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions, true_labels = [], []
    
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            outputs = model(
                b_input_ids,
                token_type_ids=None,
                attention_mask=b_input_mask,
                labels=b_labels,
            )
            tmp_eval_loss, logits = outputs[:2]
        
        #subset out unwanted predictions on cls/pad/sep tokens
        preds_mask = (
            (b_input_ids != cls_tok) & (b_input_ids != pad_tok) & (b_input_ids != sep_tok)
        )
        
        logits = logits.detach().cpu().numpy()
        label_ids = torch.masked_select(b_labels, (preds_mask == 1))
        val_batch_preds = np.argmax(logits[preds_mask.squeeze()], axis=1)
        val_batch_labels = label_ids.to('cpu').numpy()
        predictions.extend(val_batch_preds)
        true_labels.extend(val_batch_labels)
        
        tmp_eval_accuracy = flat_accuracy(val_batch_labels, val_batch_preds)
        
        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy 
        
        nb_eval_examples += b_input_ids.size(0)
        nb_eval_steps += 1
        
    #evaluate loss, acc, conf matrix and class report on devset 
    pred_tags = [idx2tag[i] for i in predictions]
    valid_tags = [idx2tag[i] for i in true_labels]
    cl_report = classification_report(valid_tags, pred_tags)
    conf_mat = annot_confusion_matrix(valid_tags, pred_tags)
    eval_loss = eval_loss / nb_eval_steps 
    eval_accuracy = eval_accuracy / nb_eval_steps 
    
    #report metrics 
    print(f"Validation loss: {eval_loss}")
    print(f"Validation Accuracy: {eval_accuracy}")
    print(f"Classificatiotn Report:\m {cl_report}")
    if verbose:
        print(f"Confusion Matrix:\n {conf_mat}")
        
    #save model and optimizer state_dict following every epoch 
    save_path = f"models/train_checkpoint_epoch_{epoch}.tar"
    torch.save(
        {
            "epoch":epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_loss": tr_loss,
            "train_acc": tr_accuracy,
            "eval_loss":eval_loss,
            "eval_acc": eval_accuracy,
            "classification_report": cl_report,
            "confusion_matrix": conf_mat,
        },
        save_path,
    )
    print(f"Checkpoint saved to {save_path}")







Epoch:   0%|          | 0/2 [00:00<?, ?it/s][A[A[A[A[A[A

starting train loop
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
backward pass
Train loss: 2.834966387067522
Train accuracy: 0.03663605828299611
starting validation loop








Epoch:  50%|█████     | 1/2 [01:47<01:47, 107.30s/it][A[A[A[A[A[A

Validation loss: 2.7709459728664823
Validation Accuracy: 0.05559395054712171
Classificatiotn Report:\m            precision    recall  f1-score   support

      geo       0.03      0.03      0.03       235
      tim       0.01      0.02      0.01       125
      org       0.01      0.07      0.02       120
      per       0.00      0.02      0.01       106
      gpe       0.02      0.08      0.04        89
      eve       0.00      0.00      0.00         2
      nat       0.00      0.00      0.00         1

micro avg       0.01      0.04      0.01       678
macro avg       0.02      0.04      0.02       678

Confusion Matrix:
 	B-art B-eve B-geo B-gpe B-nat B-org B-per B-tim I-art I-eve I-geo I-gpe I-nat I-org I-per I-tim O
B-art	[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
B-eve	[0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0]
B-geo	[ 2 10  0  9  4  2 13  1 88  2 14  2  4 51  1 21 11]
B-gpe	[ 0  1  0  6  0  0  5  0 49  1  2  2  0 12  0  3  8]
B-nat	[0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]
B-org	[ 1  5  1  2  0 







Epoch: 100%|██████████| 2/2 [03:37<00:00, 108.06s/it][A[A[A[A[A[A

Validation loss: 2.676004065407647
Validation Accuracy: 0.12658010426215843
Classificatiotn Report:\m            precision    recall  f1-score   support

      geo       0.02      0.02      0.02       235
      tim       0.01      0.02      0.01       125
      org       0.01      0.06      0.02       120
      per       0.00      0.02      0.01       106
      gpe       0.02      0.06      0.03        89
      eve       0.00      0.00      0.00         2
      nat       0.00      0.00      0.00         1

micro avg       0.01      0.03      0.01       678
macro avg       0.01      0.03      0.02       678

Confusion Matrix:
 	B-art B-eve B-geo B-gpe B-nat B-org B-per B-tim I-art I-eve I-geo I-gpe I-nat I-org I-per I-tim O
B-art	[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
B-eve	[0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1]
B-geo	[ 2  9  0  7  4  2 13  1 81  2 12  2  4 47  1 18 30]
B-gpe	[ 0  1  0  4  0  0  5  0 43  0  2  2  0 11  0  2 19]
B-nat	[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
B-org	[ 1  5  1  2  0  

In [58]:
device

device(type='cpu')

In [64]:
print(f"the num is {'d'}")

the num is d


In [86]:
len(train_dataloader)

14

In [112]:
model_name = 'bert-base-uncased'
checkpoint = torch.load('models/train_checkpoint_epoch_2.tar')
model_state_dict = checkpoint["model_state_dict"]
model = BertForTokenClassification.from_pretrained(
    model_name, num_labels=model_state_dict['classifier.weight'].shape[0]
)
model.load_state_dict(model_state_dict)
tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)

In [133]:
input_text = "Hitler was right."
encoded_text = tokenizer.encode(input_text)
wordpieces = [tokenizer.decode(tok).replace(' ','') for tok in encoded_text]

input_ids = torch.tensor(encoded_text).unsqueeze(0).long()
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0).long()
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
scores = scores.detach().numpy()

label_ids = np.argmax(scores, axis=2)
preds = [idx2tag[i] for i in label_ids[0]]

wp_preds = list(zip(wordpieces, preds))
toplevel_preds = [pair[1] for pair in wp_preds if '##' not in pair[0]]
str_rep = ' '.join([t[0] for t in wp_preds]).replace(' ##', '').split()

if len(str_rep) == len(toplevel_preds):
    preds_final = list(zip(str_rep, toplevel_preds))
    b_preds_df = pd.DataFrame(preds_final)
    b_preds_df.columns = ['text','pred']
    for tag in ['PER','LOC','ORG','MISC']:
        b_preds_df[f"b_pred_{tag.lower()}"] = np.where(
            b_preds_df['pred'].str.contains(tag.lower()), 1, 0
        )
else:
    print('could not match up output string with preds.')
b_preds_df.loc[:,'text':]

Unnamed: 0,text,pred,b_pred_per,b_pred_loc,b_pred_org,b_pred_misc
0,[CLS],B-eve,0,0,0,0
1,[UNK],O,0,0,0,0
2,was,I-org,0,0,1,0
3,right,I-org,0,0,1,0
4,.,I-org,0,0,1,0
5,[SEP],I-org,0,0,1,0


In [128]:
idx2tag

{0: 'B-tim',
 1: 'I-nat',
 2: 'O',
 3: 'I-eve',
 4: 'I-per',
 5: 'B-geo',
 6: 'I-gpe',
 7: 'B-nat',
 8: 'I-org',
 9: 'B-per',
 10: 'B-gpe',
 11: 'B-eve',
 12: 'B-org',
 13: 'I-tim',
 14: 'B-art',
 15: 'I-geo',
 16: 'I-art'}