<a href="https://colab.research.google.com/github/anon/ILCiteR/blob/main/SciBERT_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Map every text context within all_contexts
# To their SciBERT embedding

In [2]:
%%capture
!pip install transformers

In [3]:
import json
import pickle
import tqdm
from transformers import BertTokenizer, BertModel
import torch
import numpy

In [4]:
location = 'drive/My Drive/cite_reco_s2orc/full/experiments/'
test_count = 500

In [5]:
def get_all_contexts():
  global location, test_count

  with open(location + 'all_contexts_' + str(test_count) + '.json', 'r+') as f:
    all_contexts = json.load(f)

  return all_contexts

def dump_context_to_embeddings(context_to_embeddings):
  global location, test_count

  with open(location + 'context_to_embed_' + str(test_count) + '.pkl', 'wb') as f:
    pickle.dump(context_to_embeddings, f)

  return

In [6]:
%%capture

# Load SciBERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_cased')
model = BertModel.from_pretrained('allenai/scibert_scivocab_cased')

In [7]:
def SciBERT_embeddings(contexts_list):
  embeddings_list = []

  for context in tqdm.tqdm(contexts_list):
    # Tokenize context
    tokenized_input = tokenizer(context, padding = True, truncation = True, max_length = 128, return_tensors = 'pt')

    # Forward pass through the SciBERT model to obtain embeddings
    with torch.no_grad():
        outputs = model(**tokenized_input)

    # Extracting the [CLS] token's from the final hidden layer embedding for each context
    cls_embedding = outputs.last_hidden_state[:, 0, :].numpy().squeeze()
    # Shape: (768, )

    embeddings_list.append(cls_embedding)

  return embeddings_list

In [8]:
def create_context_to_embed_map(contexts_list, embeddings_list):
  context_to_embed_map = {}

  for index, context in enumerate(contexts_list):
    context_to_embed_map[context] = embeddings_list[index]

  return context_to_embed_map

In [9]:
# Unit Test

run_unit_test = False

if run_unit_test:
  contexts_list = ['BERT', 'GPT', 'Large Language Model']
  embeddings_list = SciBERT_embeddings(contexts_list)
  context_to_embed_map = create_context_to_embed_map(contexts_list, embeddings_list)
  dump_context_to_embeddings(context_to_embed_map)

In [10]:
# Generating SciBERT embeddings for all contexts

all_contexts = get_all_contexts()
embeddings_list = SciBERT_embeddings(all_contexts)
context_to_embed_map = create_context_to_embed_map(all_contexts, embeddings_list)
dump_context_to_embeddings(context_to_embed_map)

tqdm.tqdm.write('')
tqdm.tqdm.write('Completed.')

100%|██████████| 64765/64765 [2:40:23<00:00,  6.73it/s]



Completed.


In [11]:
# That's it