In [1]:
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import DBSCAN, OPTICS, KMeans
from sklearn import metrics
import hdbscan
import time
import umap
import umap.plot

In [2]:
df = pd.read_parquet('dataset/semantic_similarity/hscode.parquet')

In [3]:
df_hscode2022 = df[df['HSYear'] == 2022]

In [4]:
df_hscode4_th = df_hscode2022[['thdescriptions4','HSCode2']]
df_hscode4_th = df_hscode4_th.rename(columns={'thdescriptions4': 'hscode4_text'})

df_hscode4_en = df_hscode2022[['endescriptions4','HSCode2']]
df_hscode4_en = df_hscode4_en.rename(columns={'endescriptions4': 'hscode4_text'})

df_hscode4 = pd.concat([
    df_hscode4_th,
    df_hscode4_en
])
df_hscode4 = df_hscode4.drop_duplicates()
df_hscode4 = df_hscode4[df_hscode4['hscode4_text'].notnull()]

In [6]:
model = SentenceTransformer('distiluse-base-multilingual-cased-v2')

In [7]:
X = model.encode(df_hscode4['hscode4_text'].values, convert_to_tensor=True)

In [8]:
X.shape

torch.Size([2427, 512])

In [9]:
corpus_embeddings = X

In [32]:

def find_accuracy(series):    
    n = 2
    count_series = series.value_counts()
    sum_all = count_series.sum()
    count_series_no_noise = count_series[count_series.index != -1]
    sum_top_n = count_series_no_noise.iloc[:n].sum()
    accuracy = sum_top_n / sum_all
    
    return accuracy

In [107]:
def find_community(threshold):
    clusters = util.community_detection(corpus_embeddings, min_community_size=5, threshold=threshold)
    df_hscode4['cluster_group_id'] = -1
    
    for i, cluster in enumerate(clusters):
        df_hscode4.iloc[cluster, df_hscode4.columns.get_loc('cluster_group_id')] = i
    
    acc_each_hscode = df_hscode4.groupby('HSCode2')['cluster_group_id'].apply(find_accuracy)
    # import pdb;pdb.set_trace()
    acc_mean = acc_each_hscode.mean()
    
    n_group = df_hscode4['cluster_group_id'].value_counts().shape[0]
    if n_group == 1:
        silhouette = 0.0
    # import pdb; pdb.set_trace()
    else:
        silhouette = metrics.silhouette_score(X, df_hscode4['cluster_group_id'].values, metric='euclidean')
    
    return acc_each_hscode, acc_mean, silhouette, df_hscode4.copy()

In [108]:
result_list = []
for threshold in np.arange(0.15, 1.0, 0.05):    
    acc_each_hscode, acc_mean, silhouette, df_result = find_community(threshold=threshold)
    n_each_cluster = df_result['cluster_group_id'].value_counts()
    n_top_20_cluster = n_each_cluster.iloc[:20].values
    cluster_size = n_each_cluster.shape[0]
    result_list.append({
        'threshold': threshold,
        'acc_mean': acc_mean,
        'silhouette': silhouette,
        'cluster_size': cluster_size,
        'n_top_20_cluster': n_top_20_cluster,
    })
    print(f'threshold: {threshold:.04}, acc_mean: {acc_mean:.04}, silhouette: {silhouette:.04}, cluster_size: {cluster_size}')

threshold: 0.15, acc_mean: 0.8958, silhouette: -0.1145, cluster_size: 34
threshold: 0.2, acc_mean: 0.8375, silhouette: -0.07093, cluster_size: 47
threshold: 0.25, acc_mean: 0.733, silhouette: -0.08876, cluster_size: 64
threshold: 0.3, acc_mean: 0.6011, silhouette: -0.0828, cluster_size: 88
threshold: 0.35, acc_mean: 0.482, silhouette: -0.08549, cluster_size: 115
threshold: 0.4, acc_mean: 0.4446, silhouette: -0.07822, cluster_size: 136
threshold: 0.45, acc_mean: 0.4052, silhouette: -0.07011, cluster_size: 158
threshold: 0.5, acc_mean: 0.3854, silhouette: -0.06904, cluster_size: 167
threshold: 0.55, acc_mean: 0.3503, silhouette: -0.07348, cluster_size: 165
threshold: 0.6, acc_mean: 0.328, silhouette: -0.08341, cluster_size: 148
threshold: 0.65, acc_mean: 0.2414, silhouette: -0.124, cluster_size: 107
threshold: 0.7, acc_mean: 0.1574, silhouette: -0.1593, cluster_size: 64
threshold: 0.75, acc_mean: 0.1047, silhouette: -0.1789, cluster_size: 42
threshold: 0.8, acc_mean: 0.04468, silhouette:

In [109]:
df_result_all = pd.DataFrame(result_list)
df_result_all.to_csv('commnity_detection_output.csv',index=False, sep='\t')

In [100]:
# acc_each_hscode, acc_mean, silhouette, df_result = find_community(threshold=0.9)