In [None]:
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
import pandas as pd
import torch
import os
import glob

def plot_tsne(embeddings, labels, file_path):
    # Ensure embeddings and labels have the same length
    min_len = min(len(embeddings), len(labels))
    embeddings = embeddings[:min_len]
    labels = labels[:min_len]

    # Perform t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)

    # Create a scatter plot
    plt.figure(figsize=(8, 7))
    colors = ['#2ecc71', '#e74c3c']  # Green for alive, Red for dead
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=[colors[int(label)] for label in labels], alpha=0.7)

    # Add legend
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', label='Alive patients', 
                                  markerfacecolor='#2ecc71', markersize=10),
                       plt.Line2D([0], [0], marker='o', color='w', label='Dead patients', 
                                  markerfacecolor='#e74c3c', markersize=10)]
    plt.legend(handles=legend_elements, loc="upper right", fontsize=14)

    # Add labels
    plt.xlabel('t-SNE feature 1', fontsize=12)
    plt.ylabel('t-SNE feature 2', fontsize=12)

    # Increase font size for tick labels
    plt.tick_params(axis='both', which='major', labelsize=12)

    # Save plot
    output_dir = 'analysis/icassp'
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, os.path.basename(file_path).replace('.pkl', '_tsne.pdf'))
    plt.tight_layout()
    plt.savefig(output_file, format='pdf', dpi=500)
    plt.close()

    print(f"t-SNE plot saved to {output_file}")

def process_pickle_file(file_path):
    # Load the pickle file
    data = pd.read_pickle(file_path)

    # Print the structure of the loaded data
    print(f"\nProcessing file: {file_path}")
    print("Data type:", type(data))
    if isinstance(data, dict):
        print("Keys:", data.keys())
        for key, value in data.items():
            print(f"{key}: type {type(value)}, shape {value.shape if hasattr(value, 'shape') else len(value)}")

    # Extract embeddings and labels
    if isinstance(data, dict) and 'embeddings' in data and 'labels' in data:
        embeddings = data['embeddings'].numpy() if torch.is_tensor(data['embeddings']) else data['embeddings']
        labels = data['labels'].numpy()[:,0,:] if torch.is_tensor(data['labels']) else data['labels']
    else:
        print("Unexpected data structure. Please check the contents of the pickle file.")
        return

    # Plot t-SNE
    plot_tsne(embeddings, labels, file_path)

# Process all pickle files in the specified directory
input_dir = 'analysis/icassp'  # Replace with the actual path
for pickle_file in glob.glob(os.path.join(input_dir, '*.pkl')):
    process_pickle_file(pickle_file)