# Clustering algorithm

In [149]:
import torch
import pandas as pd
import json
from sklearn.cluster import KMeans, AffinityPropagation
from sklearn.metrics import silhouette_score
import numpy as np
import warnings
warnings.filterwarnings('ignore')

## 1. Load the data and the embeddings

In [150]:
FILE_TO_READ = './data/dev-testing/axolotl.dev.fi.tsv'
EMBEDDING_TYPE = 'glosses' # 'examples', 'glosses' or 'concatenated'
PRINT_WORDS = False
CLUSTERING_METHOD = 'KMeans' # 'KMeans' or 'AffinityPropagation'

language = FILE_TO_READ.split('.')[-2]
filename = FILE_TO_READ.split('/')[-1].split('.')[0:-1]
filename = '.'.join(filename)
embeddings_file = f"./embeddings/{EMBEDDING_TYPE}/{filename}.json"
language, embeddings_file

('fi', './embeddings/glosses/axolotl.dev.fi.json')

In [151]:
df = pd.read_csv(FILE_TO_READ, sep='\t')
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6554 entries, 0 to 6553
Data columns (total 9 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   usage_id              6554 non-null   object
 1   word                  6554 non-null   object
 2   orth                  6554 non-null   object
 3   sense_id              3203 non-null   object
 4   gloss                 3203 non-null   object
 5   example               6554 non-null   object
 6   indices_target_token  6554 non-null   object
 7   date                  6554 non-null   int64 
 8   period                6554 non-null   object
dtypes: int64(1), object(8)
memory usage: 461.0+ KB


In [152]:
with open(embeddings_file, 'r') as json_file:
    embeddings_list = json.load(json_file)

embeddings = torch.tensor(embeddings_list)
assert embeddings.shape[0] == df.shape[0], "Embeddings count must be the same as the df length"
embeddings.size()

torch.Size([6554, 768])

In [153]:
df['embedding'] = list(embeddings)
assert all(df['embedding'][0] == embeddings[0])

## 2. Clustering algorithm with the embeddings 

In [154]:
def get_silhouette_score(tensors, labels):
    X = np.array([tensor.flatten().numpy() for tensor in tensors])
    score = silhouette_score(X, labels=labels, metric='euclidean')
    return score

def KMeans_clustering(df):
    best_score = -1
    best_n = 0
    min_senses = df['sense_id'].nunique()
    max_senses = min_senses + df['sense_id'].isnull().sum()

    for n in range(min_senses,max_senses):
        kmeans = KMeans(n_clusters=n, random_state=0, n_init='auto')
        kmeans.fit(df['embedding'].tolist())
        df[f'cluster_{n}'] = None
        df[f'cluster_{n}'] = kmeans.labels_
        try:
            silhouette_avg = get_silhouette_score(df['embedding'], df[f'cluster_{n}']) if n > 1 else 0 # TODO: Esto es correcto, sí debería ser 0?
        except Exception as e:
            # this happens with glooses because they may have exactly the same embedding
            silhouette_avg = 1e6 # very high value
            #raise e
        if silhouette_avg > best_score:
            best_score = silhouette_avg
            best_n = n

    if PRINT_WORDS:
        print("Best number of clusters:", best_n, f"[{min_senses}-{max_senses}]")
    df['cluster'] = df[f'cluster_{best_n}']
    df = df.drop(columns=[f'cluster_{n}' for n in range(min_senses,max_senses)])
    return df

def AffinityPropagation_clustering(df):
    ap = AffinityPropagation()
    clusters = ap.fit(df['embedding'].tolist())
    df['cluster'] = None
    df['cluster'] = clusters.labels_
    return df

def clustering(df, method="AffinityPropagation"):
    if method == "KMeans":
        df_cl = KMeans_clustering(df)
    else:
        df_cl = AffinityPropagation_clustering(df)
    
    clusters_replaced = df_cl.loc[~df_cl['sense_id'].isna(), 'cluster']
    clusters_names = df_cl.loc[~df_cl['sense_id'].isna(), 'sense_id']

    for index, value in clusters_replaced.items():
        df_cl.loc[df_cl['cluster'] == value, 'cluster'] = clusters_names[index]
    
    df_cl['sense_id'] = df_cl['cluster']
    df_cl.drop(columns=['cluster', 'embedding'], inplace=True)

    return df_cl

In [155]:
result_df = pd.DataFrame()
for word, group in df.groupby('word'):
    if PRINT_WORDS:
        print(f"{word}: ", end="")
    group_cl = clustering(group, method=CLUSTERING_METHOD)
    result_df = pd.concat([result_df, group_cl], ignore_index=True)
    if len(group) != len(group_cl):
        print(f"{len(group)} != {len(group_cl)} for word {word}")

result_df = result_df.set_index('usage_id')
result_df = result_df.reindex(df['usage_id'])
result_df = result_df.reset_index()

result_df

Unnamed: 0,usage_id,word,orth,sense_id,gloss,example,indices_target_token,date,period
0,dev_fi_0,ajainen,"ajaisen,",ajainen_iMPSFeVQEfY,,"witzauxen ajaisen, Herra sun päälles heitti",10:18,1700,new
1,dev_fi_1,ajainen,ajainen,ajainen_iMPSFeVQEfY,,"ajainen rangaistus nijn aiwan suur, Täs rundel...",0:7,1750,new
2,dev_fi_2,alentaa,"alenna,",alentaa_pJTUhn5-iL4,,"Ei tuki suowa alenna, wara wenhettä caada",14:21,1700,new
3,dev_fi_3,alentaa,alettuja,alentaa_pJTUhn5-iL4,,kynnön pitä [tapahtua niin] – – ettei yhtäkän ...,67:75,1750,new
4,dev_fi_4,alentaa,alendaa,alentaa_pJTUhn5-iL4,,on myös hyödyllinen – – maata Tarhan sisä puol...,64:71,1750,new
...,...,...,...,...,...,...,...,...,...
6549,dev_fi_6549,mieli-suosio,mieli suosjost,mieli-suosio_hpsE0n4lHYs,"suostumus, hyväksyntä; suopeus, suosio; vapaae...","Tytär täsä dandzais taitavast, Herodexen mieli...",41:55,1600,old
6550,dev_fi_6550,mieli-suosio,miel suosjo,mieli-suosio_hpsE0n4lHYs,"suostumus, hyväksyntä; suopeus, suosio; vapaae...","Herrall cunnja corkjudhes, Maasa raoha ihmises...",58:69,1600,old
6551,dev_fi_6551,mieli-suosio,mielisuosio,mieli-suosio_hpsE0n4lHYs,"suostumus, hyväksyntä; suopeus, suosio; vapaae...",P. Hengen erinomainen – – hywä tahto ia mielis...,40:51,1600,old
6552,dev_fi_6552,mieli-suosio,Mielisuosion,mieli-suosio_hpsE0n4lHYs,"suostumus, hyväksyntä; suopeus, suosio; vapaae...",Minä N. N. lupan – – että minä ilman catzomist...,65:77,1650,old


In [156]:
# save the results_df to a tsv file
result_df.to_csv(f'./predictions/{filename}-{EMBEDDING_TYPE}-{CLUSTERING_METHOD}.tsv', sep='\t', index=False)