In [None]:
import os
import sys
from pathlib import Path

sys.path.append(os.path.join(Path().resolve(), "../src/"))

from analyst import Analyst
from config import ModelConfig, TrainerConfig
from dataset_manager import load_dataset_manager
from model import load_model

In [None]:
model_config = ModelConfig()
trainer_config = TrainerConfig(
    dataset_name="toydata-seq-lengths",
    model_dir="../cache/model",
    dataset_dir="../cache/dataset",
)

dataset_manager = load_dataset_manager(
    dataset_name=trainer_config.dataset_name,
    dataset_dir=trainer_config.dataset_dir,
    load_dataset=trainer_config.load_dataset,
    save_dataset=trainer_config.save_dataset,
    window_size=model_config.window_size,
)
model = load_model(
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config
)

In [None]:
analyst = Analyst(model, dataset_manager)

In [None]:
analyst.visualize_similarity_heatmap()

In [None]:
norms = {}

for i, (id, seq) in enumerate(dataset_manager.raw_sequences.items()):
    l = len(seq)
    if l not in norms:
        norms[l] = []
    norms[l].append(model.seq_embedding[i].norm().item())

In [None]:
x, y = [], []
x_means, y_means = [], []
for l, norm in norms.items():
    for v in norm:
        x.append(l - 20)
        y.append(v)
    x_means.append(l - 20)
    y_means.append(sum(norm) / len(norm))

In [None]:
import matplotlib.pyplot as plt

plt.scatter(x, y, marker="x")
plt.plot(x_means, y_means, c="red", marker="o")
plt.xticks([25, 50, 75, 100])