In [None]:
# RUFUS/Mutect2 Union Upset Plot

# Note: to create multi-colored plot, this one was generated five times, swapping out the colors, and then cropping the images appropriately together.

# Numbers copied from upset_counts.txt output generated by mut_ruf_isec_all.sh
union_isecs={
'WashU': 9526,
'UW': 4527,
'NYGC': 1264,
'Broad': 1045,
'BCM': 4087,
'UW&WashU': 644,
'NYGC&WashU': 119, 
'NYGC&UW': 95,
'Broad&WashU': 37,
'Broad&UW': 43,
'Broad&NYGC': 24,
'BCM&WashU': 656,
'BCM&UW': 396,
'BCM&NYGC': 84,
'BCM&Broad': 42,
'NYGC&UW&WashU': 92,
'Broad&UW&WashU': 54,
'Broad&NYGC&WashU': 18,
'Broad&NYGC&UW': 6,
'BCM&UW&WashU': 658, 
'BCM&NYGC&WashU': 103, 
'BCM&NYGC&UW': 70,
'BCM&Broad&WashU': 54,
'BCM&Broad&UW': 44,
'BCM&Broad&NYGC': 11,
'Broad&NYGC&UW&WashU': 38,
'BCM&NYGC&UW&WashU': 584,
'BCM&Broad&UW&WashU': 255, 
'BCM&Broad&NYGC&WashU': 38,
'BCM&Broad&NYGC&UW': 24,
'BCM&Broad&NYGC&UW&WashU': 4700
}

# Upset plots
import pandas as pd
import numpy as np
from upsetplot import UpSet
import matplotlib.pyplot as plt
import warnings

# Suppress pandas future warnings related to inplace operations
warnings.filterwarnings('ignore', category=FutureWarning)

def create_upset_from_intersection_counts(intersection_data, sample_names=None):
    # If sample_names not provided, infer from the data
    if sample_names is None:
        all_samples = set()
        for intersection_str in intersection_data.keys():
            samples_in_intersection = intersection_str.split('&')
            all_samples.update(samples_in_intersection)
        sample_names = sorted(list(all_samples))
    
    print(f"Detected samples: {sample_names}")
    
    # Create multi-index data more explicitly to avoid pandas warnings
    index_tuples = []
    counts = []
    
    for intersection_str, count in intersection_data.items():
        # Parse which samples are in this intersection
        samples_in_intersection = set(intersection_str.split('&'))
        
        # Create boolean tuple for this intersection
        membership = tuple(sample in samples_in_intersection for sample in sample_names)
        
        index_tuples.append(membership)
        counts.append(count)
    
    # Create MultiIndex explicitly
    multi_index = pd.MultiIndex.from_tuples(
        index_tuples, 
        names=sample_names
    )
    
    # Create Series with the counts - ensure it's properly formatted
    series = pd.Series(data=counts, index=multi_index, dtype=int)
    
    # Remove any zero counts to clean up the plot
    series = series[series > 0]
    
    # Sort by values (descending) to show largest intersections first
    series = series.sort_values(ascending=False)
    
    return series, sample_names

def plot_upset(series, sample_names, title="UpSet Plot", figsize=(20, 10), sample_colors=None, **kwargs):    
    # Suppress the tight_layout warning
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message='This figure includes Axes that are not compatible with tight_layout')
        
        # Default UpSet parameters (can be overridden by kwargs)
        upset_params = {
            'show_counts': True,
            'sort_by': 'degree',
            #'facecolor': '#d94701',
            'facecolor': '#016FB9',
            'shading_color': 'lightgray',
            'sort_categories_by': 'cardinality',
            'min_subset_size': 1,  # Show all intersections with at least 1 variant
            'element_size': None,  
            'totals_plot_elements': 6  # Limit number of total elements shown
        }
        upset_params.update(kwargs)
        
        try:
            upset = UpSet(series, **upset_params)
            fig = upset.plot(fig=plt.figure(figsize=(13,7.5)))

            
            # Now customize the dot colors in the matrix plot
            # Find the intersection matrix axes
            axes = fig.get_axes()
            matrix_ax = None
            
            # The matrix plot is typically one of the lower axes
            for ax in axes:
                # Look for the axis that has scatter plots (dots)
                if hasattr(ax, 'collections') and len(ax.collections) > 0:
                    # Check if this looks like the matrix plot
                    if any(hasattr(coll, 'get_offsets') and len(coll.get_offsets()) > 0 
                           for coll in ax.collections):
                        matrix_ax = ax
                        break
            
            if matrix_ax is not None:
                # Clear existing dots and redraw with custom colors
                matrix_ax.clear()
                
                # Get the intersection data to redraw the matrix
                intersections = []
                for idx, count in series.items():
                    intersections.append(idx)
                
                # Redraw the matrix with colored dots
                for col_idx, intersection in enumerate(intersections):
                    for row_idx, (sample_name, is_present) in enumerate(zip(sample_names, intersection)):
                        if is_present:
                            # Draw filled dot with sample-specific color
                            matrix_ax.scatter(col_idx, row_idx, 
                                            c=sample_colors[sample_name], 
                                            s=100, 
                                            edgecolors='black', 
                                            linewidths=0.5,
                                            zorder=3)
                        else:
                            # Draw gray dot for non-membership
                            matrix_ax.scatter(col_idx, row_idx, 
                                            c='lightgray', 
                                            s=30, 
                                            alpha=0.3,
                                            zorder=1)
                
                # Draw connecting lines for multi-sample intersections
                for col_idx, intersection in enumerate(intersections):
                    present_rows = [i for i, is_present in enumerate(intersection) if is_present]
                    if len(present_rows) > 1:
                        # Draw vertical line connecting present samples
                        min_row = min(present_rows)
                        max_row = max(present_rows)
                        matrix_ax.plot([col_idx, col_idx], [min_row, max_row], 
                                     'black', linewidth=2, alpha=0.6, zorder=2)
                
                # Set up the matrix axes
                matrix_ax.set_xlim(-0.5, len(intersections) - 0.5)
                matrix_ax.set_ylim(-0.5, len(sample_names) - 0.5)
                matrix_ax.set_xticks([])
                matrix_ax.set_yticks(range(len(sample_names)))
                matrix_ax.set_yticklabels(sample_names)
                matrix_ax.invert_yaxis()
                matrix_ax.grid(True, alpha=0.3)
                
                # Remove x-axis labels as they're not needed in the matrix
                matrix_ax.set_xlabel('')
            
            # Add title with manual positioning to avoid layout issues
            plt.suptitle(title, fontsize=14, y=0.98)
            
            # Use subplots_adjust instead of tight_layout for better control
            plt.subplots_adjust(top=0.92, bottom=0.1, left=0.1, right=0.95, hspace=0.3, wspace=0.3)
            
            # Add a color legend for samples only
            from matplotlib.patches import Patch
            legend_elements = [Patch(facecolor=color, edgecolor='black', label=name) 
                             for name, color in sample_colors.items()]
            plt.figlegend(handles=legend_elements, 
                         loc='upper left', 
                         bbox_to_anchor=(0.02, 0.98),
                         title='Samples')
            
            plt.show()
            
            return upset
            
        except Exception as e:
             return None

def validate_intersection_data(intersection_data):
    if not isinstance(intersection_data, dict):
        raise ValueError("intersection_data must be a dictionary")
    
    if len(intersection_data) == 0:
        raise ValueError("intersection_data cannot be empty")
    
    # Check for valid format
    for key, value in intersection_data.items():
        if not isinstance(key, str):
            raise ValueError(f"All keys must be strings, found: {type(key)}")
        
        if not isinstance(value, (int, float)):
            raise ValueError(f"All values must be numeric, found: {type(value)} for key '{key}'")
        
        if value < 0:
            raise ValueError(f"All values must be non-negative, found: {value} for key '{key}'")
        
        # Check that intersection string uses '&' separator
        if '&' not in key and len(key.split('&')) == 1:
            # Single sample - this is OK
            pass
        elif '&' in key:
            # Multiple samples - check format
            samples = key.split('&')
            if any(len(sample.strip()) == 0 for sample in samples):
                raise ValueError(f"Invalid intersection format: '{key}'. Check for empty sample names.")
    
    return True

def angle_upset_labels(fig, angle=30):
    axes = fig.get_axes()
    
    # Find the bar chart axis
    for ax in axes:
        if hasattr(ax, 'patches') and len(ax.patches) > 0:
            # Check if this has Rectangle patches (bars)
            if any(hasattr(patch, 'get_width') for patch in ax.patches):
                # This is likely the bar chart
                labels = ax.get_xticklabels()
                ax.set_xticklabels(labels, 
                                 rotation=angle, 
                                 ha='right',
                                 va='top')
                # Adjust bottom margin
                fig.subplots_adjust(bottom=0.2)
                break
    
    return fig


def main():
    
    # Validate input data
    validate_intersection_data(union_isecs)
    
    # Specify sample names explicitly (optional - will be inferred if not provided)
    sample_names = ['BCM', 'Broad', 'NYGC', 'UW', 'WashU', 'Truth Set']
    
    # Create the UpSet plot
    series, detected_samples = create_upset_from_intersection_counts(
        union_isecs, sample_names
    )
    
    # print(f"\nCreated series with {len(series)} intersections")
    # print(f"Sample order: {detected_samples}")
    
    upset = plot_upset(
        series, 
        detected_samples,
        title="Variant Intersections Called By RUFUS In GCC Combinations Not In Truth Set",
        min_subset_size=1,     # Show all intersections with at least 1 variant
        show_counts=True,      # Show count numbers on bars
        sort_by='-degree', # Sort by intersection size
        sort_categories_by='-input'  # Sort sample categories by size
    )
    return series, upset

def angle_upset_labels(fig, angle=30):
    axes = fig.get_axes()
    
    # Find the bar chart axis - it's usually the one with the most patches (bars)
    bar_chart_ax = None
    max_patches = 0
    
    for ax in axes:
        if hasattr(ax, 'patches') and len(ax.patches) > 0:
            # Count rectangle patches (bars)
            rect_patches = [p for p in ax.patches if hasattr(p, 'get_width') and hasattr(p, 'get_height')]
            if len(rect_patches) > max_patches:
                max_patches = len(rect_patches)
                bar_chart_ax = ax
    
    if bar_chart_ax is not None:
        print(f"Found bar chart axis with {max_patches} bars")
        
        # Get current tick locations and labels
        tick_locs = bar_chart_ax.get_xticks()
        current_labels = [label.get_text() for label in bar_chart_ax.get_xticklabels()]
        
        print(f"Current labels: {current_labels[:5]}...")  # Show first 5 for debugging
        
        # Apply rotation to labels
        bar_chart_ax.set_xticklabels(current_labels, 
                                   rotation=angle, 
                                   ha='right',
                                   va='top')
        
        # Adjust bottom margin to accommodate angled labels
        fig.subplots_adjust(bottom=0.25)
        print(f"Applied {angle}Â° rotation to x-axis labels")
    else:
        print("Could not find bar chart axis")
        # Debug: print info about all axes
        for i, ax in enumerate(axes):
            print(f"Axis {i}: {len(ax.patches)} patches, {len(ax.get_xticklabels())} x-labels")
    
    return fig

if __name__ == "__main__":
    # Run the example
    series, upset = main()
    angle_upset_labels(plt.gcf(), angle=30)