""" _summary_

This Jupyter notebook uses Plotly Dash Apps to explore already generated simulations.

- Results from an exemplar simulations: 
    - (app_choice = 0) APP: Dropdown - choose model condition; Dropdown - choose time point
    - show scatter plot of cells

- Results from replicate:
    - (app_choice = 1) APP: Dropdown - choose model type; Dropdown - choose seeding density
        - histogram of tumour sizes, colour coded by model condition (time point)
        - scatter plot of time-averaged tumour count vs time, size vs time; curves corresponding to replicate simulations

    - (app_choice = 2) APP: Dropdown - choose time points; Dropdown choose seeding density
        - histogram of tumour sizes, colour coded by model condition (model type)
        - scatter plot of tumour size mean vs std, marker size = count 

"""


# Library

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, callback, Input, Output, State, html, dcc

# Dash Apps

### [1] Exemplar simulation

In [None]:
app_choice = 0 # this App

In [None]:
# get cell configurations
from classes_and_functions.settings import get_cell_configurations
(site_types, sites_states, color_map, markersize_map) = get_cell_configurations()

In [None]:
path_to_combined_snapshots = "./files/selected_simulation_snapshots_at_40.csv"
combined_snapshots = pd.read_csv(path_to_combined_snapshots)

In [None]:
combined_snapshots['model_type'] = combined_snapshots['model_condition'].map(
    lambda x : '_'.join(x.split('_')[:2])
)
combined_snapshots['seeding_density'] = combined_snapshots['model_condition'].map(
    lambda x : '_'.join(x.split('_')[2:])
)


In [None]:
# ===== plotly dash APP =====

if app_choice == 0:

    external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
    app = Dash(__name__, external_stylesheets=external_stylesheets)

    app.layout = html.Div(
        [
            html.Div([
                html.Div([
                    # Dropdown - model type,
                    html.H3('Select model type...'),
                    dcc.Dropdown(
                        sorted(combined_snapshots['model_type'].unique()), 
                        combined_snapshots['model_type'].values[0], 
                        placeholder="Select model type...",
                        id='dropdown-model-type'
                    )
                ], className="six columns", style={"width": "25%"}),  
                html.Div([
                    # Dropdown - seeding density
                    html.H3('Select tumour seeding density...'),
                    dcc.Dropdown(
                        sorted(combined_snapshots['seeding_density'].unique()), 
                        combined_snapshots['seeding_density'].values[0], 
                        placeholder="Select seeding density...",
                        id='dropdown-seeding-density'
                    )
                ], className="six columns", style={"width": "25%"}),  
            ], className="row"),
        
            # Scatter plots
            html.Div([
                html.Div([
                    dcc.Graph(id='scatter1')
                ]),

            ])
            
        ]
    )


    @callback(
        [
            Output('scatter1', 'figure'),
        ],
        [
            Input('dropdown-model-type', 'value'),
            Input('dropdown-seeding-density', 'value')
        ]
    )
    def update_plot(value1, value2):
        
        df_plot = combined_snapshots.loc[
            (combined_snapshots.model_type==value1) & (combined_snapshots.seeding_density==value2)
        ].copy()
            
        # ===== scatter plots =====    
        df_plot["site_type_name"] = df_plot["site_type"].map(
            lambda x : site_types[x]
        )
        
        scatter1 = px.scatter(
            data_frame=df_plot,
            x='x', y='y',
            color='site_type_name',
            facet_col='pid', 
            # facet_col_wrap=2,
            color_discrete_map=color_map,
        )
        
        # customize the figure
        scatter1.update_layout(
            template='simple_white', 
            width=1600, height=500
        )
        scatter1.update_traces(
            marker=dict(size=3)
        )
        scatter1.update_xaxes(title=dict(text="x", font_family="Arial", font_size=14))
        scatter1.update_yaxes(
            title=dict(text="y", font_family="Arial", font_size=14),
            scaleanchor="x", scaleratio=1
            )

        return [scatter1]

    if __name__ == '__main__':
        app.run(debug=True, port=8051)

### [2] Results from replicate simulations

In [None]:
app_choice = 1 
app_choice = 2

In [None]:
path_to_combined_results = "./files/combined_results_tumour_sizes.csv"
combined_results = pd.read_csv(path_to_combined_results)

In [None]:
combined_results['model_type'] = combined_results['model_condition'].map(
    lambda x : '_'.join(x.split('_')[:2])
)
combined_results['seeding_density'] = combined_results['model_condition'].map(
    lambda x : '_'.join(x.split('_')[2:])
)

In [None]:
# ===== plotly dash APP1 =====

if app_choice == 1:
    
    external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
    app1 = Dash(__name__, external_stylesheets=external_stylesheets)

    app1.layout = html.Div(
        [
            html.Div([
                html.Div([
                    # Dropdown - model type
                    html.H3('Select model type...'),
                    dcc.Dropdown(
                        sorted(combined_results['model_type'].unique()), 
                        combined_results['model_type'].values[0], 
                        placeholder="Select model type...",
                        id='dropdown-model-type'
                    )
                ], className="six columns", style={"width": "25%"}),  
                html.Div([
                    # Dropdown - seeding density
                    html.H3('Select tumour seeding density...'),
                    dcc.Dropdown(
                        sorted(combined_results['seeding_density'].unique()), 
                        combined_results['seeding_density'].values[0], 
                        placeholder="Select seeding density...",
                        id='dropdown-seeding-density'
                    )
                ], className="six columns", style={"width": "25%"}),  
            ], className="row"),
        
            # Scatter plots
            html.Div([
                html.Div([
                    dcc.Graph(id='scatter1')
                ], className="six columns", style={"width": "25%"}),

                html.Div([
                    dcc.Graph(id='scatter2')
                ], className="six columns", style={"width": "25%"}),
            ], className="row"),
            
            # Histograms
            html.Div([
                html.Div([
                    # html.H3('Area'),
                    dcc.Graph(id='histogram1')
                ], className="six columns", style={"width": "25%"}),

                html.Div([
                    # html.H3('Area scaled by median'),
                    dcc.Graph(id='histogram2')
                ], className="six columns", style={"width": "25%"}),
            ], className="row")
            
        ]
    )


    @callback(
        [
            Output('histogram1', 'figure'),
            Output('histogram2', 'figure'),
            Output('scatter1', 'figure'),
            Output('scatter2', 'figure'),
        ],
        [
            Input('dropdown-model-type', 'value'),
            Input('dropdown-seeding-density', 'value')
        ]
    )
    def update_plot(value1, value2):
        
        data_subset = combined_results.loc[
            (combined_results.model_type==value1) & (combined_results.seeding_density==value2)
        ].copy()
        
        # ===== histograms of tumour areas (per DBSCAN cluster) =====
        
        histogram1_data = list(); histogram2_data = list()
        for t in sorted(data_subset['time'].unique()):
            
            data_subsubset = data_subset.loc[data_subset['time']==t].copy()
            
            areas = data_subsubset['size'].values
            areas_scaled = areas / np.percentile(areas, 30)
            
            log10_areas = np.log10( areas ); data_subsubset['log10_areas'] = log10_areas
            log10_areas_scaled = np.log10( areas_scaled ); data_subsubset['log10_areas_scaled'] = log10_areas_scaled
            
            # histogram1 - produce histogram data wiht numpy
            count, index = np.histogram(log10_areas, bins=30)
            histogram1_data.append(
                go.Scatter(
                    x=index, y = count,
                    line=dict(width = 1, shape='hvh'),
                    name=f"t={t}",
                )
            )
            
            # histogram2 - produce histogram data wiht numpy
            count, index = np.histogram(log10_areas_scaled, bins=30)
            histogram2_data.append(
                go.Scatter(
                    x=index, y = count,
                    line=dict(width = 1, shape='hvh'),
                    name=f"t={t}",
                )
            )
            
        # ===== scatter plots =====
        scatter1_data = list(); scatter2_data = list()
        for pid in data_subset['pid'].unique():
            
            data_subsubset = data_subset.loc[data_subset.pid==pid].copy()
            areas = data_subsubset['size'].values
            log10_areas = np.log10( areas ); data_subsubset['log10_areas'] = log10_areas
            
            time_average_summary = data_subsubset.groupby('time', as_index=False).agg({'log10_areas': ['count','mean','std']})
            
            # scatter1 - 
            scatter1_data.append(
                go.Scatter(
                    x=time_average_summary['time'].values,
                    y=time_average_summary['log10_areas']['mean'].values,
                    name=f"pid={pid}",
                    mode="lines+markers",
                    # marker_size=time_average_summary['log10_areas']['mean'].values,
                )
            )
            
            # scatter2 - 
            scatter2_data.append(
                go.Scatter(
                    x=time_average_summary['time'].values,
                    y=time_average_summary['log10_areas']['count'].values,
                    name=f"pid={pid}",
                    mode="lines+markers",
                    # marker_size=time_average_summary['log10_areas']['count'].values/10,
                )
            )
            

        # customize the figure
        
        # histograms
        histogram1 = go.Figure(data=histogram1_data); histogram2 = go.Figure(data=histogram2_data)
        histogram1.update_layout(template='simple_white', width=500, height=400); histogram2.update_layout(template='simple_white', width=500, height=400)
        histogram1.update_xaxes(title=dict(text="Log10 (tumour size)", font_family="Arial", font_size=14))
        histogram1.update_yaxes(title=dict(text="Probability density", font_family="Arial", font_size=14))
        histogram2.update_xaxes(title=dict(text="Log10 (tumour size scaled by median)", font_family="Arial", font_size=14))
        histogram2.update_yaxes(title=dict(text="Probability density", font_family="Arial", font_size=14))
        
        # scatter plots
        scatter1 = go.Figure(data=scatter1_data); scatter2 = go.Figure(data=scatter2_data)
        scatter1.update_layout(template='simple_white', width=500, height=400); scatter2.update_layout(template='simple_white', width=500, height=400)
        scatter1.update_xaxes(title=dict(text="Time", font_family="Arial", font_size=14))
        scatter1.update_yaxes(title=dict(text="Log10 (tumour size) mean", font_family="Arial", font_size=14))
        scatter2.update_xaxes(title=dict(text="Time", font_family="Arial", font_size=14))
        scatter2.update_yaxes(title=dict(text="Number of tumours observed", font_family="Arial", font_size=14))

        return [histogram1, histogram2, scatter1, scatter2]


    if __name__ == '__main__':
        app1.run(debug=True, port=8052)

In [None]:
# ===== plotly dash APP2 =====

if app_choice == 2:
    external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
    app2 = Dash(__name__, external_stylesheets=external_stylesheets)

    app2.layout = html.Div(
        [
            html.Div([
                html.Div([
                    # Dropdown - model type
                    dcc.Dropdown(
                        sorted(combined_results['time'].unique()), 
                        combined_results['time'].values[0], 
                        placeholder="Select time point...",
                        id='dropdown-time'
                    )
                ], className="six columns"),  
                html.Div([
                    # Dropdown - seeding density
                    dcc.Dropdown(
                        sorted(combined_results['seeding_density'].unique()), 
                        combined_results['seeding_density'].values[0], 
                        placeholder="Select seeding density...",
                        id='dropdown-seeding-density'
                    )
                ], className="six columns"),  
            ], className="row"),
        
            # Histograms
            html.Div([
                html.Div([
                    html.H3('Area'),
                    dcc.Graph(id='histogram1')
                ], className="six columns"),

                html.Div([
                    html.H3('Area scaled by median'),
                    dcc.Graph(id='histogram2')
                ], className="six columns"),
            ], className="row"),
            
            # Scatter plots
            html.Div([
                html.Div([
                    dcc.Graph(id='scatter1')
                ], className="six columns"),

                html.Div([
                    dcc.Graph(id='scatter2')
                ], className="six columns"),
            ], className="row")
            
        ]
    )


    @callback(
        [
            Output('histogram1', 'figure'),
            Output('histogram2', 'figure'),
            Output('scatter1', 'figure'),
            Output('scatter2', 'figure'),
        ],
        [
            Input('dropdown-time', 'value'),
            Input('dropdown-seeding-density', 'value')
        ]
    )
    def update_plot(value1, value2):
        
        data_subset = combined_results.loc[
            (combined_results['time']==value1) & (combined_results.seeding_density==value2)
        ].copy()
        
        # histograms
        histogram1_data = list(); histogram2_data = list()
        
        # scatter plots
        scatter1_data = list(); scatter2_data = list()
        
        for model_type in sorted(data_subset['model_type'].unique()):
            
            data_subsubset = data_subset.loc[data_subset['model_type']==model_type].copy()
            
            # ===== histograms of tumour areas (per DBSCAN cluster) =====
            areas = data_subsubset['size'].values
            areas_scaled = areas / np.percentile(areas, 50)
            
            log10_areas = np.log10( areas ); data_subsubset['log10_areas'] = log10_areas
            log10_areas_scaled = np.log10( areas_scaled ); data_subsubset['log10_areas_scaled'] = log10_areas_scaled
            
            # histogram1 - produce histogram data wiht numpy
            count, index = np.histogram(log10_areas, bins=30)
            histogram1_data.append(
                go.Scatter(
                    x=index, y = count,
                    line=dict(width = 1, shape='hvh'),
                    name=model_type,
                )
            )
            
            # histogram2 - produce histogram data wiht numpy
            count, index = np.histogram(log10_areas_scaled, bins=30)
            histogram2_data.append(
                go.Scatter(
                    x=index, y = count,
                    line=dict(width = 1, shape='hvh'),
                    name=model_type,
                )
            )
            
            # ===== scatter plots of areas | mean, std (per simulation) =====
            
            # scatter1 - 
            log10_areas_summary = data_subsubset.groupby('pid').agg({'log10_areas':['mean', 'std', 'count']})
            scatter1_data.append(
                go.Scatter(
                    x=log10_areas_summary['log10_areas']['mean'].values,
                    y=log10_areas_summary['log10_areas']['std'].values,
                    name=model_type,
                    mode="markers",
                    marker_size=log10_areas_summary['log10_areas']['count'].values,
                )
            )
            
            # scatter2 - 
            log10_areas_scaled_summary = data_subsubset.groupby('pid').agg({'log10_areas_scaled':['mean', 'std', 'count']})
            scatter2_data.append(
                go.Scatter(
                    x=log10_areas_scaled_summary['log10_areas_scaled']['mean'].values,
                    y=log10_areas_scaled_summary['log10_areas_scaled']['std'].values,
                    name=model_type,
                    mode="markers",
                    marker_size=log10_areas_scaled_summary['log10_areas_scaled']['count'].values,
                )
            )

        # customize the figures
        
        histogram1 = go.Figure(data=histogram1_data)
        histogram2 = go.Figure(data=histogram2_data)
        histogram1.update_layout(template='simple_white', width=700, height=400); histogram2.update_layout(template='simple_white', width=700, height=400)
        histogram1.update_xaxes(title=dict(text="Log10 (tumour size)", font_family="Arial", font_size=14))
        histogram1.update_yaxes(title=dict(text="Probability density", font_family="Arial", font_size=14))
        histogram2.update_xaxes(title=dict(text="Log10 (tumour size scaled by median)", font_family="Arial", font_size=14))
        histogram2.update_yaxes(title=dict(text="Probability density", font_family="Arial", font_size=14))
        
        scatter1 = go.Figure(data=scatter1_data)
        scatter2 = go.Figure(data=scatter2_data)
        scatter1.update_layout(template='simple_white', width=700, height=400); scatter2.update_layout(template='simple_white', width=700, height=400)
        scatter1.update_xaxes(title=dict(text="Log10 (tumour size) mean", font_family="Arial", font_size=14))
        scatter1.update_yaxes(title=dict(text="Log10 (tumour size) std", font_family="Arial", font_size=14))
        scatter2.update_xaxes(title=dict(text="Log10 (tumour size scaled by median) mean", font_family="Arial", font_size=14))
        scatter2.update_yaxes(title=dict(text="Log10 (tumour size scaled by median) std", font_family="Arial", font_size=14))
        

        return [histogram1, histogram2, scatter1, scatter2]

    if __name__ == '__main__':
        app2.run(debug=True, port=8053)