# Extract hidden representations from flair's pretrained Chunking model

In [6]:
from flair.data import Sentence
from flair.models import SequenceTagger
from flair.models import TextClassifier

from flair.data import TaggedCorpus
from flair.data_fetcher import  NLPTaskDataFetcher, NLPTask
import torch

In [1]:
from pymongo import MongoClient

client = MongoClient()

db = client['glvis_db']

# Extract hidden representations from flair's pretrained Chunking model

In [None]:
chunk_tagger = SequenceTagger.load('chunk')

In [None]:
chunk_tagger

In [None]:
corpus: TaggedCorpus = NLPTaskDataFetcher.load_corpus(NLPTask.CONLL_2000)

In [None]:
tag_type = 'np'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)

In [None]:
len(corpus.get_all_sentences())

### First linear layer (the layer after embedding)

In [None]:
db_col = db['flair_chunk_embedding2nn']

In [None]:
db_col.drop()

In [None]:
for i, sentence in enumerate(corpus.get_all_sentences()):
    print(f'Start sentence {i}')
    
    # Define hook to get intermediate values
    hidden_states = torch.zeros(len(sentence), 1, 4096)
    def hook(m, i):
        hidden_states.copy_(i[0].data)
                                
    h = chunk_tagger.embedding2nn.register_forward_pre_hook(hook)
                                
    chunk_tagger.predict(sentence)
                                
    spans = sentence.get_spans('np')
                                
    # Informaction to store: the named entities, their predicted labels, probabilities and hidden states
    # If there are multiple words for one entity, take the average value of hidden states
    # and record the number of words in the entity
    
    for span in spans:
        entry = {}
        entry['text'] = span.text
        entry['tag'] = span.tag
        entry['score'] = span.score
        entry['token_num'] = len(span.tokens)
        
        idx = [token.idx-1 for token in span.tokens]
        entry['embedding2nn'] = hidden_states[idx, :, :].mean(dim=0).squeeze().tolist()
        
        db_col.insert_one(entry)
    
    h.remove()
    
    print(f'Finish sentence {i}')

# Last linear layer

In [None]:
db_col = db['flair_chunk_linear']

In [None]:
db_col.drop()

In [None]:
for i, sentence in enumerate(corpus.get_all_sentences()):
    print(f'Start sentence {i}')
    
    # Define hook to get intermediate values
    hidden_states = torch.zeros(len(sentence), 1, 512)
    def hook(m, i):
        hidden_states.copy_(i[0].data)
                                
    h = chunk_tagger.linear.register_forward_pre_hook(hook)
                                
    chunk_tagger.predict(sentence)
                                
    spans = sentence.get_spans('np')
                                
    # Informaction to store: the named entities, their predicted labels, probabilities and hidden states
    # If there are multiple words for one entity, take the average value of hidden states
    # and record the number of words in the entity
    
    for span in spans:
        entry = {}
        entry['text'] = span.text
        entry['tag'] = span.tag
        entry['score'] = span.score
        entry['token_num'] = len(span.tokens)
        
        idx = [token.idx-1 for token in span.tokens]
        entry['linear_layer_state'] = hidden_states[idx, :, :].mean(dim=0).squeeze().tolist()
        
        db_col.insert_one(entry)
    
    h.remove()
    
    print(f'Finish sentence {i}')