# Extract ScholarBERT Embeddings for materials

In [None]:
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import numpy as np
import pickle
import torch
import csv


In [None]:
out_file = "./known.pkl"  # or "./random.pkl" or "./relevant.pkl"
model_path = "globuslabs/ScholarBERT-XL"
data_path = "./"

## 1: Define utility functions for extracting contextualized BERT embeddings

In [None]:
def get_word_idx(sent: str, word: str, encoding):
    sent = sent.lower()
    word = word.lower()
    start = sent.find(word)
    end = start + len(word)
    indices = list()
    try:
        for idx in range(len(sent)):
            word_start, word_end = encoding.word_to_chars(idx)
            if word_start >= end:
                break
            elif word_start >= start:
                indices.append(idx)
    except TypeError:
        print(word, '\t', sent)
    return indices

def get_hidden_states(encoded, token_ids_word, model, layers):
    """Push input IDs through model. Stack and sum `layers` (last four by default).
    Select only those subword token outputs that belong to our word of interest
    and average them."""
    with torch.no_grad():
        output = model(**encoded, output_hidden_states=True)
    # Get all hidden states
    states = output.hidden_states
    # Stack and sum all requested layers
    output = torch.stack([states[i] for i in layers]).sum(0).squeeze()
    # Only select the tokens that constitute the requested word
    word_tokens_output = output[token_ids_word]
    return word_tokens_output.mean(dim=0).detach().cpu().numpy()


def get_word_vector(sent, word, tokenizer, model, layers):
    """Get a word vector by first tokenizing the input sentence, getting all token idxs
    that make up the word of interest, and then `get_hidden_states`."""
    encoded = tokenizer.encode_plus(sent, is_split_into_words=False, max_length=512, 
                                    truncation=True, return_tensors="pt").to(device)
    indices = get_word_idx(sent, word, encoded)
    if indices:
        # get all token idxs that belong to the word of interest
        token_ids_word = np.where(np.isin(np.array(encoded.word_ids()), indices))
        return get_hidden_states(encoded, token_ids_word, model, layers)
    else:
        return None

## 2. Extract contextualized embeddings

In [None]:
def main(layers=None):
    if torch.cuda.is_available():
        print('Running on GPU')
        device = 'cuda'
    else:
        print('Running on CPU')
        device = 'cpu'

    # Use last four layers by default
    layers = [-4, -3, -2, -1] if layers is None else layers
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path, output_hidden_states=True)
    model.to(device)

    emb_dict = dict()
    with open(data_path, 'r') as csvfile:
        data_reader = csv.reader(csvfile)
        # This skips the header row of the CSV file.
        next(data_reader)
        lines_read = 0
        for filename, line, name, molecule, text in tqdm(data_reader):
            word_embedding = get_word_vector(text, name, tokenizer, model, layers)
            if word_embedding is not None:
                if name not in emb_dict:
                    emb_dict[name] = [1, word_embedding]
                else:
                    emb_dict[name][0] += 1
                    emb_dict[name][1] += word_embedding
            lines_read += 1
        with open(out_file, 'wb') as fp_out:
            pickle.dump(emb_dict, fp_out)
    return emb_dict