In [None]:
import altair as alt
import pandas as pd

# To generate the five S4 stacked histograms, this code was run five times
# using the input data corresponding to the five allele frequency points.
# E.g. to generate the chart for the threshold of 2:
# - edit line 26 to "2"
# - edit line 34 to the variable union_pos_2plus
# - edit line 37 to the variable union_homog_neg_2plus

# Data derived from transpose_counts.sh output
union_pos_1plus=[4778, 1710, 1071, 929, 4694]
union_homog_neg_1plus=[1434, 147, 51, 46, 79]
union_pos_2plus=[3921, 1634, 1056, 925, 4694]
union_homog_neg_2plus=[1007, 104, 29, 32, 32]
union_pos_5plus=[2208, 1119, 869, 886, 4691]
union_homog_neg_5plus=[798, 84, 20, 23, 15]
union_pos_8plus=[1383, 555, 553, 774, 4667]
union_homog_neg_8plus=[694, 78, 19, 22, 13]
union_pos_10plus=[1144, 324, 379, 702, 4644]
union_homog_neg_10plus=[657, 72, 18, 19, 12]
union_coverage=[20374, 2132, 1109, 939, 4699] # Took the lower of the two coverages

# Color schemes
purple_range=['#756bb1', '#9e9cad', '#a6611a', '#dfc27d']

reads="8" # Fill in per respective threshold
title="Validation Rates Using " + reads + "+ PacBio Reads"
sub="Identical Negative Control Sites in ST001-1A Files (250x)"

# Adjust as necessary per threshold
data = pd.DataFrame({
    'category': ['Σ C(5,1)', 'Σ C(5,2)', 'Σ C(5,3)', 'Σ C(5,4)', 'Σ C(5,5)'],
    'positive_total': union_coverage,
    'positive_validated': union_pos_8plus, # BLT50 validated variants for actual sites
    'negative_total': union_coverage,
    'negative_validated': union_homog_neg_8plus # BLT50 validated variants in same site in unrelated control
})

# Create stacked data structure
data_stacked = []
for _, row in data.iterrows():
    category = row['category']
    
    # Calculate percentages for this category
    pos_validation_rate = (row['positive_validated'] / row['positive_total'] * 100) if row['positive_total'] > 0 else 0
    neg_validation_rate = (row['negative_validated'] / row['negative_total'] * 100) if row['negative_total'] > 0 else 0
    
    # Calculate unvalidated counts (difference between total and validated)
    pos_unvalidated = row['positive_total'] - row['positive_validated']
    neg_unvalidated = row['negative_total'] - row['negative_validated']
    
    # Positive side - unvalidated variants (bottom of stack)
    data_stacked.append({
        'category': category,
        'direction': 'positive',
        'variant_type': 'unvalidated',
        'variant_direction': 'unvalidated_positive',
        'count': pos_unvalidated,
        'stack_order': 0,
        'validation_rate': pos_validation_rate,
        'total_count': row['positive_total'],
        'validated_count': row['positive_validated']
    })
    
    # Positive side - validated variants (stacked on top)
    data_stacked.append({
        'category': category,
        'direction': 'positive',
        'variant_type': 'validated',
        'variant_direction': 'validated_positive',
        'count': row['positive_validated'],
        'stack_order': 1,
        'validation_rate': pos_validation_rate,
        'total_count': row['positive_total'],
        'validated_count': row['positive_validated']
    })
    
    # Negative side - unvalidated variants (displayed as negative, bottom of stack)
    data_stacked.append({
        'category': category,
        'direction': 'negative',
        'variant_type': 'unvalidated',
        'variant_direction': 'unvalidated_negative',
        'count': -neg_unvalidated,
        'stack_order': 0,
        'validation_rate': neg_validation_rate,
        'total_count': row['negative_total'],
        'validated_count': row['negative_validated']
    })
    
    # Negative side - validated variants (stacked on negative side)
    data_stacked.append({
        'category': category,
        'direction': 'negative',
        'variant_type': 'validated',
        'variant_direction': 'validated_negative',
        'count': -row['negative_validated'],
        'stack_order': 1,
        'validation_rate': neg_validation_rate,
        'total_count': row['negative_total'],
        'validated_count': row['negative_validated']
    })

data_long = pd.DataFrame(data_stacked)

# Create the bidirectional stacked bar chart
chart = alt.Chart(data_long).mark_bar().encode(
    x=alt.X('category:O',
            sort=['Σ C(5,5)', 'Σ C(5,4)', 'Σ C(5,3)', 'Σ C(5,2)', 'Σ C(5,1)'],
            title='Category',
            axis=alt.Axis(labelAngle=60,labelFontSize=20, titleFontSize=22)),
        y=alt.Y('count:Q',
            title='Count',
            scale=alt.Scale(domain=[-5500, 5500]), # Adjust domain as needed
            stack='zero',
            axis=alt.Axis(labelFontSize=20, titleFontSize=22, labelExpr="abs(datum.value)")),
    color=alt.Color('variant_direction:N',
                    scale=alt.Scale(
                        domain=['validated_positive', 'unvalidated_positive', 'validated_negative', 'unvalidated_negative'],
                        range=purple_range
                    ),
                    legend=None),
                    # legend=alt.Legend(
                    #     title="Variant Type",
                    #     orient='bottom',
                    #     # legendX=150,
                    #     # legendY=700,
                    #     direction='horizontal',
                    #     symbolType='square',
                    #     symbolSize=400,
                    #     titleFontSize=22,
                    #     labelFontSize=16,
                    #     labelExpr="datum.value == 'unvalidated_positive' ? 'Non-Validated' : datum.value == 'validated_positive' ? 'Validated' : datum.value == 'unvalidated_negative' ? 'Non-Validated Control' : 'Validated Control'"
                    # )),
    order=alt.Order('stack_order:O'),
    tooltip=['category:O', 'direction:N', 'variant_type:N', 'count:Q', 'validation_rate:Q']
).properties(
    width=1200,
    height=600,
    title=alt.TitleParams(
        text=title,
        #subtitle=sub,
        fontSize=24,
        fontWeight='bold',
        subtitleFontSize=16,
        dy=-10,
        dx=100,
        anchor='start'
    )
)

# Create data for percentage labels (one per category per direction)
label_data = []
for category in data['category'].unique():
    # Positive side label
    pos_row = data[data['category'] == category].iloc[0]
    pos_validation_rate = (pos_row['positive_validated'] / pos_row['positive_total'] * 100) if pos_row['positive_total'] > 0 else 0
    pos_y_position = pos_row['positive_total']  # Top of positive stack (total height)
    
    label_data.append({
        'category': category,
        'y_position': pos_y_position + 50,  # Slightly above the bar
        'percentage': f"{pos_validation_rate:.1f}% ({pos_row['positive_validated']}/{pos_row['positive_total']})",
        'direction': 'positive'
    })
    
    # Negative side label
    neg_validation_rate = (pos_row['negative_validated'] / pos_row['negative_total'] * 100) if pos_row['negative_total'] > 0 else 0
    neg_y_position = -pos_row['negative_total']  # Bottom of negative stack (total height)
    
    label_data.append({
        'category': category,
        'y_position': neg_y_position - 50,  # Slightly below the bar
        'percentage': f"{neg_validation_rate:.1f}% ({pos_row['negative_validated']}/{pos_row['negative_total']})",
        'direction': 'negative'
    })

label_df = pd.DataFrame(label_data)

# Create percentage labels
percentage_labels = alt.Chart(label_df).mark_text(
    align='center',
    baseline='middle',
    fontSize=18,
    color='black',
    dy=alt.expr("datum.direction == 'positive' ? -8 : 8") 
).encode(
    x=alt.X('category:O'),
    y=alt.Y('y_position:Q'),
    text=alt.Text('percentage:N')
)

# Add a horizontal line at y=0 for reference
zero_line = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    color='black',
    strokeWidth=1
).encode(y='y:Q')

# Create section labels for positive and negative sides - positioned independently
section_label_data = pd.DataFrame([
    {'label': 'Union Called Sites', 'y_position': 20000, 'x_pos': 0},  # Top section
    {'label': 'Negative Ctrl Sites', 'y_position': -8000, 'x_pos': 0}  # Bottom section
])

# Create section labels for positive and negative sides
section_labels = alt.Chart(section_label_data).mark_text(
    align='right',
    baseline='middle',
    fontSize=18,
    fontWeight='bold',
    color='#333333',
    angle=270,  # Rotate text vertically like y-axis title
    dy=-70  # Position to the left of the chart area
).encode(
    x=alt.value(0),  # Fixed x position
    y=alt.Y('y_position:Q', scale=alt.Scale(domain=[-21000, 21000])), 
    text=alt.Text('label:N')
)

# Combine all charts
final_chart = alt.layer(
    chart,
    zero_line,
    percentage_labels,
    section_labels
).resolve_scale(
    color='independent'
)

# Display the chart
final_chart.show()

# Optional: Save the chart
# final_chart.save('bidirectional_bar_chart.html')
# final_chart.save('bidirectional_bar_chart.png', scale_factor=2.0)