In [1]:
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
import warnings
warnings.filterwarnings("ignore")

TCGA clustering
---

In [2]:
from sklearn.cluster import KMeans, AgglomerativeClustering, SpectralClustering
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
import pandas as pd

def clustering_comparison(dataframe, genes_list, file_name):

    genes = [gene.strip() for gene in genes_list]

    # Validate gene existence in dataframe
    genes_in_dataframe = [gene for gene in genes if gene in dataframe.columns]
    if not genes_in_dataframe:
        raise ValueError("None of the genes found in the file exist in the DataFrame columns.")
    missing_genes = set(genes) - set(genes_in_dataframe)
    if missing_genes:
        print(f"Warning: The following genes were not found in the DataFrame: {missing_genes}")

    # Extract relevant gene data
    gene_data = dataframe[genes_in_dataframe]

    # Store results
    clustering_results = {}

    # Define clustering algorithms with K=2 where applicable
    clustering_algorithms = {
        'KMeans': KMeans(n_clusters=2, random_state=0),
        ### uncomment to rull all clustering algorithms
        # 'Agglomerative': AgglomerativeClustering(n_clusters=2),
        # 'Spectral': SpectralClustering(n_clusters=2, affinity='nearest_neighbors', random_state=0),
        # 'GMM': GaussianMixture(n_components=2, random_state=0)
    }

    for name, algorithm in clustering_algorithms.items():
        try:
            if name == 'GMM':
                algorithm.fit(gene_data)
                labels = algorithm.predict(gene_data)
            else:
                labels = algorithm.fit_predict(gene_data)

            silhouette = silhouette_score(gene_data, labels)
            ch_index = calinski_harabasz_score(gene_data, labels)
            db_index = davies_bouldin_score(gene_data, labels)

            clustering_results[name] = {
                'labels': labels,
                'silhouette_score': silhouette,
                'calinski_harabasz_index': ch_index,
                'davies_bouldin_index': db_index
            }

            print(f"{name} clustering:")
            print(f"  Silhouette Score: {silhouette:.4f}")
            print(f"  Calinski-Harabasz Index: {ch_index:.4f}")
            print(f"  Davies-Bouldin Index: {db_index:.4f}")
            print("-" * 40)

        except Exception as e:
            print(f"Error in {name} clustering: {e}")

    # Save clusters from the best algorithm (e.g., highest silhouette)
    if clustering_results:
        best_algo = max(clustering_results.items(), key=lambda x: x[1]['silhouette_score'])[0]
        dataframe['cluster'] = clustering_results[best_algo]['labels']
        dataframe.to_csv(file_name, index=False)
        print(f"Best clustering method: {best_algo}, saved results to {file_name}")

    return clustering_results

In [3]:
def read_ds():
    df = pd.read_csv("TCGA_dataset.csv") ## insert the path to your local TGCA dataset (expected in csv format)
    filtered_df = df[(df['gender'] != 'MALE') & (df['sample_type'] != 'Solid Tissue Normal')]
    er_pos = filtered_df[(filtered_df['ER_Status_nature2012'] == 'Positive') &
              (filtered_df['HER2_Final_Status_nature2012'] == 'Negative')]
    others = filtered_df.drop(er_pos.index)
    return filtered_df, er_pos, others

In [4]:
filtered_df, er_pos, others = read_ds()
filtered_df.shape, er_pos.shape, others.shape

((1090, 20734), (481, 20734), (609, 20734))

In [5]:
_, er_pos, _ = read_ds()
genes_list = ["CDCA5", "AURKA", "UBE2C", "MKNK2", "CCNB2", "C14orf45", "CYB5D1", "APH1B "]
results = clustering_comparison(er_pos, genes_list, file_name = f"labeled_ds_TCGA.csv")

KMeans clustering:
  Silhouette Score: 0.3529
  Calinski-Harabasz Index: 396.0096
  Davies-Bouldin Index: 1.0250
----------------------------------------
Best clustering method: KMeans, saved results to labeled_ds_TCGA.csv


In [9]:
import matplotlib.pyplot as plt
import numpy as np

def plot_clustering(clustering_results, output_file="clustering_metrics_comparison.pdf"):
    # Initialize lists
    algorithms = []
    silhouette_scores = []
    calinski_scores = []
    davies_scores = []

    # Collect valid metrics
    for alg, metrics in clustering_results.items():
        try:
            s = metrics['silhouette_score']
            c = metrics['calinski_harabasz_index']
            d = metrics['davies_bouldin_index']

            if any(np.isnan([s, c, d])):
                print(f"Skipping {alg} due to NaN in metrics.")
                continue

            algorithms.append(alg)
            silhouette_scores.append(s)
            calinski_scores.append(c)
            davies_scores.append(d)

        except Exception as e:
            print(f"Skipping {alg} due to error: {e}")

    if not algorithms:
        print("No valid clustering results to plot.")
        return

    x = np.arange(len(algorithms))

    # Create 3 subplots side-by-side
    fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)

    metric_data = [
        ("Silhouette Score (↑)", silhouette_scores, 'skyblue'),
        ("Calinski-Harabasz Index (↑)", calinski_scores, 'lightgreen'),
        ("Davies-Bouldin Index (↓)", davies_scores, 'salmon'),
    ]

    for ax, (title, scores, color) in zip(axes, metric_data):
        bars = ax.bar(x, scores, color=color)
        ax.set_title(title)
        ax.set_ylabel("Score")
        ax.set_xticks(x)
        ax.set_xticklabels(algorithms, rotation=45)

        # Add numeric value on top of each bar
        for bar, score in zip(bars, scores):
            height = bar.get_height()
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                height + 0.01 * max(scores),  # small padding
                f'{score:.3f}',
                ha='center', va='bottom', fontsize=9
            )

    plt.suptitle("Clustering Algorithm Metrics Comparison", fontsize=16)
    plt.savefig(output_file, format='pdf')
    plt.show()

    print(f"Saved plot to: {output_file}")


### uncomment to plot

# plot_clustering(results)