## Behind the scenes: edge clustering

In [5]:
import numpy as np
import os
import json
import spacy
from sklearn.manifold import TSNE
from sklearn.decomposition import IncrementalPCA, PCA
from sklearn.metrics import silhouette_score
from sklearn.mixture import GaussianMixture
from scipy.cluster.hierarchy import dendrogram
import matplotlib.pyplot as plt
import numpy as np
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from adjustText import adjust_text


In [None]:
#######
aggregation_method = 'mean' # Options: 'min', 'max', 'median', 'mean'
#######

# Load data
with open(f"input/relation_embeddings.json", 'r') as file:
    company_embeddings = json.load(file)

companies = list(company_embeddings[aggregation_method].keys())
X = np.array(list(company_embeddings[aggregation_method].values()))

print('Silhouette Scores:')
for K in range(2,10):
    gmm = GaussianMixture(n_components=K, covariance_type='full', random_state=42)
    gmm_labels = gmm.fit_predict(X)
    sil_score = silhouette_score(X, gmm_labels)
    print(f"- K = {K}, Score = {sil_score:.3f}")

In [None]:
def make_GMM_plot_for_relations(n_clusters, aggregation_method, exclude_firms=None, box_position='right'):
    # Load data
    with open(f"input/relation_embeddings.json", 'r') as file:
        company_embeddings = json.load(file)
    
    companies = list(company_embeddings[aggregation_method].keys())
    X = np.array(list(company_embeddings[aggregation_method].values()))

    # Remove 'exclude firms' list
    def filter_companies(X, companies, exclude_firms=None):
        if exclude_firms is not None:
            exclude_indices = [i for i, company in enumerate(companies) if company in exclude_firms]
            X = np.delete(X, exclude_indices, axis=0)
            companies = np.delete(companies, exclude_indices, axis=0)
        return X, companies
    
    X, companies = filter_companies(X, companies, exclude_firms)

    # Normalize data
    scaler = StandardScaler()
    X_normalized = scaler.fit_transform(X)

    # Apply GMM clustering
    gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
    gmm_labels = gmm.fit_predict(X_normalized)

    # PCA for dimensionality reduction 
    pca = PCA(n_components=2)
    X_reduced = pca.fit_transform(X_normalized)

    plt.figure(figsize=(6, 6)) # Should be: (10, 6)
    plt.scatter(X_reduced[:, 0], X_reduced[:, 1], c=gmm_labels, cmap='viridis', s=50, alpha=0.7)

    sil_score = silhouette_score(X, gmm_labels)

    # Add box info
    textstr = f"Method: {aggregation_method}\nNum clusters: {n_clusters}\nSil score: {sil_score:.3f}"
    props = dict(boxstyle='square,pad=0.5', facecolor='lightgrey', edgecolor='black', alpha=0.75)

    position = (0.85, 0.95)
    if box_position == 'left':
        position = (0.03, 0.95)
    elif box_position == 'down_right':
        position = (0.85, 0.17) # (0.75, 0.13) for square
    elif box_position == 'down_left':
        position = (0.03, 0.17) # (0.05, 0.13) for square

    plt.text(position[0], position[1], textstr, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=props)

    # Annotate company names
    texts = []
    for i, company_name in enumerate(companies):
        texts.append(plt.text(X_reduced[i, 0], X_reduced[i, 1], company_name, fontsize=9))
    adjust_text(texts, arrowprops=dict(arrowstyle='-', color='gray', lw=0.5))

    os.makedirs(f'results/gmm_edges/{aggregation_method}', exist_ok=True)

    filename = f'results/gmm_edges/{aggregation_method}/{aggregation_method}_relations_K_{n_clusters}'
    if exclude_firms is not None:
        filename += f'_excluded_{len(exclude_firms)}_square'

    print(f"Saving {filename}")

    plt.xlabel('PCA Component 1')
    plt.ylabel('PCA Component 2')
    plt.tight_layout()
    plt.savefig((filename))
    plt.show()

####################
K = 5
aggregation_method = 'mean'

exclude_firms = None
exclude_firms = ['General Atomics', 'Ultra Safe Nuclear Corporation', 'Babcock and Wilcox']
#exclude_firms += ['Westinghouse']
exclude_firms += ['Westinghouse', 'ARC']

make_GMM_plot_for_relations(K, aggregation_method, exclude_firms, box_position='down_left')