# Imports

In [1]:
import os
import gc
import pickle

In [2]:
from flair.data import Sentence
from flair.models import SequenceTagger
from flair.models import TextClassifier

from flair.data import TaggedCorpus
from flair.data_fetcher import  NLPTaskDataFetcher, NLPTask

import torch

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [4]:
from pymongo import MongoClient

client = MongoClient()

db = client['glvis_db']

# Extract hidden representations from flair's pretrained NER model

In [6]:
ner_tagger = SequenceTagger.load('ner')

2019-03-23 17:18:52,068 loading file /home/snie/.flair/models/en-ner-conll03-v0.4.pt


In [7]:
ner_tagger

SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): WordEmbeddings()
    (list_embedding_1): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.05)
        (encoder): Embedding(300, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=300, bias=True)
      )
    )
    (list_embedding_2): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.05)
        (encoder): Embedding(300, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=300, bias=True)
      )
    )
  )
  (word_dropout): WordDropout()
  (locked_dropout): LockedDropout()
  (embedding2nn): Linear(in_features=4196, out_features=4196, bias=True)
  (rnn): LSTM(4196, 256, bidirectional=True)
  (linear): Linear(in_features=512, out_features=20, bias=True)
)

In [8]:
corpus = NLPTaskDataFetcher.load_corpus(NLPTask.CONLL_03, base_path='data/conll/')

2019-03-23 17:19:06,783 Reading data from data/conll/conll_03
2019-03-23 17:19:06,784 Train: data/conll/conll_03/eng.train
2019-03-23 17:19:06,784 Dev: data/conll/conll_03/eng.testa
2019-03-23 17:19:06,784 Test: data/conll/conll_03/eng.testb


In [9]:
tag_dictionary = corpus.make_tag_dictionary(tag_type='ner')

In [10]:
len(corpus.get_all_sentences())

22137

## First linear layer

In [5]:
db_col = db['flair_ner_embedding2nn']

In [None]:
db_col.drop()

In [None]:
for i, sentence in enumerate(corpus.get_all_sentences()):
    print(f'Start sentence {i}')
    
    # Define hook to get intermediate values
    hidden_states = torch.zeros(len(sentence), 1, 4196)
    def hook(m, i):
        hidden_states.copy_(i[0].data)
                                
    h = ner_tagger.embedding2nn.register_forward_pre_hook(hook)
                                
    ner_tagger.predict(sentence)
                                
    spans = sentence.get_spans('ner')
                                
    # Informaction to store: the named entities, their predicted labels, probabilities and hidden states
    # If there are multiple words for one entity, take the average value of hidden states
    # and record the number of words in the entity
    
    for span in spans:
        entry = {}
        entry['text'] = span.text
        entry['tag'] = span.tag
        entry['score'] = span.score
        entry['token_num'] = len(span.tokens)
        
        idx = [token.idx-1 for token in span.tokens]
        entry['embedding2nn'] = hidden_states[idx, :, :].mean(dim=0).squeeze().tolist()
        
        db_col.insert_one(entry)
    
    h.remove()
    
    print(f'Finish sentence {i}')

## Last linear layer

In [None]:
db_col = db['flair_ner_linear']

In [None]:
for i, sentence in enumerate(corpus.get_all_sentences()):
    print(f'Start sentence {i}')
    
    # Define hook to get intermediate values
    hidden_states = torch.zeros(len(sentence), 1, 512)
    def hook(m, i):
        hidden_states.copy_(i[0].data)
                                
    h = ner_tagger.linear.register_forward_pre_hook(hook)
                                
    ner_tagger.predict(sentence)
                                
    spans = sentence.get_spans('ner')
                                
    # Informaction to store: the named entities, their predicted labels, probabilities and hidden states
    # If there are multiple words for one entity, take the average value of hidden states
    # and record the number of words in the entity
    
    for span in spans:
        entry = {}
        entry['text'] = span.text
        entry['tag'] = span.tag
        entry['score'] = span.score
        entry['token_num'] = len(span.tokens)
        
        idx = [token.idx-1 for token in span.tokens]
        entry['linear_layer_state'] = hidden_states[idx, :, :].mean(dim=0).squeeze().tolist()
        
        db_col.insert_one(entry)
    
    h.remove()
    
    print(f'Finish sentence {i}')