# Categories for centroid tests

In [None]:
categ_dict = dict(zip(s2c["source_vocabulary_id"].unique(), range(s2c["source_vocabulary_id"].unique().shape[0])))

In [None]:
s2c.loc[:, 'category'] = s2c['source_vocabulary_id'].map(categ_dict)
s2c

In [None]:
categories = np.array(s2c['category'].tolist())

# Mean pooling centroid computation

In [None]:
def compute_centroid(vectors, categories, category):
    indeces = np.where(categories == category)
    selected_vectors = vectors.take(indeces[0], axis=0)
    return np.mean(selected_vectors, axis=0)

In [None]:
centroids = {}

for cat in sorted(np.unique(categories)):
    centroids[cat] = compute_centroid(targets_emb, categories, cat)


In [None]:
centroids = np.array([v for k, v in centroids.items()])
len(centroids)

In [None]:
targets_pcs = compute_pca(targets_emb)

In [None]:
plot_pca(pcs=targets_pcs, colors=categories, names=categories)

In [None]:
sources_pcs = compute_pca(centroids)
plot_pca(pcs=sources_pcs, colors=sorted(np.unique(categories)), names=sorted(np.unique(categories)))

# Number of hits in top1 and top2 centroids test:

In [None]:
correct_centroid = 0 
top2_hit = 0
for i in range(total):
    # Compute distances
    source_example = sources_emb[i]
    distance, index = norml2_innerproduct(centroids, source_example)

    if index[0][0] == categories[i]:
        correct_centroid += 1
    if categories[i] in index[0][:2]:
        top2_hit += 1

In [None]:
print(correct_centroid/total)

In [None]:
print(top2_hit/total)

# If we count the number of singular hits for each category?

In [None]:
from scipy.stats import mode
correct_category = 0 

for i in range(total):
    # Compute distances
    source_example = sources_emb[i]
    distance, index = norml2_innerproduct(targets_emb, source_example)
    most_frequent_cat = mode(categories.take(index[0][:50])).mode

    if most_frequent_cat == categories[i]:
        correct_category += 1

In [None]:
correct_category/total

# Test Centroid

In [None]:
def test_models_centroid(models: list, sources: list, targets: list, categories, centroids):
    
    # Store results
    results = []

    for plm in tqdm(models, desc="Testing models: "):

        # Track Results
        correct_centroid = 0
        top1 = 0
        top5 = 0
        top10 = 0
        total = len(sources)

        # Load Model
        needs_remote_code = "no"
        try:
            model = SentenceTransformer(plm, trust_remote_code = False)
        except ValueError:
            model = SentenceTransformer(plm, trust_remote_code = True)
            needs_remote_code = "yes"

        # Encode
        sources_emb = model.encode(sources, normalize_embeddings=True)
        targets_emb = model.encode(targets, normalize_embeddings=True)
        
        # NOTE: Centroid computation would enter here
        centroid_emb = centroids

        start = time()
        for i in tqdm(range(total), leave=False, desc="Computing dinstances: "):

            # Compute distances
            source_example = sources_emb[i]
            distance, index = norml2_innerproduct(centroid_emb, source_example)
            infered_category = index[0][0]
            if infered_category == categories[i]:
                correct_centroid += 1
            
            # Search within the restricted space
            indeces = np.where(categories == infered_category)
            selected_vectors = targets_emb.take(indeces[0], axis=0)
            indeces_map = dict(zip(range(selected_vectors.shape[0]), indeces))

            distance, index = norml2_innerproduct(selected_vectors, source_example)
            # Retrieve the real indeces
            index2 = [indeces_map(_) for _ in index[0]]

            # Check matches
            if i == index2[0]:
                top1 += 1
                top5 += 1
                top10 += 1
            elif i in index2[:5]:
                top5 += 1
                top10 += 1
            elif i in index2[:10]:
                top10 += 1
        end = time()
        elapsed_seconds = end - start

        results.append(f"""
                        plm: {plm};
                        needs remote code: {needs_remote_code};
                        Correct_centroid: {correct_centroid/total:.2%};
                        Top 1 match: {top1/total:.2%};
                        Top 5 match: {top5/total:.2%};
                        Top 10 match: {top10/total:.2%};
                        Total number of tests: {len(sources)},
                        Elapsed seconds: {elapsed_seconds};
                        Predictions per second X 1000: {len(sources)/elapsed_seconds/1000:.2}
                        """)
    return results

In [None]:
list_of_models = ["mixedbread-ai/mxbai-embed-large-v1", 
                  "intfloat/multilingual-e5-small",
                  "intfloat/multilingual-e5-large",
                  "sentence-transformers/all-MiniLM-L6-v2", 
]

In [None]:
results = test_models(list_of_models, sources, targets)

In [None]:
[print(_) for _ in results]

# Test number of hits in each category

In [None]:
def test_models_centroid(models: list, sources: list, targets: list, categories, centroids):
    
    # Store results
    results = []

    for plm in tqdm(models, desc="Testing models: "):

        # Encode
        needs_remote_code = "no"
        try:
            model = SentenceTransformer(plm, trust_remote_code = False)
        except ValueError:
            model = SentenceTransformer(plm, trust_remote_code = True)
            needs_remote_code = "yes"

        sources_emb = model.encode(sources, normalize_embeddings=True)
        targets_emb = model.encode(targets, normalize_embeddings=True)
        
        # o calculo dos centroids entraria aqui
        centroid_emb = centroids

        # counts
        correct_category = 0
        top1 = 0
        top5 = 0
        top10 = 0
        total = len(sources)
        
        start = time()
        for i in tqdm(range(total), leave=False, desc="Computing dinstances: "):
            
            # Compute distances
            source_example = sources_emb[i]
            distance, index = compute_distance(targets_emb, source_example)
            most_frequent_cat = mode(categories.take(index[0][:50])).mode

            if most_frequent_cat == categories[i]:
                correct_category += 1

            # Search within the restricted space
            indeces = np.where(categories == most_frequent_cat)
            selected_vectors = targets_emb.take(indeces[0], axis=0)
            indeces_map = dict(zip(range(selected_vectors.shape[0]), indeces))

            distance, index = compute_distance(selected_vectors, source_example)
            # Retrieve the real indeces
            index2 = [indeces_map(_) for _ in index[0]]

            # Check matches
            if i == index2[0]:
                top1 += 1
                top5 += 1
                top10 += 1
            elif i in index2[:5]:
                top5 += 1
                top10 += 1
            elif i in index2[:10]:
                top10 += 1
        end = time()
        elapsed_seconds = end - start

        results.append(f"""
                        plm: {plm};
                        needs remote code: {needs_remote_code};
                        Correct_centroid: {correct_category/total:.2%};
                        Top 1 match: {top1/total:.2%};
                        Top 5 match: {top5/total:.2%};
                        Top 10 match: {top10/total:.2%};
                        Total number of tests: {len(sources)},
                        Elapsed seconds: {elapsed_seconds};
                        Predictions per second X 1000: {len(sources)/elapsed_seconds/1000:.2}
                        """)
    return results

In [None]:
list_of_models = ["microsoft/Multilingual-MiniLM-L12-H384", 
                  'intfloat/multilingual-e5-small',
                  "sentence-transformers/all-MiniLM-L6-v2", 
]

In [None]:
results = test_models(list_of_models, sources, targets)

In [None]:
[print(_) for _ in results]