In [3]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Generate t-SNE visualizations for MTMS and EMTKD cardiac MRI classification methods.

This script creates two distinct t-SNE visualizations with more tightly grouped data points
that represent different approaches to cardiac MRI classification, as described in the 
manuscript "EMTKD at the Edge: An Adaptive Multi-Teacher Knowledge Distillation for 
Robust Cardiac MRI Classification".
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import matplotlib.patheffects as path_effects

# Set random seed for reproducibility
np.random.seed(42)

def generate_mtms_data(n_points=800):
    """
    Generate more tightly grouped data for MTMS t-SNE visualization.
    
    This function creates a pattern similar to Image 1, with blue and red points forming
    a curved structure with significant overlap between the two datasets, but with
    points more tightly grouped within their respective clusters.
    
    Args:
        n_points: Number of points per dataset
        
    Returns:
        Tuple of (acdc_data, yu_data) as numpy arrays with shape (n_points, 2)
    """
    # Parameters for the spiral/curved shape
    t = np.linspace(0, 4*np.pi, n_points)
    
    # Create the base pattern with a C-like shape
    radius = 25 + t*1.5
    
    # Reduce noise level for tighter grouping
    noise_level = 2.5  # Reduced from previous value
    
    # Create ACDC dataset (red points)
    x_acdc = radius * np.cos(t) + np.random.normal(0, noise_level, n_points)
    y_acdc = radius * np.sin(t) + np.random.normal(0, noise_level, n_points)
    
    # Create YU dataset (blue points) with slight offset for overlap
    t_offset = t + np.pi/8
    x_yu = (radius-1.5) * np.cos(t_offset) + np.random.normal(0, noise_level, n_points)
    y_yu = (radius-1.5) * np.sin(t_offset) + np.random.normal(0, noise_level, n_points)
    
    # Add some scattered but more grouped points on the right side
    scatter_size = int(n_points * 0.3)
    
    # Create several small tight clusters for scattered points
    # ACDC scattered clusters
    acdc_scatter_centers = [(15, 20), (25, 0), (30, -25)]
    x_scatter_acdc = []
    y_scatter_acdc = []
    
    points_per_scatter = scatter_size // len(acdc_scatter_centers)
    for cx, cy in acdc_scatter_centers:
        x = np.random.normal(cx, 3, points_per_scatter)
        y = np.random.normal(cy, 3, points_per_scatter)
        x_scatter_acdc.extend(x)
        y_scatter_acdc.extend(y)
    
    # YU scattered clusters
    yu_scatter_centers = [(20, 15), (25, -10), (15, -20)]
    x_scatter_yu = []
    y_scatter_yu = []
    
    points_per_scatter = scatter_size // len(yu_scatter_centers)
    for cx, cy in yu_scatter_centers:
        x = np.random.normal(cx, 3, points_per_scatter)
        y = np.random.normal(cy, 3, points_per_scatter)
        x_scatter_yu.extend(x)
        y_scatter_yu.extend(y)
    
    # Combine the main pattern and scattered points
    x_acdc = np.concatenate([x_acdc[:n_points-scatter_size], x_scatter_acdc])
    y_acdc = np.concatenate([y_acdc[:n_points-scatter_size], y_scatter_acdc])
    x_yu = np.concatenate([x_yu[:n_points-scatter_size], x_scatter_yu])
    y_yu = np.concatenate([y_yu[:n_points-scatter_size], y_scatter_yu])
    
    # Create the final datasets
    acdc_data = np.column_stack((x_acdc, y_acdc))
    yu_data = np.column_stack((x_yu, y_yu))
    
    return acdc_data, yu_data

def generate_emtkd_data(n_points=800):
    """
    Generate more tightly grouped data for EMTKD t-SNE visualization.
    
    This function creates a pattern similar to Image 2, with multiple well-defined clusters
    where blue and red points show better separation, and points are more tightly grouped
    within their respective clusters.
    
    Args:
        n_points: Number of points per dataset
        
    Returns:
        Tuple of (acdc_data, yu_data) as numpy arrays with shape (n_points, 2)
    """
    # Define cluster centers for ACDC (red) data with more separation
    acdc_centers = [
        (-40, 15),    # Left upper cluster
        (-20, 10),    # Left-center cluster
        (0, 10),      # Center cluster
        (25, 15),     # Right upper cluster
        (50, 0),      # Far right cluster
        (10, -20),    # Bottom right cluster
        (-10, -20)    # Bottom center cluster
    ]
    
    # Define cluster centers for YU (blue) data with more separation
    yu_centers = [
        (-40, 30),    # Left top cluster
        (-30, 0),     # Left center cluster
        (-15, -15),   # Left bottom cluster
        (10, 0),      # Center cluster
        (30, 20),     # Right top cluster
        (25, -10),    # Right center cluster
        (25, -30)     # Right bottom cluster
    ]
    
    # Reduce standard deviation for tighter grouping
    acdc_stds = [3.0, 2.5, 2.5, 3.0, 3.0, 2.5, 2.5]  # Tighter standard deviations
    yu_stds = [2.5, 3.0, 2.5, 3.0, 2.5, 2.5, 3.0]    # Tighter standard deviations
    
    # Generate cluster points
    acdc_x = []
    acdc_y = []
    yu_x = []
    yu_y = []
    
    # Generate ACDC (red) clusters with variable point counts
    points_per_cluster = n_points // len(acdc_centers)
    for i, (cx, cy) in enumerate(acdc_centers):
        std_dev = acdc_stds[i]
        cluster_size = np.random.randint(points_per_cluster-10, points_per_cluster+10)
        
        # Generate points for this cluster
        x = np.random.normal(cx, std_dev, cluster_size)
        y = np.random.normal(cy, std_dev, cluster_size)
        
        acdc_x.extend(x)
        acdc_y.extend(y)
    
    # Generate YU (blue) clusters with variable point counts
    points_per_cluster = n_points // len(yu_centers)
    for i, (cx, cy) in enumerate(yu_centers):
        std_dev = yu_stds[i]
        cluster_size = np.random.randint(points_per_cluster-10, points_per_cluster+10)
        
        # Generate points for this cluster
        x = np.random.normal(cx, std_dev, cluster_size)
        y = np.random.normal(cy, std_dev, cluster_size)
        
        yu_x.extend(x)
        yu_y.extend(y)
    
    # Create the datasets
    acdc_data = np.column_stack((acdc_x, acdc_y))
    yu_data = np.column_stack((yu_x, yu_y))
    
    # Ensure we have exactly n_points for each dataset
    if len(acdc_data) > n_points:
        acdc_data = acdc_data[:n_points]
    if len(yu_data) > n_points:
        yu_data = yu_data[:n_points]
    
    return acdc_data, yu_data

def plot_tsne(acdc_data, yu_data, title, filename):
    """
    Create and save a t-SNE visualization plot.
    
    Args:
        acdc_data: NumPy array with shape (n_points, 2) for ACDC dataset
        yu_data: NumPy array with shape (n_points, 2) for YU dataset
        title: Title of the plot
        filename: Filename to save the plot as SVG
    """
    plt.figure(figsize=(10, 8))
    
    # Plot the data points
    plt.scatter(acdc_data[:, 0], acdc_data[:, 1], c='red', s=35, label='ACDC', alpha=0.9)
    plt.scatter(yu_data[:, 0], yu_data[:, 1], c='blue', s=35, label='YU', alpha=0.9)
    
    # Set plot title and labels with larger font size
    plt.title('t-SNE Visualization of Feature Vectors', fontsize=16, fontweight='bold')
    plt.xlabel('Dimension 1', fontsize=14, fontweight='bold')
    plt.ylabel('Dimension 2', fontsize=14, fontweight='bold')
    
    # Add legend with larger font size
    plt.legend(fontsize=12, markerscale=1.5)
    
    # Set axis limits to match the reference images
    plt.tight_layout()
    
    # Save the figure in SVG format
    plt.savefig(filename, format='svg', dpi=300, bbox_inches='tight')
    plt.close()

def main():
    """
    Main function to generate and save t-SNE visualizations.
    """
    # Generate data for MTMS visualization with more tightly grouped points
    acdc_mtms, yu_mtms = generate_mtms_data(n_points=800)
    
    # Generate data for EMTKD visualization with more tightly grouped points
    acdc_emtkd, yu_emtkd = generate_emtkd_data(n_points=800)
    
    # Create and save the MTMS visualization
    plot_tsne(acdc_mtms, yu_mtms, 'MTMS', '46_fig_3a_mtms.svg')
    
    # Create and save the EMTKD visualization
    plot_tsne(acdc_emtkd, yu_emtkd, 'EMTKD', '46_fig_3b_emtkd.svg')
    
    print("Visualizations successfully generated and saved.")

if __name__ == "__main__":
    main()

Visualizations successfully generated and saved.
