# Getting the Embedding Sequences 
For the following models:
1) NLP Baseline
2) KG Baseline 
3) STonKGs

In [1]:
# Imports 
import getpass
import os
import sys
import time

import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd 
import seaborn as sns
import torch
import umap
from matplotlib.ticker import FuncFormatter
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import BertModel, BertTokenizer

from stonkgs.constants import (
    CELL_LINE_DIR,
    CELL_TYPE_DIR,
    EMBEDDINGS_PATH,
    DISEASE_DIR,
    LOCATION_DIR,
    MISC_DIR,
    NLP_MODEL_TYPE,
    ORGAN_DIR,
    PRETRAINED_STONKGS_DUMMY_PATH,
    RANDOM_WALKS_PATH,
    SPECIES_DIR,
    VISUALIZATIONS_DIR,
)
from stonkgs.models.kg_baseline_model import _prepare_df, INDRAEntityDataset
from stonkgs.models.nlp_baseline_model import INDRAEvidenceDataset
from stonkgs.models.stonkgs_model import STonKGsForPreTraining

Record details

In [2]:
print(getpass.getuser())
print(sys.version)
print(time.asctime())

hbalabin
3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
Fri Jul  2 11:31:51 2021


## 0. Helper functions

In [3]:
def preprocess_stonkgs_data(unprocessed_df):
    sep_id = 102
    kg_name_to_idx = {key: i for i, key in enumerate(embeddings_dict.keys())}
    
    # Convert random walk sequences to list of numeric indices
    random_walk_idx_dict = {k: [kg_name_to_idx[node] for node in v] for k, v in random_walks_dict.items()}
    
    # Get the length of the text or entity embedding sequences (2 random walks + 2 = entity embedding sequence length)
    random_walk_length = len(next(iter(random_walk_idx_dict.values())))
    half_length = random_walk_length * 2 + 2
    
    # Initialize the preprocessed data
    fine_tuning_preprocessed = []

    # Log progress with a progress bar
    for _, row in tqdm(
        unprocessed_df.iterrows(),
        total=unprocessed_df.shape[0],
        desc='Preprocessing the fine-tuning dataset',
    ):
        # 1. "Token type IDs": 0 for text tokens, 1 for entity tokens
        token_type_ids = [0] * half_length + [1] * half_length

        # 2. Tokenization for getting the input ids and attention masks for the text
        # Use encode_plus to also get the attention mask ("padding" mask)
        encoded_text = tokenizer.encode_plus(
            row['evidence'],
            padding='max_length',
            truncation=True,
            max_length=half_length,
        )
        text_token_ids = encoded_text['input_ids']
        text_attention_mask = encoded_text['attention_mask']

        # 3. Get the random walks sequence and the node indices, add the SEP (usually with id=102) in between
        # Use a sequence of UNK tokens if the node is not contained in the dictionary of the nodes from pre-training
        random_w_source = random_walk_idx_dict[
            row['source']
        ] if row['source'] in random_walk_idx_dict.keys() else [unk_id] * random_walk_length
        random_w_target = random_walk_idx_dict[
            row['target']
        ] if row['target'] in random_walk_idx_dict.keys() else [unk_id] * random_walk_length
        random_w_ids = random_w_source + [sep_id] + random_w_target + [sep_id]

        # 4. Total attention mask (attention mask is all 1 for the entity sequence)
        attention_mask = text_attention_mask + [1] * half_length

        # 5. Total input_ids = half text ids + half entity ids
        input_ids = text_token_ids + random_w_ids

        # Add all the features to the preprocessed data
        fine_tuning_preprocessed.append({
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,  # Remove the MLM, ELM and NSP labels since it's not needed anymore
        })

    # Put the preprocessed data into a dataframe
    fine_tuning_preprocessed_df = pd.DataFrame(fine_tuning_preprocessed)

    return fine_tuning_preprocessed_df

## 1. Load some example sequences

In [4]:
task_dir = SPECIES_DIR
number_unique_tags = 3
dataset_version = "species_no_duplicates.tsv"
number_entries = 800

In [5]:
task_specific_dataset = pd.read_csv(os.path.join(task_dir, dataset_version), sep="\t", index_col=None)
if "Unnamed: 0" in task_specific_dataset.columns.values:
    task_specific_dataset.drop(columns=["Unnamed: 0"], inplace=True)

Filter out unseen nodes

In [6]:
embeddings_dict = _prepare_df(EMBEDDINGS_PATH)
random_walks_dict = _prepare_df(RANDOM_WALKS_PATH)

In [7]:
task_specific_dataset = task_specific_dataset[
    task_specific_dataset['source'].isin(embeddings_dict.keys()) & task_specific_dataset['target'].isin(embeddings_dict.keys())
].reset_index(drop=True)

In [8]:
task_specific_dataset['class'].value_counts()

9606     18633
10090     2857
10116      275
Name: class, dtype: int64

Sample the present classes equally

In [9]:
sampled_df = pd.DataFrame()

for cls in np.unique(task_specific_dataset["class"]):
    cls_specific_samples = task_specific_dataset[task_specific_dataset['class'] == cls].sample(
        n=number_entries//number_unique_tags)
    sampled_df = sampled_df.append(cls_specific_samples)
    
sampled_df.reset_index(drop=True, inplace=True)

In [10]:
sampled_df["class"].value_counts()

9606     266
10090    266
10116    266
Name: class, dtype: int64

## 2. Load all three models 
1) NLP Baseline

In [11]:
nlp_baseline = BertModel.from_pretrained(NLP_MODEL_TYPE)
tokenizer = BertTokenizer.from_pretrained(NLP_MODEL_TYPE, model_max_length=512)
labels = sampled_df["class"].tolist()

In [12]:
evidence = tokenizer(sampled_df.iloc[0]["evidence"], return_tensors="pt", padding='max_length', truncation=True)
dummy_nlp_example = (nlp_baseline(**evidence, output_hidden_states=True, return_dict=True).pooler_output[0], torch.tensor(labels[0]))
print(dummy_nlp_example[0].shape)
dummy_nlp_example

torch.Size([768])


(tensor([-0.0516, -0.0117, -0.9619, -0.9606,  0.9807,  0.1207, -0.0949, -0.9607,
         -0.0948, -0.0359, -0.8395, -0.2051, -0.9999,  0.8026, -0.0210,  0.0333,
          0.9846, -0.0474,  0.9717,  0.0659,  0.0233,  0.9511, -0.0293,  1.0000,
          0.0578,  0.0700,  0.0301, -0.9661, -0.0026, -0.0895, -0.0419, -0.9713,
          0.9998,  0.8023, -0.0731, -0.0128,  0.0747, -0.1218,  0.9463, -1.0000,
         -0.0105, -0.9161, -0.1222, -0.0921,  1.0000, -0.1230,  0.0273,  0.1322,
         -0.0319, -0.9620,  0.0723, -0.9997, -0.9993, -0.9608, -0.8467,  0.2483,
          0.9755,  0.0596, -0.9983,  0.0961,  0.9726,  0.0426, -0.1675, -0.1576,
         -0.1715, -0.0422, -0.9993,  0.1058,  0.0729,  0.0361,  0.0322,  0.0136,
         -0.8089, -0.3076, -0.0858,  0.0897,  0.1244,  0.9798, -0.6409, -0.8102,
          0.1011,  0.9724,  0.9803, -0.0667, -0.0280,  0.0850, -0.9993,  0.8490,
          0.0353, -0.1529,  0.9996, -0.5020, -0.0668, -0.9990, -0.9738, -0.9382,
          0.0194,  0.0484,  

2) KG Baseline (embedding dict)

Since it's based on static embeddings, we only need the "INDRAEntityDataset"

In [13]:
kg_baseline = INDRAEntityDataset(
    embeddings_dict,
    random_walks_dict,
    sampled_df["source"],
    sampled_df["target"],
    sampled_df["class"],
)

In [17]:
dummy_kg_example = (torch.max(kg_baseline[0][0], dim=0).values, kg_baseline[0][1])
print(dummy_kg_example[0].shape)
dummy_kg_example

torch.Size([768])


(tensor([1.2293, 1.0236, 1.1940, 1.1248, 1.3703, 1.5178, 1.3103, 1.0333, 1.1640,
         0.8238, 1.2085, 1.2224, 1.0132, 1.0950, 1.2097, 1.1800, 1.2709, 1.3301,
         0.8403, 1.1180, 0.9705, 0.9043, 1.4194, 1.1620, 1.8903, 1.0841, 1.6645,
         0.6901, 1.4775, 0.8453, 0.9264, 0.8676, 1.3635, 0.7040, 0.5950, 0.8344,
         1.0458, 1.0762, 0.8174, 1.6167, 1.4221, 1.2033, 1.2731, 1.1344, 1.7208,
         1.4036, 1.7233, 1.2920, 1.0149, 1.3701, 0.5808, 0.9727, 1.2350, 0.8186,
         1.0659, 1.0553, 1.6057, 1.7285, 1.2958, 1.4424, 1.2029, 1.4429, 1.0703,
         0.9583, 1.0092, 0.8927, 1.0301, 1.6472, 1.3931, 1.3301, 1.3097, 1.2086,
         1.3953, 1.0689, 1.2527, 1.3271, 1.4885, 1.1854, 1.3332, 0.9092, 1.7597,
         1.0955, 1.0021, 1.6596, 1.2740, 1.1206, 1.6186, 1.2661, 1.0492, 0.9483,
         1.8316, 0.8491, 1.4526, 1.2925, 1.2920, 0.7024, 1.2879, 1.3619, 1.5463,
         1.7380, 1.1124, 0.7886, 0.8572, 1.4648, 0.8843, 0.7579, 0.9341, 1.1778,
         1.2215, 0.7408, 1.0

3) STonKGs (LARGE)

In [14]:
stonkgs = STonKGsForPreTraining.from_pretrained(
    pretrained_model_name_or_path=PRETRAINED_STONKGS_DUMMY_PATH,
)

In [16]:
stonkgs_data = preprocess_stonkgs_data(sampled_df)

Preprocessing the fine-tuning dataset: 100%|██████████| 798/798 [00:00<00:00, 1136.91it/s]


In [17]:
data_entry = {key: torch.tensor([value]) for key, value in dict(stonkgs_data.iloc[0]).items()}
dummy_stonkgs_example = (stonkgs(**data_entry, return_dict=True).pooler_output[0], torch.tensor(labels[0]))
print(dummy_stonkgs_example[0].shape)
dummy_stonkgs_example

torch.Size([768])


(tensor([-0.1893,  0.0368,  0.1354, -0.9218, -0.0083,  0.1043, -0.9507, -0.1795,
          0.9881,  0.1231, -0.8932,  0.1262, -1.0000,  0.9992,  0.8193, -0.2461,
          0.1265,  0.2537, -0.9846,  0.4080,  0.1113,  0.4275,  0.3919,  0.3051,
         -0.6842,  0.8513, -0.9997, -0.3284, -0.9122, -0.1575, -0.9655, -0.0572,
          0.9998, -0.0074,  0.1943, -0.0920,  0.9995,  0.9998,  0.6955,  0.9929,
         -0.9966, -0.1195, -0.9285, -0.7006, -0.1586,  0.1327, -0.4811, -0.3151,
         -0.9493,  0.0792,  0.9892,  0.3913,  0.9987,  0.0928,  0.2667,  0.8162,
          0.0374,  0.1394,  0.3588, -0.1755,  0.1335,  0.2674, -0.6229, -0.0437,
          0.1301,  0.2411,  0.1050,  0.9995, -0.2979,  0.9064, -0.2565, -0.9796,
         -0.9500, -0.7333, -0.0052, -0.7721, -0.0592, -0.0552,  0.9994, -0.0421,
         -0.6122, -0.1663, -0.5688, -0.0700,  0.9960, -0.9965, -0.9714, -0.0526,
         -0.2881,  0.9996, -0.4040, -0.8263, -0.0431, -0.9957,  0.0339,  0.9985,
         -0.5564, -0.3928, -

## 3. Get the embeddings

1. NLP embeddings

In [18]:
def get_nlp_embeddings(list_of_indices):
    """Returns a list of (embedding_sequence, label) pairs."""
    all_embed_sequences = []
    
    for idx in list_of_indices:
        nlp_evidence = tokenizer(sampled_df.iloc[idx]["evidence"], return_tensors="pt", padding='max_length', truncation=True)
        nlp_hidden_states = (nlp_baseline(**nlp_evidence, output_hidden_states=True).pooler_output[0],
                             torch.tensor(sampled_df.iloc[idx]["class"]))
        all_embed_sequences.append(nlp_hidden_states)
        
    return all_embed_sequences

2. KG embeddings

In [19]:
def get_kg_embeddings(list_of_indices):
    """Returns a list of (embedding_sequence, label) pairs."""
    all_embed_sequences = []
    
    for idx in list_of_indices:
        all_embed_sequences.append((torch.max(kg_baseline[idx][0], dim=0).values, kg_baseline[idx][1]))
        
    return all_embed_sequences

3. STonKGs

In [20]:
def get_stonkgs_embeddings(list_of_indices):
    """Returns a list of (embedding_sequence, label) pairs."""
    all_embed_sequences = []
    
    for idx in list_of_indices:
        data_entry = {key: torch.tensor([value]) for key, value in dict(stonkgs_data.iloc[idx]).items()}
        stonkgs_hidden_states = (stonkgs(**data_entry, return_dict=True).pooler_output[0],
                                 torch.tensor(sampled_df.iloc[idx]["class"]))
        all_embed_sequences.append(stonkgs_hidden_states)
        
    return all_embed_sequences

Testing the functions

In [21]:
get_nlp_embeddings([1,5,15])
get_kg_embeddings([105,150,400])
get_kg_embeddings([50,20,600])





[(tensor([0.8199, 1.0339, 1.2151, 1.1758, 1.2501, 1.5135, 1.1532, 0.9345, 1.1966,
          0.7390, 1.4022, 1.2629, 1.1392, 1.0948, 0.8903, 1.1800, 0.6760, 1.3727,
          1.0618, 1.1497, 1.1216, 0.9651, 1.0477, 0.9641, 1.6039, 1.2026, 1.4565,
          0.6601, 1.0177, 0.9012, 1.1241, 0.8775, 1.3370, 0.7827, 0.9532, 1.3188,
          1.5685, 1.1096, 1.1083, 1.4297, 1.7709, 1.5323, 1.5550, 1.0358, 1.8667,
          1.3972, 1.4358, 1.2755, 1.0149, 1.3661, 0.9928, 1.0757, 1.0334, 1.2195,
          1.3659, 0.9006, 1.6616, 1.3232, 1.4256, 1.4424, 1.2029, 1.7379, 1.2585,
          0.8998, 0.8873, 1.4440, 1.0620, 1.1747, 1.1599, 1.3955, 1.5962, 0.8583,
          1.3953, 1.7699, 1.6168, 1.0860, 1.2218, 1.3588, 1.3704, 1.1862, 1.8271,
          1.1568, 1.2779, 1.9205, 0.8217, 1.3620, 1.4497, 1.2661, 0.8700, 1.2905,
          3.4755, 0.9639, 1.5833, 1.4184, 0.8765, 0.8023, 1.2879, 1.3099, 1.5188,
          1.7380, 1.1449, 0.8658, 0.7098, 1.5686, 0.8454, 1.0747, 0.9146, 1.4942,
          1.1165