In [1]:
# Import packages
 
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import distance
from tqdm import tqdm
import glob
from sklearn.cluster import KMeans
from statistics import mode



In [2]:
def calc_am(cluster):
    """ Caclulate the arithmetic mean of a cluster. """
    center_x = np.mean(cluster[:,0])
    center_y = np.mean(cluster[:,1])

    return np.array([center_x, center_y])

In [3]:
def calc_gm(cluster):
    """ Caclulate the geometric mean of a cluster. """
    center_x = gmean(cluster[:,0])
    center_y = gmean(cluster[:,1])

    return np.array([center_x, center_y])

In [4]:
def collect_cluster(cluster_no, embedding, label_embedding_df):
    """ Collect all points in a cluster. """
    indices = np.where(label_embedding_df['syll_id'] == cluster_no)[0]
    cluster_points = embedding[indices]

    return cluster_points


In [5]:
def collect_bird_cluster(bird_id, label_embedding_df):
    """ Collect all points in a cluster. """
    bird_clusters = label_embedding_df.loc[label_embedding_df['bird_id']==bird_id, 'syll_id'].values
    bird_clusters =  np.unique(bird_clusters)

    return bird_clusters


In [6]:
def find_bird_id_of_cluster(cluster_no, label_embedding_df):
    """ Collect all points in a cluster. """
    cluster_rows = label_embedding_df[label_embedding_df['syll_id'] == cluster_id]
    bird_id = cluster_rows['bird_id'].values[0]

    return bird_id

In [7]:
def inertia(points, centroid):
    sum_dist = 0
    for p in points:
        sum_dist += np.linalg.norm(centroid - p)
    return sum_dist/points.shape[0]


In [8]:
def collect_kmeans_cluster(cluster_points):

    kmeans = KMeans(n_clusters=2, random_state=0, n_init="auto").fit(cluster_points)
    kmeans_labels = kmeans.labels_
    kmeans_centroids = kmeans.cluster_centers_

    # fig = plt.figure()

    # sp = plt.scatter(cluster_points[:,0], cluster_points[:,1], c=kmeans_labels, cmap='Dark2', marker=',', alpha=.5)
    # plt.scatter(kmeans_centroids[:,0], kmeans_centroids[:,1], marker='X', color='k')
    # plt.colorbar(sp)

    main_cluster = mode(kmeans_labels)
    # plt.scatter(kmeans_centroids[main_cluster,0], kmeans_centroids[main_cluster,1], marker='X', color='brown')
    # main_cluster=0

    # inertia0 = inertia(cluster_points[np.where(kmeans_labels==0)], kmeans_centroids[0])
    # inertia1 = inertia(cluster_points[np.where(kmeans_labels==1)], kmeans_centroids[1])

    # if inertia0>inertia1: main_cluster=1
    # else: main_cluster=0
    # plt.scatter(kmeans_centroids[main_cluster,0], kmeans_centroids[main_cluster,1], marker='x', color='yellow', alpha=.5)
    # plt.title(main_cluster)

    return cluster_points[np.where(kmeans_labels==main_cluster)]





In [9]:

data_folder = "/home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Main script/Data/"
results_folder ="/home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results/"
embeddings_folder = results_folder + '/Embeddings/'
centroids_folder = results_folder + '/KCentroids/'


In [10]:
normalise = False
n_syllables = 150 # Tutored

In [11]:
embedding_files = glob.glob(embeddings_folder + 'embedding_gen_*.csv')


In [12]:

cmap = plt.cm.Dark2.colors
cmap *= 40
cmap  =  cmap[::-1]

In [13]:
embedding_files[0]

'/home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results//Embeddings/embedding_gen_13315092.csv'

In [14]:
for embedding_file in embedding_files:
    rseed = embedding_file.split('.')[0].split('_')[-1]
    results_filename_tag = '_' + rseed
    rseed = int(rseed)



    label_embedding_df = pd.read_csv(embedding_file, index_col=0)
    embedding = label_embedding_df[['x', 'y']].values


    if normalise == True:
        embedding = (embedding-embedding.min())
        embedding = embedding/embedding.max()

    cluster_labels = label_embedding_df['syll_id'].unique()
    cluster_labels.sort()

    syll_labels = label_embedding_df['syll_id'].values
    bird_labels = label_embedding_df['bird_id'].values

    # Find the centroid = arithmetic mean of each cluster
    cluster_ameans  = np.zeros((n_syllables, 2))
    cluster_kmeans  = np.zeros((n_syllables, 2))
    cluster_bird_id =  np.zeros((n_syllables, 1))
    cluster_ids = np.zeros((n_syllables, 1)) + 300
    for cluster_id in cluster_labels:
        cluster_points = collect_cluster(cluster_id, embedding, label_embedding_df)
        kmeans_cluster_points = collect_kmeans_cluster(cluster_points)

        cluster_ameans[cluster_id] = calc_am(cluster_points)
        cluster_kmeans[cluster_id] = calc_am(kmeans_cluster_points)
        cluster_bird_id[cluster_id] = find_bird_id_of_cluster(cluster_id, label_embedding_df)
        cluster_ids[cluster_id] = cluster_id

    centroid_info = {
        'cluster_id': cluster_ids[:,0],
        'centroid_x': cluster_ameans[:,0],
        'centroid_y': cluster_ameans[:,1],
        'kcentroid_x': cluster_kmeans[:,0],
        'kcentroid_y': cluster_kmeans[:,1],
        'bird_id': cluster_bird_id[:,0]
    }



    centroid_df = pd.DataFrame(centroid_info)
    centroid_df['range_0_min'] = embedding[:,0].min()
    centroid_df['range_1_min'] = embedding[:,1].min()
    centroid_df['range_0_max'] = embedding[:,0].max()
    centroid_df['range_1_max'] = embedding[:,1].max()
    centroid_df.to_csv(centroids_folder + 'centroids' + '_normalised'*(normalise) + '_' + results_filename_tag +'.csv', index=False) 
    print("Saved to ", centroids_folder + 'centroids' + '_normalised'*(normalise) + '_' + results_filename_tag +'.csv') 

    fig = plt.figure(figsize=(10, 10))
    plt.scatter(embedding[:, 0], embedding[:, 1], c=bird_labels, cmap='Spectral', s=.1, alpha=.01)



    # plt.scatter(cluster_ameans[:,0], cluster_ameans[:,1], c=cluster_bird_id, edgecolor='k', s=40, marker='X')
    ps = plt.scatter(cluster_kmeans[:,0], cluster_kmeans[:,1], c=cluster_bird_id, cmap='Spectral', s=40, marker='x', alpha=.5, label='Kcentroid')

    #     plt.text(cluster_ameans[bird_cluster_id,0]+1, cluster_ameans[bird_cluster_id,1]+1, str(bird_cluster_id),
    #                 color=cmap[bci])

    
    plt.title('All tutored birds with rseed '  + str(rseed))


    plt.colorbar(ps)

    plt.xlim(embedding[:,0].min()-5, embedding[:,0].max()+5)
    plt.ylim(embedding[:,1].min()-5, embedding[:,1].max()+5)

    plt.legend(fontsize=10, bbox_to_anchor=(1, 0.9))

    plt.tight_layout()

    plt.savefig(centroids_folder + 'centroids' + '_normalised'*(normalise) +  '_' + results_filename_tag  +'.png', dpi=300)

    plt.close()
    








Saved to  /home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results//KCentroids/centroids__13315092.csv
Saved to  /home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results//KCentroids/centroids__96319575.csv
Saved to  /home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results//KCentroids/centroids__88409749.csv
Saved to  /home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results//KCentroids/centroids__85652971.csv
Saved to  /home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results//KCentroids/centroids__93410762.csv
Saved to  /home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results//KCentroids/centroids__13953367.csv
Saved to  /home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Results//KCentroids/centroids__21081788.csv
Saved to  /home/remya/Work/AlamTest/Alam JC/Oct 2024/Alam tests/Aux scripts/RseedTest/Resu

In [15]:
centroid_df

Unnamed: 0,cluster_id,centroid_x,centroid_y,kcentroid_x,kcentroid_y,bird_id,range_0_min,range_1_min,range_0_max,range_1_max
0,0.0,-8.965124,-1.494228,-9.197216,-1.537718,0.0,-15.758223,-16.897474,18.409811,17.584328
1,1.0,-7.512445,6.310543,-7.576684,6.367309,0.0,-15.758223,-16.897474,18.409811,17.584328
2,2.0,1.940302,16.791844,1.919608,17.192378,0.0,-15.758223,-16.897474,18.409811,17.584328
3,3.0,-1.600823,-11.229536,-1.963580,-12.373137,0.0,-15.758223,-16.897474,18.409811,17.584328
4,4.0,-1.472903,3.724189,-2.461621,4.369026,0.0,-15.758223,-16.897474,18.409811,17.584328
...,...,...,...,...,...,...,...,...,...,...
145,145.0,-5.474650,5.070977,-5.524012,5.104523,30.0,-15.758223,-16.897474,18.409811,17.584328
146,146.0,-11.443208,9.094238,-11.875408,9.339500,30.0,-15.758223,-16.897474,18.409811,17.584328
147,147.0,-5.480967,-0.899654,-5.556455,-0.784547,30.0,-15.758223,-16.897474,18.409811,17.584328
148,148.0,9.852521,-9.172308,9.915235,-9.225741,30.0,-15.758223,-16.897474,18.409811,17.584328
