In [None]:
%load_ext autoreload
%autoreload 2

# Save nearby clusters

For each cluster, save a list of nearby clusters so that we don't have to calculate nearby clusters each time.

Use a bi-encoder followed by a cross-encoder to determine which clusters are nearby.

In [None]:
from collections import defaultdict, Counter
import json
import math
import os
import random
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# from py4j.java_gateway import JavaGateway
from sentence_transformers.cross_encoder import CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity
from statistics import harmonic_mean
import torch
from tqdm.auto import tqdm

from src.data.normalize import normalize
from src.data.utils import read_csv
from src.models.biencoder import BiEncoder
from src.models.tokenizer import get_tokenize_function_and_vocab
from src.models.utils import get_cross_encoder_score, top_similar_names

In [None]:
# configure
given_surname = "given"

be_score_threshold = 0.4
be_limit = 50  # 100
ce_score_threshold = 0.01

max_tokens = 10
subwords_path=f"../data/models/fs-{given_surname}-subword-tokenizer-2000f.json"
ce_model_dir = f"../data/models/cross-encoder-{given_surname}-10m-265-same-all"
be_model_type = 'cecommon+0+aug-0-1'
be_model_path = f"../data/models/bi_encoder-{given_surname}-{be_model_type}.pth"
tokenizer_max_length = 32
scorer = "ce"
linkage = "average"
similarity_threshold = 0.1
cluster_freq_normalizer = "none"
clusters_path = f"../data/processed/clusters_{given_surname}-{scorer}-{linkage}-{similarity_threshold}-{cluster_freq_normalizer}-augmented.json"

nearby_clusters_path = f"../data/processed/nearby_clusters_{given_surname}-{scorer}-{linkage}-{similarity_threshold}-{cluster_freq_normalizer}.json"

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

## Load data

In [None]:
cluster_position = {}
cluster_names = {}
cluster_centroids = []        # centroid for each cluster
cluster_centroid_labels = []  # label for each cluster

with open(clusters_path, 'r') as f:
    clusters = json.load(f)  # cluster label -> names, centroid

for label, cluster in clusters.items():
    cluster_position[label] = len(cluster_centroids)
    cluster_names[label] = set(cluster['names'])
    cluster_centroid_labels.append(label)
    cluster_centroids.append(np.array(cluster['centroid']))
cluster_centroid_labels = np.array(cluster_centroid_labels)
print(len(clusters), sum(len(names) for names in cluster_names.values()))

In [None]:
# load tokenize function
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    max_tokens=max_tokens,
    subwords_path=subwords_path,
)
len(tokenizer_vocab)

In [None]:
# load bi-encoder model
be_model = torch.load(be_model_path)
be_model.eval()

In [None]:
# load cross encoder model
ce_model = CrossEncoder(ce_model_dir, max_length=tokenizer_max_length)

## Find nearby clusters

In [None]:
def sample_names(cluster):
    if not cluster:
        return ''
    return ' '.join(list(cluster_names[cluster])[:8])

In [None]:
def get_nearest_bi_encoder_cluster_scores(cluster, threshold, limit):
    embedding = cluster_centroids[cluster_position[cluster]]
    return top_similar_names(embedding, cluster_centroids, cluster_centroid_labels, threshold, limit)

In [None]:
def get_cross_encoder_cluster_score(cluster, other_cluster):
    pairs = []
    total_score = 0
    total_pairs = 0
    # print('cluster', cluster)
    for cluster_name in cluster_names[cluster]:
        for other_name in cluster_names[other_cluster]:
            if cluster_name == other_name:
                total_score += 1.0
                total_pairs += 1
            else:
                pairs.append((cluster_name, other_name))
                pairs.append((other_name, cluster_name))
    if len(pairs) > 0:
        scores = ce_model.predict(pairs)
        for ix in range(0, len(scores), 2):
            total_score += harmonic_mean([scores[ix], scores[ix+1]])
            total_pairs += 1
    return total_score / total_pairs

In [None]:
# test
cluster = 'richard/richard'
# get the 100 nearest clusters according to the bi-encoder
clusters, scores = get_nearest_bi_encoder_cluster_scores(cluster, limit=10)
print(*zip(clusters, scores))

## Analyze bi-encoder and cross-encoder scores

In [None]:
be_scores = []
ce_scores = []
low_threshold = 0.01
med_threshold = 0.02
ce_low_scores = 0
ce_med_scores = 0
for cluster in tqdm(random.sample(list(clusters.keys()), 100)):
    # print(cluster)
    nears, scores = get_nearest_bi_encoder_cluster_scores(cluster, be_score_threshold, be_limit)
    for near, score in zip(nears, scores):
        if cluster == near:
            continue
        # print(near, score)
        try:
            ce_score = get_cross_encoder_cluster_score(cluster, near)
        except Exception as e:
            print(cluster, near, e)
            continue
        if ce_score < low_threshold:
            continue
        be_scores.append(score)
        ce_scores.append(ce_score)
        if ce_score >= low_threshold:
            ce_low_scores += 1
        if ce_score >= med_threshold:
            ce_med_scores += 1
print(low_threshold, ce_low_scores, med_threshold, ce_med_scores)

In [None]:
plt.scatter(be_scores, ce_scores, s=1, alpha=1)
plt.xlim(0.3, 0.5)
plt.ylim(0.0, 0.1)

### Compute nearby clusters

In [None]:
near_clusters = {}
total_near_clusters = 0
for cluster in tqdm(clusters.keys(), mininterval=1.0):
    near_scores = []
    nears, _ = get_nearest_bi_encoder_cluster_scores(cluster, be_score_threshold, be_limit)
    for near in nears:
        if cluster == near:
            continue
        try:
            ce_score = get_cross_encoder_cluster_score(cluster, near)
        except Exception as e:
            print(cluster, near, e)
            continue
        if ce_score < ce_score_threshold:
            continue
        near_scores.append((near, ce_score))
        total_near_clusters += 1
    near_clusters[cluster] = sorted(near_scores, key=lambda x: x[1], reverse=True)
total_near_clusters

## Save nearby clusters

In [None]:
with open(nearby_clusters_path, 'w') as f:
    json.dump(near_clusters, f, indent=2)