In [None]:
from bertopic import BERTopic
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import os
from datetime import datetime
from umap import UMAP
import matplotlib.pyplot as plt
import seaborn as sns
from deep_translator import GoogleTranslator
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

In [None]:
import sys
sys.path.append('../..')

def auto_save_figure(fig, figure_name, save_formats=['png'], 
                    width=1000, height=700, scale=2, output_dir="../../outputs/plots"):
    """
    Automatically save any Plotly or Matplotlib figure to disk with multiple formats and timestamp
    """
    if fig is None:
        print(f"❌ No figure provided for '{figure_name}'")
        return {}
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Generate timestamp for unique filenames
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    saved_files = {}
    
    # Detect figure type
    is_plotly = hasattr(fig, 'write_image') and hasattr(fig, 'write_html')
    is_matplotlib = hasattr(fig, 'savefig')
    
    if not is_plotly and not is_matplotlib:
        print(f"❌ Unsupported figure type for '{figure_name}': {type(fig)}")
        return {}
    
    for format_type in save_formats:
        # Clean figure name for filename (remove spaces, special chars)
        clean_name = "".join(c for c in figure_name if c.isalnum() or c in (' ', '-', '_')).rstrip()
        clean_name = clean_name.replace(' ', '_').lower()
        
        filename = f"{timestamp}_{clean_name}.{format_type}"
        filepath = os.path.join(output_dir, filename)
        
        try:
            if is_plotly:
                # Handle Plotly figures
                if format_type == 'html':
                    fig.write_html(filepath)
                    print(f"📊 Saved {figure_name} as HTML: {filename}")
                elif format_type in ['png', 'pdf', 'svg', 'jpeg']:
                    fig.write_image(
                        filepath,
                        width=width,
                        height=height,
                        scale=scale,
                        format=format_type
                    )
                    print(f"🖼️  Saved {figure_name} as {format_type.upper()}: {filename}")
                    
            elif is_matplotlib:
                # Handle Matplotlib figures
                if format_type == 'html':
                    # Convert matplotlib to HTML via mpld3 (if available) or skip
                    try:
                        import mpld3
                        html_str = mpld3.fig_to_html(fig)
                        with open(filepath, 'w') as f:
                            f.write(html_str)
                        print(f"📊 Saved {figure_name} as HTML: {filename}")
                    except ImportError:
                        print(f"⚠️  Skipping HTML for matplotlib figure '{figure_name}' (mpld3 not available)")
                        continue
                elif format_type in ['png', 'pdf', 'svg', 'jpeg']:
                    # Set DPI based on scale
                    dpi = 100 * scale
                    fig.savefig(
                        filepath,
                        format=format_type,
                        dpi=dpi,
                        bbox_inches='tight',
                        facecolor='white',
                        edgecolor='none'
                    )
                    print(f"🖼️  Saved {figure_name} as {format_type.upper()}: {filename}")
            
            saved_files[format_type] = os.path.abspath(filepath)
            
        except Exception as e:
            print(f"❌ Error saving {figure_name} as {format_type}: {e}")
    
    if saved_files:
        print(f"✅ Total saved: {len(saved_files)} file(s) for '{figure_name}'")
        print(f"📁 Location: {os.path.abspath(output_dir)}")
        print("-" * 60)
    
    return saved_files

def load_bert_model(path):
    """Load BERTopic model with Portuguese sentence transformer"""
    return BERTopic.load(path, embedding_model=SentenceTransformer("PORTULAN/serafim-900m-portuguese-pt-sentence-encoder"))

In [None]:
# Define group configurations with descriptive names
groups_config = {
    "Female_ADHD": {
        "color": "#e6194b",  # Red
        "symbol": "circle",
        "size": 12,
        "display_name": "Women with ADHD",
        "description": "Portuguese women diagnosed with ADHD"
    },
    "Female_noADHD": {
        "color": "#3cb44b",  # Green
        "symbol": "circle",  # Different symbol for better distinction
        "size": 12,
        "display_name": "Women without ADHD",
        "description": "Portuguese women without ADHD diagnosis"
    },
    "ADHD": {
        "color": "#4363d8",  # Blue
        "symbol": "circle",  # Different symbol for better distinction
        "size": 12,
        "display_name": "Participants with ADHD",
        "description": "All participants (men and women) with ADHD diagnosis"
    },
    "noADHD": {
        "color": "#f58231",  # Orange
        "symbol": "circle",  # Different symbol for better distinction
        "size": 12,
        "display_name": "Participants without ADHD",
        "description": "All participants (men and women) without ADHD diagnosis"
    }
}

print("📋 Group configurations:")
for group, config in groups_config.items():
    print(f"  {config['display_name']}: {config['color']} {config['symbol']}")

# Function to find and load models
def find_model_folder(group_name):
    """Find the model folder for a given group"""
    base_path = f"../../data/adhd-beliefs-pt/bertopic_final/{group_name}/"
    if os.path.exists(base_path):
        folders = [name for name in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, name))]
        if folders:
            return os.path.join(base_path, folders[0])
    return None

# Load all models
models = {}
model_paths = {}

print("\n🔄 Loading BERTopic models...")
for group_name in groups_config.keys():
    model_path = find_model_folder(group_name)
    if model_path:
        print(f"  Loading {group_name} from {model_path}")
        try:
            models[group_name] = load_bert_model(model_path)
            model_paths[group_name] = model_path
            print(f"    ✅ {group_name} loaded successfully")
        except Exception as e:
            print(f"    ❌ Failed to load {group_name}: {e}")
    else:
        print(f"    ⚠️  Model path not found for {group_name}")

print(f"\n📊 Successfully loaded {len(models)} models: {list(models.keys())}")

In [None]:
def extract_topic_representations(topic_model, group_name):
    """
    Extract the exact pre-computed UMAP coordinates from BERTopic's visualization
    """
    # Get topic info
    topic_info = topic_model.get_topic_info()
    real_topics = topic_info[topic_info['Topic'] != -1].copy()
    
    if len(real_topics) == 0:
        print(f"⚠️  No real topics found for {group_name}")
        return None
    
    try:
        # Create the exact same visualization that BERTopic would create
        # This uses the model's cached/pre-computed coordinates
        fig_temp = topic_model.visualize_topics(custom_labels=True)
        
        # Extract coordinates from the figure data
        if fig_temp and fig_temp.data:
            # The first trace contains the topic positions
            trace_data = fig_temp.data[0]
            x_coords = list(trace_data.x)
            y_coords = list(trace_data.y)
            
            # Get topic labels from the visualization
            topic_labels = []
            if hasattr(trace_data, 'text') and trace_data.text:
                topic_labels = list(trace_data.text)
            else:
                # Fallback to custom labels or topic names
                for idx, row in real_topics.iterrows():
                    topic_id = row['Topic']
                    if 'CustomName' in row and pd.notna(row['CustomName']):
                        label = row['CustomName']
                    elif 'Name' in row and pd.notna(row['Name']):
                        label = row['Name']
                    else:
                        topics_dict = topic_model.get_topics()
                        if topic_id in topics_dict:
                            top_words = [word for word, score in topics_dict[topic_id][:3]]
                            label = f"{', '.join(top_words)}"
                        else:
                            label = f"Topic {topic_id}"
                    topic_labels.append(label)
            
            # Create the positions array
            topic_positions_2d = np.column_stack((x_coords, y_coords))
            
            metadata = {
                'group': group_name,
                'topic_ids': real_topics['Topic'].tolist(),
                'topic_labels': topic_labels[:len(real_topics)],  # Ensure same length
                'topic_counts': real_topics['Count'].tolist() if 'Count' in real_topics.columns else [0] * len(real_topics),
                'topic_positions_2d': topic_positions_2d
            }
            
            print(f"  📊 {group_name}: {len(real_topics)} topics, exact UMAP coordinates extracted")
            return metadata
            
    except Exception as e:
        print(f"  ❌ Could not extract pre-computed coordinates for {group_name}: {e}")
        return None

In [None]:
def create_combined_visualization_from_umap(all_representations, groups_config):
    """
    Create visualization directly from existing UMAP representations with descriptive legends
    and marker sizes proportional to document counts
    """
    print("🎨 Creating combined visualization from existing UMAP representations...")
    
    fig = go.Figure()
    
    # Calculate global size scaling parameters
    all_counts = []
    for repr_data in all_representations.values():
        all_counts.extend(repr_data['topic_counts'])
    
    if all_counts:
        min_count = min(all_counts)
        max_count = max(all_counts)
        print(f"📊 Document count range: {min_count} - {max_count}")
        
        # Define size range for markers
        min_size = 8   # Minimum marker size
        max_size = 30  # Maximum marker size
        
        # Create size scaling function
        def scale_size(count):
            if max_count == min_count:  # Avoid division by zero
                return (min_size + max_size) / 2
            # Linear scaling from min_size to max_size
            normalized = (count - min_count) / (max_count - min_count)
            return min_size + normalized * (max_size - min_size)
    else:
        print("⚠️  No document counts found, using default sizes")
        def scale_size(count):
            return 12
    
    # Add traces for each group
    for group_name, repr_data in all_representations.items():
        if 'topic_positions_2d' not in repr_data:
            continue
            
        positions_2d = repr_data['topic_positions_2d']
        group_x = positions_2d[:, 0]
        group_y = positions_2d[:, 1]
        topic_counts = repr_data['topic_counts']
        
        # Calculate sizes for this group
        marker_sizes = [scale_size(count) for count in topic_counts]
        
        # Get group configuration
        config = groups_config[group_name]
        
        # Create hover text with descriptive names and document counts
        hover_text = []
        for i in range(len(group_x)):
            hover_text.append(
                f"<b>{config['display_name']}</b><br>" +
                f"Group: {group_name}<br>" +
                f"Topic {repr_data['topic_ids'][i]}<br>" +
                f"Label: {repr_data['topic_labels'][i]}<br>" +
                f"<b>Documents: {repr_data['topic_counts'][i]}</b><br>" +
                f"Position: ({group_x[i]:.2f}, {group_y[i]:.2f})"
            )
        
        # Add scatter trace with proportional marker sizes
        fig.add_trace(go.Scatter(
            x=group_x,
            y=group_y,
            mode='markers',
            marker=dict(
                color=config['color'],
                symbol=config['symbol'],
                size=marker_sizes,  # Use calculated sizes
                line=dict(width=1, color='rgba(0,0,0,0.3)'),
                opacity=0.8,
                sizemode='diameter'  # Ensures size represents diameter
            ),
            name=config['display_name'],  # Use descriptive name for legend
            text=hover_text,
            hovertemplate='%{text}<extra></extra>',
            showlegend=True
        ))
        
        print(f"  {config['display_name']}: {len(marker_sizes)} topics, sizes {min(marker_sizes):.1f}-{max(marker_sizes):.1f}")
    
    # Update layout with improved styling
    fig.update_layout(
        title=dict(
            text="<b>Combined Topic Clusters Across Groups</b><br>",
            x=0.5,
            xanchor='center',
            font=dict(size=18)
        ),
        xaxis=dict(
            title="D1",
            showgrid=True,
            gridcolor='rgba(0,0,0,0.1)',
            zeroline=True,
            zerolinecolor='rgba(0,0,0,0.3)',
            showline=True,
            linewidth=1,
            linecolor='black'
        ),
        yaxis=dict(
            title="D2",
            showgrid=True,
            gridcolor='rgba(0,0,0,0.1)',
            zeroline=True,
            zerolinecolor='rgba(0,0,0,0.3)',
            showline=True,
            linewidth=1,
            linecolor='black'
        ),
        legend=dict(
            orientation="v",
            yanchor="top",
            y=1,
            xanchor="left",
            x=1.02,
            bgcolor='rgba(255,255,255,0.9)',
            bordercolor='rgba(0,0,0,0.3)',
            borderwidth=1,
            font=dict(size=12)
        ),
        width=1400,  # Slightly wider to accommodate legend
        height=800,
        plot_bgcolor='white',
        paper_bgcolor='white',
        margin=dict(l=80, r=200, t=120, b=80)  # More top margin for subtitle
    )
      
    return fig

In [None]:
# Extract representations using existing UMAP
print("🔄 Extracting topic representations from existing UMAP...")
all_representations = {}

for group_name, model in models.items():
    config = groups_config[group_name]
    print(f"\n🔍 Processing {config['display_name']} ({group_name})...")
    repr_data = extract_topic_representations(model, group_name)
    if repr_data is not None:
        all_representations[group_name] = repr_data
        print(f"  ✅ {config['display_name']}: {len(repr_data['topic_ids'])} topics")
    else:
        print(f"  ❌ Failed to extract representations for {config['display_name']}")

print(f"\n📈 Successfully extracted representations from {len(all_representations)} groups")

# Create visualization
print("🎨 Displaying Combined Topic Clusters with Descriptive Labels:")
fig_combined = create_combined_visualization_from_umap(all_representations, groups_config)
display(fig_combined)

In [None]:
# Save visualizations
print("💾 Saving visualizations...")

# Save UMAP version
umap_files = auto_save_figure(
    fig_combined, 
    "combined_topic_clusters_umap_all_groups",
    save_formats=['png', 'html'],
    width=1200,
    height=800,
    scale=2,
    output_dir="../../outputs/combined_visualizations"
)
print("✅ All visualizations saved successfully!")