### This notebook contains the tasks done on the Bacterial datasets' embeddings

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.lines import Line2D
import torch, json, re
from tqdm import tqdm

from scipy.cluster.hierarchy import dendrogram, linkage

from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import AgglomerativeClustering
from sklearn.model_selection import StratifiedKFold

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam

device = "cuda:0" if torch.cuda.is_available() else "cpu"

### Load and clean data

In [None]:
# load data
dna_df = pd.read_csv(f'dna_embeddings.csv')
prot_df = pd.read_csv(f'protein_embeddings.csv')

In [None]:
# load embeddings as numpy array and remove redundent columns
def load_as_np(df):
    df['embeddings'] = df['embeddings_json'].apply(lambda x: np.array(json.loads(x)))
    drop_cols = ['Unnamed: 0', 'embeddings_np', 'embeddings_tensor', 'embeddings_json']
    for col in drop_cols:
        if col in df.columns:
            df.drop(columns=[col], inplace=True)
    return df

dna_df = load_as_np(dna_df)
prot_df = load_as_np(prot_df)

In [None]:
def clean_organism(value):
    """The funtion remove strain if contained in the organism columnm, and removes redundent characters"""
    # Remove anything from 'str.' onward, including 'str.'
    value = re.sub(r'\s*str\..*', '', value)
    
    # Remove square brackets
    value = value.replace('[', '').replace(']', '')

    # Replace multiple spaces with a single space
    value = re.sub(r'\s+', ' ', value).strip()

    if value.endswith('sp.'):
        value = value[:-4]
    return value

dna_df['organism'] = dna_df['organism'].apply(clean_organism)
prot_df['organism'] = prot_df['organism'].apply(clean_organism)

In [None]:
# filter the df to rows with non-unique organism value
def filter_df(df):
    organism_counts = df['organism'].value_counts()
    values_to_keep = organism_counts[organism_counts > 1].index
    df = df[df['organism'].isin(values_to_keep)]
    return df

dna_df = filter_df(dna_df)
prot_df = filter_df(prot_df)

### Random forest classifier - predict organism or predict gene

In [None]:
size_before = len(dna_df)
while True:
    organisms_to_keep = [org for org, count in dna_df['organism'].value_counts().to_dict().items() if count > 5]
    dna_df_rfc = dna_df[dna_df['organism'].isin(organisms_to_keep)]
    prot_df_rfc = prot_df[prot_df['organism'].isin(organisms_to_keep)]

    genes_to_keep = [gene for gene, count in dna_df['gene'].value_counts().to_dict().items() if count > 5]
    dna_df_rfc = dna_df_rfc[dna_df_rfc['gene'].isin(genes_to_keep)]
    prot_df_rfc = prot_df_rfc[prot_df_rfc['gene'].isin(genes_to_keep)]
    
    if len(dna_df_rfc) == size_before:
        break
    size_before = len(dna_df_rfc)

In [None]:
def print_measures(y_test, y_pred):
    print(f"Classification Report:\n{classification_report(y_test, y_pred)}")

In [None]:
def random_forest_classifier(col_to_pred):
    # Preprocess the data
    X = pd.DataFrame(dna_df_rfc['embeddings'].tolist())
    y = dna_df_rfc[col_to_pred]
    
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    # Train a Random Forest Classifier
    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(X_train, y_train)
    
    # Make predictions
    y_pred = clf.predict(X_test)
    
    print_measures(y_test, y_pred)

    X = pd.DataFrame(prot_df_rfc['embeddings'].tolist())
    y = prot_df_rfc[col_to_pred]
    
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    # Train a Random Forest Classifier
    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(X_train, y_train)
    
    # Make predictions
    y_pred = clf.predict(X_test)
    
    print_measures(y_test, y_pred)

In [None]:
# predict the organism
random_forest_classifier(col_to_pred='organism')

In [None]:
# predict the gene
random_forest_classifier(col_to_pred='gene')

### Hierarchical clustering (by organism or gene)

In [None]:
# Genes
def genes_dendogram(df, color_threshold, model):
    # Take embedding's mean for each organism
    tmp_df = df.groupby('gene')['embeddings'].apply(lambda x: np.mean(np.array(x).tolist(), axis=0)).reset_index()

    X = np.array(tmp_df['embeddings'].tolist())
    Z = linkage(cosine_similarity(X), 'ward')

    # Plot the Dendrogram
    plt.figure(figsize=(20, 5))
    dendrogram(Z, labels=tmp_df['gene'].values, leaf_rotation=0, leaf_font_size=12, color_threshold=color_threshold)
    plt.title(f"Dendrogram of Genes Hierarchical Clustering\n{model}")
    plt.xlabel('Gene')
    plt.ylabel('Distance')
    plt.savefig(f"gene_dendogram_{model}", dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
genes_dendogram(dna_df, color_threshold=0.7, model='DNABERT-2')
genes_dendogram(prot_df, color_threshold=1.5, model='ProteinBERT')

In [None]:
# Organisms
def organism_dendogram(df, color_threshold, model):
    # Take embedding's mean for each organism
    tmp_df = df.groupby('organism')['embeddings'].apply(lambda x: np.mean(np.array(x).tolist(), axis=0)).reset_index()

    X = np.array(tmp_df['embeddings'].tolist())
    Z = linkage(cosine_similarity(X), 'ward')

    # Plot the Dendrogram
    plt.figure(figsize=(20, 5))
    dendrogram(Z, labels=tmp_df['organism'].values, leaf_rotation=90, leaf_font_size=3, color_threshold=color_threshold)
    plt.title(f"Dendrogram of Hierarchical Clustering\n{model}")
    plt.xlabel('Organism')
    plt.ylabel('Distance')
    plt.savefig(f"organism_dendogram_{model}", dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
organism_dendogram(dna_df, color_threshold=15, model='DNABERT-2')
organism_dendogram(prot_df, color_threshold=25, model='ProteinBERT')

### Neural Network

In [None]:
def create_model(input_dim, num_classes):
    model = Sequential([
        Dense(128, activation='relu', input_shape=(input_dim,)),
        Dropout(0.3),
        Dense(64, activation='relu'),
        Dropout(0.2),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

In [None]:
def run_organism_and_genus_histories(df, label_encoder, create_model_fn, n_splits=5, epochs=100):
    X = np.array(df['embeddings'].tolist())
    input_dim = X.shape[1]

    # Encode labels
    y_full = label_encoder.fit_transform(df['organism'])
    y_genus = label_encoder.fit_transform(df['organism'].str.split().str[0])
    y_gene = label_encoder.fit_transform(df['gene'])
    
    label_sets = {
        "Organism Prediction": y_full,
        "Organism Prediction (Genus Only)": y_genus,
        "Gene": y_gene
    }

    all_histories = {}

    for title, y in label_sets.items():
        histories = []
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

        for fold, (train_index, val_index) in enumerate(skf.split(X, y), 1):
            print(fold)
            X_train, X_val = X[train_index], X[val_index]
            y_train, y_val = y[train_index], y[val_index]

            model = create_model_fn(input_dim, len(np.unique(y)))
            history = model.fit(
                X_train, y_train,
                validation_data=(X_val, y_val),
                epochs=epochs,
                batch_size=32,
                callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)],
                verbose=0
            )
            histories.append(history)

        all_histories[title] = histories

    return all_histories

In [None]:
def plot_cross_validation_average(dup_hist_lst, dup_titles, suptitle, fig_name):
    metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy']
    min_epochs = min(
        min(len(h.history['loss']) for h in histories)
        for histories_list in dup_hist_lst
        for histories in histories_list
    )
    plt.figure(figsize=(18, 5))
    
    dna_colors = cm.Blues(np.linspace(0.5, 0.9, len(dup_hist_lst[0])))
    protein_colors = cm.Reds(np.linspace(0.2, 0.4, len(dup_hist_lst[1])))
    all_colors = [dna_colors, protein_colors]
    
    for j, all_histories_list in enumerate(dup_hist_lst):
        titles = dup_titles[j]
        colors = all_colors[j]
        
        for i, (m_train, m_val, metric_title) in enumerate([
            ('loss', 'val_loss', 'Loss'),
            ('accuracy', 'val_accuracy', 'Accuracy')
        ]):
            ax = plt.subplot(1, 2, i + 1)
            epochs = np.arange(1, min_epochs + 1)
            for histories, title, color in zip(all_histories_list, titles, colors):
                data_train = np.array([h.history[m_train][:min_epochs] for h in histories])
                data_val = np.array([h.history[m_val][:min_epochs] for h in histories])
                mean_train = np.mean(data_train, axis=0)
                std_train = np.std(data_train, axis=0)
                mean_val = np.mean(data_val, axis=0)
                std_val = np.std(data_val, axis=0)
                ax.plot(epochs, mean_train, linestyle='--', color=color)
                ax.fill_between(epochs, mean_train - std_train, mean_train + std_train, alpha=0.2, color=color)
                ax.plot(epochs, mean_val, linestyle='-', color=color)
                ax.fill_between(epochs, mean_val - std_val, mean_val + std_val, alpha=0.2, color=color)
            ax.set_title(f'{metric_title} (Mean ± SD)')
            ax.set_xlabel("Epochs")
            ax.set_ylabel(metric_title)
    
    legend_elements = []
    data_types = ['DNA', 'Protein']
    title_prefixes = ['Train full name prediction', 'Validation full name prediction', 
                      'Train genus only prediction', 'Validation genus only prediction']
    
    for j, (data_type, colors) in enumerate(zip(data_types, all_colors)):
        for k, (title, color) in enumerate(zip(dup_titles[j], colors)):
            if k % 2 == 0:  # Full name
                legend_elements.append(Line2D([0], [0], color=color, linestyle='--', 
                                            label=f"Train, {title}"))
                legend_elements.append(Line2D([0], [0], color=color, linestyle='-', 
                                            label=f"Validation, {title}"))
            else:  # Genus only
                legend_elements.append(Line2D([0], [0], color=color, linestyle='--', 
                                            label=f"Train, {title}"))
                legend_elements.append(Line2D([0], [0], color=color, linestyle='-', 
                                            label=f"Validation, {title}"))
    
    plt.figlegend(handles=legend_elements, loc='lower center', ncol=4, bbox_to_anchor=(0.5, -0.05))
    plt.suptitle(suptitle, fontsize=14)
    plt.subplots_adjust(bottom=0.15)
    plt.savefig(f"{fig_name}", dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
label_encoder = LabelEncoder()
prot_label_encoder = LabelEncoder()
histories = run_organism_and_genus_histories(dna_df, label_encoder, create_model)
prot_histories = run_organism_and_genus_histories(prot_df, prot_label_encoder, create_model)

In [None]:
histories_val = list(histories.values())
prot_histories_val = list(prot_histories.values())

# by orgainsm
dup_lst = [histories_val[:2], prot_histories_val[:2]]
titles = [['DNA, full name', 'DNA, genus only'], ['Protein, full name', 'Protein, genus only']]
suptitle = "Cross Validation Performance: Full Organism Name vs. Genus Only"
fig_name = 'cross_validation_organism.png'
plot_cross_validation_average(dup_lst, titles, suptitle, fig_name)

# by gene
dup_lst = [[histories_val[2]], [prot_histories_val[2]]]
titles = [['DNA, gene prediction'], ['Protein, gene prediction']]
suptitle = "Cross validation Performance: Gene Prediction"
fig_name = 'cross_validation_gene.png'
plot_cross_validation_average(dup_lst, titles, suptitle, fig_name)

### K- means clustering

In [None]:
# Elbow method

X_protein = np.array(prot_df['embeddings'].tolist())
X_dna = np.array(dna_df['embeddings'].tolist())
K = range(1, len(dna_df["gene"].unique()) *2)
inertia_protein = []
inertia_dna = []

for k in tqdm(K):
    kmeans_protein = KMeans(n_clusters=k, random_state=42)
    kmeans_protein.fit(X_protein)
    inertia_protein.append(kmeans_protein.inertia_)
    kmeans_dna = KMeans(n_clusters=k, random_state=42)
    kmeans_dna.fit(X_dna)
    inertia_dna.append(kmeans_dna.inertia_)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(K, inertia_protein, 'bx-')
plt.xlabel('k')
plt.ylabel('Inertia')
plt.title('Elbow Method for ProteinBERT')
plt.subplot(1, 2, 2)
plt.plot(K, inertia_dna, 'bx-')
plt.xlabel('k')
plt.ylabel('Inertia')
plt.title('Elbow Method for DNABERT2')
plt.savefig("Elbow_method_bacteria.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# kmeans clustering
num_clusters = 10
kmeans_protein = KMeans(n_clusters=num_clusters, random_state=42)
protein_labels = kmeans_protein.fit_predict(X_protein)
tsne = TSNE(n_components=2, random_state=42)
X_protein_tsne = tsne.fit_transform(X_protein)
kmeans_dna = KMeans(n_clusters=num_clusters, random_state=42)
dna_labels = kmeans_dna.fit_predict(X_dna)
X_dna_tsne = tsne.fit_transform(X_dna)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
plt.suptitle("t-SNE of KMeans Clustering", fontsize=15)
# Protein
ax1 = ax[0]
ax1.scatter(X_protein_tsne[:, 0], X_protein_tsne[:, 1], c=protein_labels, cmap='viridis', s=20)
ax1.set_title("ProteinBERT")
ax1.set_xlabel("t-SNE 1")
ax1.set_ylabel("t-SNE 2")
ax1.text(-0.1, 1.01, "A.", fontsize=15, ha='center', va='center', transform=ax1.transAxes)

# DNA
ax2 = ax[1]
ax2.scatter(X_dna_tsne[:, 0], X_dna_tsne[:, 1], c=dna_labels, cmap='viridis', s=20)
ax2.set_title("DNABERT-2")
ax2.set_xlabel("t-SNE 1")
ax2.set_ylabel("t-SNE 2")
ax2.text(-0.1, 1.01, "B.", fontsize=15, ha='center', va='center', transform=ax2.transAxes)

plt.tight_layout()
plt.savefig("tSNE.png", dpi=300, bbox_inches='tight')
plt.show()