In [29]:
import pandas as pd
import numpy as np
import altair as alt
from random import sample
import torch
from tqdm import tqdm

from doc_embed_torch import DocumentEmbeddingTrainer, load_run_config, VOCAB_SIZE

In [11]:
run_code = "p1CiZIzT"
trainer = DocumentEmbeddingTrainer(run_code="p1CiZIzT")
trainer.load_mlm(run_code, VOCAB_SIZE)

Preparing the masked dataset ...
Done preparing the masked dataset.
Preparing the model for quantization ...


In [43]:
indices = sample(range(len(trainer.train_dataset)), 200)
print(indices[:10])

[39582, 97432, 10137, 47954, 23988, 123218, 101765, 40866, 28672, 6067]


In [44]:
distances = dict()
for i in tqdm(indices):
    for j in indices:
        if (i, j) in distances or (j, i) in distances or i == j:
            continue
        
        doc_i = trainer.train_dataset[i]
        doc_j = trainer.train_dataset[j]
        embed_i = trainer.model(doc_i, return_doc_embedding=True)
        embed_j = trainer.model(doc_j, return_doc_embedding=True)
        
        # p determines the Minkowski order. 2 is Euclidean, 1 is Manhattan. etc.
        distances[(i, j)] = torch.cdist(embed_i.unsqueeze(0), embed_j.unsqueeze(0), p=2).item()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [08:04<00:00,  2.42s/it]


In [45]:
s = pd.Series(distances.values())
s.describe()

count    19900.000000
mean        30.373340
std          2.682652
min          5.589416
25%         29.380765
50%         30.822581
75%         31.936457
max         36.048267
dtype: float64

In [58]:
# create a DataFrame from your Series s
df = pd.DataFrame({'s': s})
df['bin'] = pd.cut(df['s'], bins=100)
df_agg = df.groupby('bin').size().reset_index(name='count')
df_agg['mid'] = df_agg.bin.apply(lambda x: x.mid)
df_agg = df_agg[['mid', 'count']].copy()

df_agg.head()

Unnamed: 0,mid,count
0,5.7265,1
1,6.0465,0
2,6.351,1
3,6.6555,1
4,6.96,2


In [59]:
# create the histogram chart
alt.Chart(df_agg).mark_bar().encode(
    alt.X('mid'),
    y='count'
).properties(
    height=500,
    width=750,
)