In [1]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dash_table, dcc, html, Input, Output
from jupyter_dash import JupyterDash
import pycountry_convert as pc


# Load the CSV data
data_url = 'https://query.data.world/s/ubyf7tmt7vkvimvz3u5kehrkx2rpwf?dws=00000'
df1 = pd.read_csv(data_url, encoding='latin-1')
df2 = pd.read_csv('./continents.csv')

df = pd.merge(df1, df2[["name", "region"]], left_on='Country', right_on='name')
#df = df3.drop('name', axis=1, inplace=True)

# Create the Dash app
app = JupyterDash(__name__)


#Layout
app.layout = html.Div([

    html.H1("Supply Chain Health Commodity Shipment Data", style={'text-align': 'center'}),
    html.Br(),
    
    html.H2("Number of Shipments by Country", style={'text-align': 'center'}),

    dcc.Dropdown(id="slct_country",
            options=sorted([{"label": dept, "value": dept} for dept in df['Country'].unique()],key = lambda x: x['label']),
            style={'color': '#111111'},
            #value=[dept for dept in df['Country'].unique()],
            #value=['Afghanistan'],
            value=[],
            multi=True
                 ),
    
    html.Br(),
    html.Div(id='output_container', children=[]),
    html.Br(),
    
    html.Div([
        dcc.Graph(id='map', figure={})
    ]),
    
    html.Div([
        html.H2("Shipment Details by Region", style={'text-align': 'center'}),
        html.H4("Level of Detail:", style={'text-align': 'center'}),
        html.P("Global View  |  Region  |  Country  |  Product Group  |  Product Sub Classification", style={'text-align': 'center'}),
        html.P("(For example: Global View  |  Americas  |  Guyana  |  ARV  |  Adult, Pediatric )", style={'text-align': 'center'}),
        dcc.Graph(id='sunburst', figure={}),
        html.H4("Key:", style={'text-align': 'center'}),
        html.P("Count = Number of Shipments", style={'text-align': 'center'}),
        html.P("ACT = Artemisinin-Based Combination Therapy", style={'text-align': 'center'}),
        html.P("ANTM = Anti-Malarial Medicine", style={'text-align': 'center'}),
        html.P("ARV = Anti-Retroviral Treatment", style={'text-align': 'center'}),
        html.P("HRDT = HIV Rapid Diagnostic Test", style={'text-align': 'center'}),
        html.P("MRDT = Malarial Rapid Diagnostic Test", style={'text-align': 'center'}),
    ])

])

@app.callback(
    [Output(component_id='output_container', component_property='children'),
     Output(component_id='map', component_property='figure'),
    Output(component_id='sunburst', component_property='figure')],
    [Input(component_id='slct_country', component_property='value')]
)

def update_graph(option_slctd):

    container = "Selected countries: {}".format(option_slctd)

    dff = df.copy()

    cols = dff.columns
    cols = cols.map(lambda x: x.replace(' ', '_') if isinstance(x, (str)) else x)
    dff.columns = cols

    # make a dict with counts
    count_dict = {d:(dff['Country']==d).sum() for d in dff.Country.unique()}
    # assign that dict to a column
    dff['Shipment_count'] = [count_dict[d] for d in dff.Country]

    #new column for % of the shipment count
    dff['Perc_sc'] = (dff['Shipment_count'] / dff['Shipment_count'].sum()) * 100

    #make a dict with product groups per country
    pg_dict = dict(dff.groupby('Country')['Product_Group'].unique())
    dff['Product_groups'] = [pg_dict[d] for d in dff.Country]

    #make a dict with product sub classifications per country
    sub_dict = dict(dff.groupby('Country')['Sub_Classification'].unique())
    dff['Sub_class'] = [sub_dict[d] for d in dff.Country] 
    
    filtered_data = dff[dff['Country'].isin(option_slctd)]
    
    # Plotly Express
    fig = px.choropleth(
        data_frame=filtered_data,
        locationmode='country names',
        locations='Country',
        scope="world",
        color='Shipment_count',
        hover_data=['Country', 'Shipment_count', 'Product_groups'],
        color_continuous_scale=px.colors.sequential.Jet,
        template='plotly_dark',
        height=800,
        width=1450
    )
    
    fig.add_scattergeo(
        locations=filtered_data['Country'],
        locationmode="country names", 
        text=filtered_data['Country'],
        mode='text',
        hoverinfo='skip',
        textposition="top right",
        textfont=dict(
            size=10,
            color="LightGray"
    )
    )
    
    fig.update_layout(
        coloraxis_colorbar=dict(
        title='Shipment Count'
        ))
    
    fig2 = px.sunburst(
        dff, 
        path=['region','Country', 'Product_Group', 'Sub_Classification'], 
        template='plotly_dark',
        color='region',
        height=700)
        
    fig2.update_traces(
        maxdepth=2, 
        selector=dict(type='sunburst'),
        domain=dict(column=5)
    )
    
    fig2.update_layout(
        coloraxis_colorbar=dict(
        title='Continent'
        ))
    
    return container, fig, fig2

# Run the app
if __name__ == '__main__':
    app.run_server(debug=True, host='localhost', port=8053)

Dash app running on http://localhost:8053/
