In [1]:
### ToDo:
# 2. Add numbers to chord diagram, or a heatmap

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import plotly.graph_objects as go
import random

import networkx as nx

import holoviews as hv
from holoviews import opts
from bokeh.io import output_file, show
hv.extension('bokeh')

In [4]:
import plotly
plotly.__version__

'5.22.0'

In [4]:
prefix = '/slade/home/pl450/Uveitis/GRS/python_analysis/'
data = (pd
        .read_csv(prefix + 'allGRS_forROCAUC_220524.tsv',
                  sep='\t', low_memory=False)
        .dropna(subset=['Sex']))

In [5]:
columns_to_fill = [i for i in data.columns if 'first_' in i and 'uve' not in i]

data.loc[:, columns_to_fill].fillna(value=0,axis=1,inplace=True)

In [8]:
columns_to_fillx_to_rgba(hex_color, opacity):
    # Remove '#' if it exists in the hex color
    hex_color = hex_color.strip('#')

    # Convert hex to RGB
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)

    # Return the color in RGBA format
    return str(f'rgba({r},{g},{b},{opacity})')


def filter_data(data, filter_column, filter_value):
    """ Filter data for people with the specified condition. """
    return data[data[filter_column] == filter_value]

def calculate_transitions(data, disease_list, pre_conditions_list, adjust):
    """ Calculate the transitions including proportional flows for Uveitis. """
    source, target, value = [], [], []
    condition_indices = {'Uveitis': 0}
    condition_counts = {}
    next_index = 1

    # Pre-Uveitis Conditions (flow into Uveitis)
    for condition in pre_conditions_list:
        column_name = f'first_{condition}'
        condition_data = data[data[column_name] == True]
        # Test case

        if condition in condition_indices:
            idx = condition_indices[condition]
        else:
            idx = next_index
            condition_indices[condition] = idx
            next_index += 1
        source.append(idx)
        target.append(condition_indices['Uveitis'])
        value.append(len(condition_data))        
        condition_counts[condition] = len(condition_data)

    # Post-Uveitis Conditions (flow from Uveitis)
    # Assumed 'uve_any' indicates presence of Uveitis
    uveitis_cases = data[data['uve_any'] == True]
    for condition in disease_list:
        column_name = f'first_uve_{condition}'
        condition_data = uveitis_cases[uveitis_cases[column_name] == True]
        if f' {condition}' in condition_indices:
            idx = condition_indices[f' {condition}']
        else:
            idx = next_index
            condition_indices[f' {condition}'] = idx
            next_index += 1
        source.append(condition_indices['Uveitis'])
        target.append(idx)
        value.append(len(condition_data))
        condition_counts[f' {condition}'] = len(condition_data)
        
    return condition_counts

    # Adjusting Uveitis node size to match inflow and outflow
    uveitis_inflow = sum([value[i] for i in range(len(value)) if target[i] == condition_indices['Uveitis']])
    uveitis_outflow = sum([value[i] for i in range(len(value)) if source[i] == condition_indices['Uveitis']])
    
    if adjust:
        if uveitis_inflow != uveitis_outflow:  # Check if adjustment needed
            correction_factor = uveitis_outflow / uveitis_inflow if uveitis_inflow else 0
            for i in range(len(value)):
                if target[i] == condition_indices['Uveitis']:
                    value[i] = int(value[i] * correction_factor)

    return source, target, value, condition_indices, condition_counts



def prepare_diagram_elements(condition_indices, condition_counts,
                             total_patients):
    """ Prepare diagram labels and color coding. """
    labels = [f"Uveitis\n{total_patients} (100%)"]
    colors = [condition_colors.get('Uveitis', '#66c2a5')]  # Default teal for Uveitis
    
    for condition, index in sorted(condition_indices.items(), key=lambda x: x[1]):
        count = condition_counts.get(condition, 0)
        if condition == 'Uveitis':
            continue
        label = f"{condition}\n{count} ({count/total_patients:.1%})" if count > 0 else condition
        labels.append(label)
        # Assign color based on condition_colors dictionary or default to a neutral color
        colors.append(condition_colors.get(condition, "#98f5d7"))


    return labels, colors


def create_sankey(source, target, value, labels, colors):
    
    """ Create and display the Sankey diagram. """
    

    # Reordering colours for the links 
    link_colors = colors.copy()
    first_color = link_colors.pop(0)
    link_colors.append(first_color)
    link_colors = [(hex_to_rgba(i, 0.4)) for i in link_colors]
    
    fig = go.Figure(data=[go.Sankey(
        node=dict(pad=15, thickness=20, line=dict(color="black", width=0.4),
                  label=labels, color=colors),
        link=dict(source=source, target=target, value=value, color=link_colors)
    )])
    fig.update_layout(title_text="Uveitis to IMIDs Sankey", font_size=16)
    return fig, dict(source=source, target=target, value=value, label=labels, color=colors)

def save_sankey_diagram(fig, filename="uveitis_sankey_diagram.html"):
    """ Save the Sankey diagram to an HTML file. """
    fig.write_html(filename)
    print(f"Diagram saved as {filename}")

def create_sankey_diagram(data, disease_list, pre_conditions_list,
                          filter_column='uve_any', filter_value=True,
                          filename="uveitis_sankey_diagram.html",
                          return_out = False, adjust=False):
    filtered_data = filter_data(data, filter_column, filter_value)
    (source, target, value, condition_indices, condition_counts) = (
        calculate_transitions(filtered_data, disease_list, pre_conditions_list, adjust)
    )
    labels, colors = prepare_diagram_elements(condition_indices, condition_counts, len(filtered_data))
    fig, links = create_sankey(source, target, value, labels, colors)
    save_sankey_diagram(fig, filename)
    
    if return_out:
        return (source, target, value, condition_indices, condition_counts, labels, colors, links)



In [9]:
disease_list = ['AS', 'RA', 'Behcets', 'Crohns', 'Psoriasis',
                'Sarcoid', 'Sjo', 'SLE', 'UC', 'Weg', 'CS',
                'GCA', 'MS']

In [21]:
# Columns to check for incident IMIDs
data.rename({'undif_uve':'first_Undif'}, axis=1, inplace=True)

In [22]:
data['first_uve_Undif'] = 0
incident_imid = ['first_uve_'+i for i in disease_list ]
data.loc[((data[incident_imid].sum(axis=1)==0) & (data['first_Undif']==1)),
         'first_uve_Undif'] = 1

In [11]:
# This checks how many turned from undifferentiated at presentation
    # into one of the IMIDs     
# data.first_Undif.value_counts()
# data.first_uve_Undif.value_counts()

In [12]:
condition_colors = {
    'AS': '#ff6347',  # Tomato
    'RA': '#4682b4',  # Steel blue
    'Behcets': '#da70d6',  # Orchid
    'Crohns': '#32cd32',  # Lime green
    'Psoriasis': '#ffa500',  # Orange
    'Sarcoid': '#e6beff',  # Lavender
    'Sjo': '#ff4500',  # Orange red
    'SLE': '#ffea00',  # Yellow
    'UC': '#6a5acd',  # Slate blue
    'Weg': '#db7093',  # Pale violet red
    'CS': '#ffc0cb',  # Pink
    'GCA': '#46f0f0',  # Cyan
    'MS': '#9A6324',  # Brown
    'Uveitis': '#66c2a5', # Teal
    'Undif': '#66c2a5'
}


condition_colors = (condition_colors
                    | {(" "+key):value for key,value 
                       in condition_colors.items()})


In [24]:
pre_conditions_list = disease_list + ['Undif']

# create_sankey_diagram(
#     data, pre_conditions_list, pre_conditions_list,
#     filename='uveitis_sankey_diagram_050824.html')


In [32]:
sum(list(calculate_transitions(data, pre_conditions_list, pre_conditions_list, False).values())[14:-1])

396

In [172]:
def create_chord_diagram(data, output_f='chord_diagram.html',
                         condition_columns=None, expanded=True,
                         show_fig=True, title = "Chord Diagram"):
    filtered_data = data.copy()
    if condition_columns is None:
        condition_columns = [col for col in data.columns if col.endswith('_any')]

    # Ensure data types and handle missing values
    filtered_data[condition_columns] = filtered_data[condition_columns].apply(pd.to_numeric, errors='coerce').fillna(0)
    
    mapper = {i: 'Uveitis' if i == 'uve_any' else i.strip('_any') for i in filtered_data.columns if i in condition_columns}
#    return mapper

    filtered_data.rename(mapper, axis=1, inplace=True)
    condition_columns = [i.strip("_any") if i!="uve_any"
                         else 'Uveitis' for i in condition_columns]
    
    
    custom_palette = [condition_colors.get(i.strip('_any'), '#999999') for i in condition_columns]
    # Generate the co-occurrence matrix
    co_occurrence_matrix = np.dot(filtered_data[condition_columns].T, filtered_data[condition_columns])

    if expanded:
        # Update diagonal with total occurrences of each condition
        total_occurrences = filtered_data[condition_columns].sum(axis=0)
        np.fill_diagonal(co_occurrence_matrix, total_occurrences)
    else:
        # Set diagonal to zero to not count self-occurrences for the non-expanded view
        np.fill_diagonal(co_occurrence_matrix, 0)

    # Optionally clear lower triangle to avoid duplicate edges in visualization
    upper_tri_mask = np.triu(np.ones(co_occurrence_matrix.shape), k=0).astype(bool)
    co_occurrence_matrix[~upper_tri_mask] = 0

    co_occurrence_df = pd.DataFrame(co_occurrence_matrix, index=condition_columns, columns=condition_columns).astype(int)

    # Create nodes and edges for the Chord diagram
    nodes = pd.DataFrame(index=condition_columns, data={'name': condition_columns})
    totals = co_occurrence_df.sum(axis=1).reset_index(name='value')  # Ensure node sizes reflect total occurrences
    nodes['value'] = totals['value']

    
    
    edges = co_occurrence_df.stack().reset_index()
    edges.columns = ['source', 'target', 'value']
    edges = edges[edges['value'] > 0]  # Filter out zero values

    

    # Define node and edge data for Holoviews Chord
    hv_nodes = hv.Dataset(nodes, 'index', ['name', 'value'])
    hv_edges = hv.Dataset(edges, ['source', 'target'], 'value')


    
    # Create the Chord diagram using Holoviews
    chord = hv.Chord((hv_edges, hv_nodes)).opts(
        opts.Chord(
            cmap=custom_palette, edge_cmap=custom_palette,
            labels='name',
            edge_color=hv.dim('source').str(),
            node_color=hv.dim('index').str(),
            width=800, height=800,
            fontsize={'labels': '10pt', 'ticks': '8pt'},
            title=title
        )
    )

    output_file(output_f)
    if show_fig:
        show(hv.render(chord))
    print(f"Diagram saved as {output_f}")

In [173]:
condition_columns=['uve_any', 'AS_any', 'MS_any', 'RA_any', 'UC_any',
                   'Crohns_any', 'Psoriasis_any', 'Sarcoid_any',
                   'SLE_any']

In [174]:
create_chord_diagram(data.loc[data.uve_any==1], 
                     output_f='uveitis_chord_diagram_Expanded_300724.html',
                     condition_columns=condition_columns, expanded=True)

Diagram saved as uveitis_chord_diagram_Expanded_300724.html


In [175]:
condition_columns=['AS_any', 'MS_any', 'RA_any', 'UC_any','Crohns_any',
                   'Psoriasis_any', 'Sarcoid_any', 'SLE_any']
create_chord_diagram(data.loc[data.uve_any==1], 
                     output_f='Withinuveitis_chord_diagram_Expanded_300724.html',
                     condition_columns=condition_columns, expanded=True)

Diagram saved as Withinuveitis_chord_diagram_Expanded_300724.html


In [176]:
condition_columns=['uve_any','AS_any', 'MS_any', 'RA_any', 'UC_any','Crohns_any',
                   'Psoriasis_any', 'Sarcoid_any', 'SLE_any']
create_chord_diagram(data, 
                     output_f='ALLUKBB_chord_diagram_Expanded_300724.html',
                     condition_columns=condition_columns, expanded=True)

Diagram saved as ALLUKBB_chord_diagram_Expanded_300724.html
