In [1]:
import pandas as pd
import numpy as np
import altair as alt
from random import sample
import torch
from tqdm import tqdm

from doc_embed_torch import DocumentEmbeddingTrainer, load_run_config, VOCAB_SIZE

In [2]:
run_code = "CjP6bpvz"
trainer = DocumentEmbeddingTrainer(run_code=run_code)
trainer.load_mlm(run_code, VOCAB_SIZE)

Preparing the masked dataset ...
Done preparing the masked dataset.
Preparing the model for quantization ...


In [5]:
indices = sample(range(len(trainer.train_dataset)), 100)
print(indices[:10])

[47655, 67014, 37460, 45467, 34703, 6515, 2092, 25410, 13217, 33236]


In [8]:
distances = dict()
embeddings = dict()
for i in tqdm(indices):
    for j in indices:
        if (i, j) in distances or (j, i) in distances or i == j:
            continue
        
        doc_i = trainer.train_dataset[i]
        doc_j = trainer.train_dataset[j]
        if i not in embeddings:
            embeddings[i] = trainer.model(doc_i, return_doc_embedding=True)
        if j not in embeddings:
            embeddings[j] = trainer.model(doc_j, return_doc_embedding=True)
        
        # p determines the Minkowski order. 2 is Euclidean, 1 is Manhattan. etc.
        distances[(i, j)] = torch.cdist(embeddings[i].unsqueeze(0), embeddings[j].unsqueeze(0), p=2).item()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:13<00:00,  7.26it/s]


In [12]:
s = pd.Series(distances.values())
s.describe()

count    4950.000000
mean        0.052788
std         0.020364
min         0.022364
25%         0.039627
50%         0.043556
75%         0.059977
max         0.125888
dtype: float64

In [28]:
np.percentile(list(distances.values()), 80)

0.07526061534881592

In [16]:
# create a DataFrame from your Series s
df = pd.DataFrame(s)
df.columns=['value']
df.head()

Unnamed: 0,value
0,0.044161
1,0.03801
2,0.083194
3,0.037413
4,0.042177


In [17]:
# create the histogram chart
alt.Chart(df).mark_bar().encode(
    alt.X('value', bin=alt.Bin(maxbins=100)),
    y='count()'
)