In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import time

# ======================
# DATA LOADING & PREP
# ======================

print("Loading dataset...")
df = pd.read_csv("FoodDataset.csv")
print("Successfully loaded Oxford dataset")

# Standardize column names
column_mapping = {
    'ghg': 'mean_ghgs',
    'ch4': 'mean_ghgs_ch4',
    'n2o': 'mean_ghgs_n2o',
    'land': 'mean_land',
    'bio': 'mean_bio',
    'watscar': 'mean_watscar',
    'watuse': 'mean_watuse',
    'eut': 'mean_eut',
    'acid': 'mean_acid',
    'n': 'n_participants',
    'sex': 'sex',
    'age': 'age_group',
    'diet_group': 'diet_group'
}
column_mapping = {k:v for k,v in column_mapping.items() if k in df.columns}
df = df.rename(columns=column_mapping)

# Clean and standardize values
df['diet_group'] = df['diet_group'].str.lower().str.strip()
df['sex'] = df['sex'].str.lower().str.strip()

# ======================
# DATA VALIDATION
# ======================

print("\n=== DATA VALIDATION ===")
print("Columns in dataset:", df.columns.tolist())
print("\nUnique diet groups:", df['diet_group'].unique())
print("Unique sex values:", df['sex'].unique())
print("Unique age groups:", df['age_group'].unique())

# ======================
# DATA PROCESSING
# ======================

# Define impact categories with better organization
impact_categories = {
    'Climate_Impact': ['mean_ghgs', 'mean_ghgs_ch4', 'mean_ghgs_n2o'],
    'Land_Impact': ['mean_land', 'mean_bio'],
    'Water_Impact': ['mean_watscar', 'mean_watuse'],
    'Chemical_Impact': ['mean_eut', 'mean_acid']
}

# Define metric to category mapping
metric_to_category = {
    'mean_ghgs': 'Climate_Impact',
    'mean_ghgs_ch4': 'Climate_Impact',
    'mean_ghgs_n2o': 'Climate_Impact',
    'mean_land': 'Land_Impact',
    'mean_bio': 'Land_Impact',
    'mean_watscar': 'Water_Impact',
    'mean_watuse': 'Water_Impact',
    'mean_eut': 'Chemical_Impact',
    'mean_acid': 'Chemical_Impact'
}

# Friendly names for display
category_friendly_names = {
    'Climate_Impact': 'Climate Impact',
    'Land_Impact': 'Land & Biodiversity Impact',
    'Water_Impact': 'Water Resources Impact',
    'Chemical_Impact': 'Chemical Pollution Impact'
}

# Enhanced diet categorization
def categorize_diet(diet):
    diet = str(diet).lower().strip()
    if 'vegan' in diet:
        return 'Vegan'
    elif 'veggie' in diet or 'vegetarian' in diet:
        return 'Vegetarian'
    elif 'fish' in diet or 'pescatarian' in diet:
        return 'Pescatarian'
    elif 'meat100' in diet or '>100' in diet:
        return 'High Meat (>100g/day)'
    elif 'meat50-99' in diet or ('50' in diet and '99' in diet):
        return 'Medium Meat (50-99g/day)'
    elif 'meat50' in diet or '<50' in diet:
        return 'Low Meat (<50g/day)'
    elif 'meat' in diet:  # Catch-all for generic meat entries
        return 'Medium Meat (50-99g/day)'  # Default assumption
    else:
        print(f"Unknown diet group: {diet}")
        return 'Other'

print("Categorizing diets...")
df['Diet_Category'] = df['diet_group'].apply(categorize_diet)

# Convert sex values to Male/Female
df['Sex'] = df['sex'].str.lower().str.strip().map({'male': 'Male', 'female': 'Female'})
df['Age_Category'] = df['age_group']

# Calculate impact scores
print("Calculating impact scores...")
for category, metrics in impact_categories.items():
    available_metrics = [m for m in metrics if m in df.columns]
    if available_metrics:
        # First normalize each metric
        for metric in available_metrics:
            if df[metric].max() > df[metric].min():  # Avoid div by zero
                df[f'{metric}_norm'] = (df[metric] - df[metric].min()) / (df[metric].max() - df[metric].min())
            else:
                df[f'{metric}_norm'] = 0  # Default if all values are the same
        # Then calculate category score
        df[category] = df[[f'{m}_norm' for m in available_metrics]].mean(axis=1)

# ======================
# IMPROVED COMPREHENSIVE SANKEY DIAGRAM WITH ENHANCED COLOR SCHEME
# ======================

def create_comprehensive_sankey(df, weights=None):
    """
    Creates a comprehensive Sankey diagram showing:
    Diet → Age → Gender → Individual Metrics → Impact Categories

    Args:
        df: DataFrame with the data
        weights: Dict with category weights (optional)
    """
    print("Creating comprehensive Sankey diagram with enhanced color scheme...")
    start_time = time.time()

    if weights is None:
        weights = {
            'Climate_Impact': 0.4,
            'Land_Impact': 0.3,
            'Water_Impact': 0.2,
            'Chemical_Impact': 0.1
        }

    # 1. Prepare nodes
    # Diet categories
    diet_nodes = sorted(df['Diet_Category'].unique())

    # Age groups
    age_nodes = sorted(df['Age_Category'].unique())

    # Gender groups
    gender_nodes = sorted(df['Sex'].unique())

    # Create more readable metric labels
    metric_labels = {
        'mean_ghgs': 'Greenhouse Gas Emissions',
        'mean_ghgs_ch4': 'Methane (CH4) Emissions',
        'mean_ghgs_n2o': 'Nitrous Oxide (N2O) Emissions',
        'mean_land': 'Land Use',
        'mean_bio': 'Biodiversity Loss',
        'mean_watscar': 'Water Scarcity',
        'mean_watuse': 'Water Usage',
        'mean_eut': 'Eutrophication',
        'mean_acid': 'Acidification'
    }

    # Impact categories
    impact_categories_available = {k: v for k, v in impact_categories.items()
                                  if any(metric in df.columns for metric in v)}
    impact_category_nodes = list(impact_categories_available.keys())

    # Individual metrics
    all_metrics = []
    for category, metrics in impact_categories_available.items():
        all_metrics.extend([m for m in metrics if m in df.columns])

    # 2. Create nodes list in order of flow:
    # Diet → Age → Gender → Individual Metrics → Impact Categories
    all_nodes = (
        diet_nodes +
        age_nodes +
        gender_nodes +
        [metric_labels[m] for m in all_metrics] +
        [category_friendly_names[cat] for cat in impact_category_nodes]
    )

    print(f"Total nodes: {len(all_nodes)}")

    # Create node indices dictionary
    node_indices = {node: i for i, node in enumerate(all_nodes)}

    # 3. Define IMPROVED color mappings with better contrast and cohesive palette
    diet_colors = {
        'Vegan': 'rgba(39, 174, 96, 0.85)',          # Green
        'Vegetarian': 'rgba(142, 202, 100, 0.85)',    # Light green
        'Pescatarian': 'rgba(52, 152, 219, 0.85)',    # Blue
        'Low Meat (<50g/day)': 'rgba(241, 196, 15, 0.85)',  # Yellow
        'Medium Meat (50-99g/day)': 'rgba(230, 126, 34, 0.85)', # Orange
        'High Meat (>100g/day)': 'rgba(231, 76, 60, 0.85)',  # Red
        'Other': 'rgba(189, 195, 199, 0.85)'          # Gray
    }

    # Color gradient for age groups (from younger to older)
    age_colors = {
        '20-29': 'rgba(108, 92, 231, 0.85)',    # Violet
        '30-39': 'rgba(162, 155, 254, 0.85)',   # Light violet
        '40-49': 'rgba(116, 185, 255, 0.85)',   # Light blue
        '50-59': 'rgba(30, 144, 255, 0.85)',    # Dodger blue
        '60-69': 'rgba(25, 118, 210, 0.85)',    # Medium blue
        '70-79': 'rgba(21, 101, 192, 0.85)'     # Dark blue
    }

    gender_colors = {
        'Male': 'rgba(91, 192, 222, 0.85)',    # Cyan
        'Female': 'rgba(240, 98, 146, 0.85)'   # Pink
    }

    # More distinctive colors for metrics with a cohesive palette
    metric_base_colors = {
        # Climate metrics - red variants
        'mean_ghgs': 'rgba(229, 57, 53, 0.85)',       # Red
        'mean_ghgs_ch4': 'rgba(244, 81, 30, 0.85)',   # Orange-red
        'mean_ghgs_n2o': 'rgba(230, 74, 25, 0.75)',   # Lighter orange-red

        # Land metrics - green variants
        'mean_land': 'rgba(56, 142, 60, 0.85)',       # Green
        'mean_bio': 'rgba(104, 159, 56, 0.85)',       # Light green

        # Water metrics - blue variants
        'mean_watscar': 'rgba(3, 169, 244, 0.85)',    # Light blue
        'mean_watuse': 'rgba(21, 101, 192, 0.85)',    # Dark blue

        # Chemical metrics - purple variants
        'mean_eut': 'rgba(156, 39, 176, 0.85)',       # Purple
        'mean_acid': 'rgba(123, 31, 162, 0.85)'       # Dark purple
    }

    # Bold, distinctive colors for impact categories (final nodes)
    category_colors = {
        'Climate Impact': 'rgba(229, 57, 53, 0.9)',              # Bold red
        'Land & Biodiversity Impact': 'rgba(56, 142, 60, 0.9)',  # Bold green
        'Water Resources Impact': 'rgba(3, 169, 244, 0.9)',      # Bold blue
        'Chemical Pollution Impact': 'rgba(156, 39, 176, 0.9)'   # Bold purple
    }

    # 4. Assign node colors with improved scheme
    node_colors = []
    for node in all_nodes:
        if node in diet_colors:
            node_colors.append(diet_colors[node])
        elif node in age_colors:
            node_colors.append(age_colors[node])
        elif node in gender_colors:
            node_colors.append(gender_colors[node])
        elif node in category_colors:
            node_colors.append(category_colors[node])
        else:
            # Find the matching metric
            for metric, label in metric_labels.items():
                if node == label:
                    node_colors.append(metric_base_colors.get(metric, 'rgba(150, 150, 150, 0.85)'))
                    break
            else:
                node_colors.append('rgba(150, 150, 150, 0.85)')

    # 5. Create links with improved transparency for better flow visibility
    links = {
        'source': [],
        'target': [],
        'value': [],
        'color': [],
        'label': []
    }

    # Diet -> Age links
    print("Processing Diet -> Age links...")
    diet_age = df.groupby(['Diet_Category', 'Age_Category'])['n_participants'].sum().reset_index()
    for _, row in diet_age.iterrows():
        # Skip very small links to reduce complexity
        if row['n_participants'] < 5:
            continue

        links['source'].append(node_indices[row['Diet_Category']])
        links['target'].append(node_indices[row['Age_Category']])
        links['value'].append(row['n_participants'])
        # Reduce opacity for links
        color = diet_colors.get(row['Diet_Category'], 'rgba(150, 150, 150, 0.4)')
        if 'rgba' in color:
            # Extract the RGB values and set a lower opacity
            color = color.replace('0.85', '0.65')
        links['color'].append(color)
        links['label'].append(f"{row['Diet_Category']} → {row['Age_Category']}: {row['n_participants']} participants")

    # Age -> Gender links
    print("Processing Age -> Gender links...")
    age_gender = df.groupby(['Age_Category', 'Sex'])['n_participants'].sum().reset_index()
    for _, row in age_gender.iterrows():
        links['source'].append(node_indices[row['Age_Category']])
        links['target'].append(node_indices[row['Sex']])
        links['value'].append(row['n_participants'])
        color = age_colors.get(row['Age_Category'], 'rgba(150, 150, 150, 0.4)')
        if 'rgba' in color:
            color = color.replace('0.85', '0.65')
        links['color'].append(color)
        links['label'].append(f"{row['Age_Category']} → {row['Sex']}: {row['n_participants']} participants")

    # Gender -> Individual Metrics links
    print("Processing Gender -> Individual Metrics links...")
    # Calculate metrics values by gender
    gender_metrics = {}
    for gender in gender_nodes:
        gender_data = df[df['Sex'] == gender]
        participants = gender_data['n_participants'].sum()

        gender_metrics[gender] = {
            'participants': participants,
            'metrics': {}
        }

        for metric in all_metrics:
            if metric in df.columns:
                metric_value = gender_data[metric].mean()
                gender_metrics[gender]['metrics'][metric] = metric_value

    # Create links from gender to individual metrics
    for gender, data in gender_metrics.items():
        for metric, value in data['metrics'].items():
            metric_node = metric_labels[metric]
            # Scale value to make it visible
            scaled_value = value * data['participants'] / 10
            if scaled_value < 0.1:  # Ensure minimum visibility
                scaled_value = 0.1

            links['source'].append(node_indices[gender])
            links['target'].append(node_indices[metric_node])
            links['value'].append(scaled_value)
            color = gender_colors.get(gender, 'rgba(150, 150, 150, 0.4)')
            if 'rgba' in color:
                color = color.replace('0.85', '0.65')
            links['color'].append(color)
            links['label'].append(f"{gender} → {metric_node}: {value:.2f} avg. impact")

    # Individual Metrics -> Impact Categories links
    print("Processing Individual Metrics -> Impact Categories links...")
    for metric in all_metrics:
        if metric in metric_to_category:
            category = metric_to_category[metric]
            category_node = category_friendly_names[category]
            metric_node = metric_labels[metric]

            # Get average metric value
            metric_avg = df[metric].mean()
            # Scale by weights
            scaled_value = metric_avg * weights[category] * 1000

            links['source'].append(node_indices[metric_node])
            links['target'].append(node_indices[category_node])
            links['value'].append(scaled_value)
            color = metric_base_colors.get(metric, 'rgba(150, 150, 150, 0.4)')
            if 'rgba' in color:
                color = color.replace('0.85', '0.65')
            links['color'].append(color)
            links['label'].append(f"{metric_node} → {category_node}: contributes to {category_node}")

    # 6. Create the Sankey diagram with improved visual settings
    fig = go.Figure(go.Sankey(
        node=dict(
            pad=20,                          # Increased padding
            thickness=25,                    # Increased thickness for better visibility
            line=dict(color="rgba(50, 50, 50, 0.3)", width=0.3),  # Softer line color
            label=all_nodes,
            color=node_colors,
            hovertemplate='<b>%{label}</b><extra></extra>'  # Bold labels in hover
        ),
        link=dict(
            source=links['source'],
            target=links['target'],
            value=links['value'],
            color=links['color'],
            hovertemplate='<b>%{label}</b><extra></extra>',  # Bold labels in hover
            label=links['label']
        )
    ))

    # 7. Add a proper title and improved layout settings
    fig.update_layout(
        title=dict(
            text="Comprehensive Diet-Demographics-Environmental Impact Flow",
            font=dict(size=24, color="#424242", family="Arial, sans-serif"),
            y=0.98
        ),
        font=dict(
            family="Arial, sans-serif",
            size=14,
            color="#424242"
        ),
        height=900,
        width=1600,
        margin=dict(l=50, r=50, t=100, b=50),
        paper_bgcolor='rgba(250, 250, 250, 1)',  # Light gray background
        plot_bgcolor='rgba(250, 250, 250, 1)'    # Light gray background
    )

    # Add subtle watermark/annotation
    fig.add_annotation(
        text="Dietary Impact Analysis",
        x=0.5,
        y=1.05,
        xref="paper",
        yref="paper",
        showarrow=False,
        font=dict(
            family="Arial, sans-serif",
            size=14,
            color="rgba(150, 150, 150, 0.6)"
        )
    )

    print(f"Enhanced Sankey diagram created in {time.time() - start_time:.2f} seconds")
    return fig

# ======================
# INTERACTIVE DASHBOARD WITH WEIGHT CONTROLS
# ======================

def create_interactive_dashboard():
    """
    Creates an interactive dashboard with sliders to adjust impact weights
    """
    app = dash.Dash(__name__)

    app.layout = html.Div([
        html.H1("Diet-Demographics-Environmental Impact Dashboard",
                style={
                    'textAlign': 'center',
                    'fontFamily': 'Arial, sans-serif',
                    'color': '#2C3E50',
                    'marginBottom': '30px',
                    'marginTop': '20px'
                }),

        html.Div([
            html.H3("Adjust Impact Category Weights", style={'color': '#34495E'}),

            html.Div([
                html.Label("Climate Impact Weight:", style={'fontWeight': 'bold', 'color': '#C0392B'}),
                dcc.Slider(
                    id='climate-weight',
                    min=0.1,
                    max=1.0,
                    step=0.1,
                    value=0.4,
                    marks={i/10: str(i/10) for i in range(1, 11)},
                    className='custom-slider'
                )
            ], style={'marginBottom': 20}),

            html.Div([
                html.Label("Land & Biodiversity Impact Weight:", style={'fontWeight': 'bold', 'color': '#27AE60'}),
                dcc.Slider(
                    id='land-weight',
                    min=0.1,
                    max=1.0,
                    step=0.1,
                    value=0.3,
                    marks={i/10: str(i/10) for i in range(1, 11)},
                    className='custom-slider'
                )
            ], style={'marginBottom': 20}),

            html.Div([
                html.Label("Water Resources Impact Weight:", style={'fontWeight': 'bold', 'color': '#3498DB'}),
                dcc.Slider(
                    id='water-weight',
                    min=0.1,
                    max=1.0,
                    step=0.1,
                    value=0.2,
                    marks={i/10: str(i/10) for i in range(1, 11)},
                    className='custom-slider'
                )
            ], style={'marginBottom': 20}),

            html.Div([
                html.Label("Chemical Pollution Impact Weight:", style={'fontWeight': 'bold', 'color': '#9B59B6'}),
                dcc.Slider(
                    id='chemical-weight',
                    min=0.1,
                    max=1.0,
                    step=0.1,
                    value=0.1,
                    marks={i/10: str(i/10) for i in range(1, 11)},
                    className='custom-slider'
                )
            ], style={'marginBottom': 20}),

            html.Button('Reset Weights',
                      id='reset-button',
                      n_clicks=0,
                      style={
                          'backgroundColor': '#34495E',
                          'color': 'white',
                          'border': 'none',
                          'padding': '10px 20px',
                          'borderRadius': '5px',
                          'cursor': 'pointer',
                          'fontSize': '16px'
                      }),

        ], style={
            'padding': '25px',
            'backgroundColor': '#f8f9fa',
            'borderRadius': '10px',
            'boxShadow': '0 4px 8px rgba(0,0,0,0.1)',
            'marginBottom': '30px'
        }),

        # The main Sankey diagram
        dcc.Graph(id='sankey-graph'),

        # Helper text
        html.Div([
            html.H4("How to Read This Visualization", style={'color': '#34495E'}),
            html.P([
                "This Sankey diagram shows the flow from diet types through demographics to environmental impacts. ",
                "The width of each connection represents its relative magnitude. ",
                "The diagram follows this path: Diet Categories → Age Groups → Gender → Individual Metrics → Impact Categories."
            ], style={'lineHeight': '1.6'}),
            html.P([
                "Use the sliders above to adjust the weights of different environmental impact categories ",
                "and see how they affect the environmental impact assessment."
            ], style={'lineHeight': '1.6'})
        ], style={
            'padding': '25px',
            'backgroundColor': '#f8f9fa',
            'borderRadius': '10px',
            'boxShadow': '0 4px 8px rgba(0,0,0,0.1)',
            'marginTop': '30px'
        }),

    ], style={
        'maxWidth': '1800px',
        'margin': '0 auto',
        'padding': '20px',
        'fontFamily': 'Arial, sans-serif'
    })

    @app.callback(
        Output('sankey-graph', 'figure'),
        [Input('climate-weight', 'value'),
         Input('land-weight', 'value'),
         Input('water-weight', 'value'),
         Input('chemical-weight', 'value'),
         Input('reset-button', 'n_clicks')]
    )
    def update_graph(climate_weight, land_weight, water_weight, chemical_weight, n_clicks):
        # Define weights
        weights = {
            'Climate_Impact': climate_weight,
            'Land_Impact': land_weight,
            'Water_Impact': water_weight,
            'Chemical_Impact': chemical_weight
        }

        # Create the comprehensive Sankey diagram with current weights
        return create_comprehensive_sankey(df, weights)

    return app

# ======================
# EXECUTION
# ======================

print("\n=== GENERATING COMPREHENSIVE VISUALIZATION ===")
try:
    # Create a static version of the comprehensive Sankey diagram
    print("Creating enhanced Sankey diagram with improved color scheme...")
    comp_sankey = create_comprehensive_sankey(df)
    comp_sankey.show()
    comp_sankey.write_html("enhanced_diet_demographic_impact_sankey.html")
    print("Enhanced Sankey diagram saved to HTML file")

    # Create the interactive dashboard with weight controllers
    print("\nSetting up interactive dashboard...")
    app = create_interactive_dashboard()
    print("Dashboard ready! Run app.run_server(debug=True) to start the dashboard")

    # Uncomment to run the dashboard:
    app.run(debug=True)

except Exception as e:
    print(f"Error in visualization creation: {e}")
    import traceback
    traceback.print_exc()

Loading dataset...
Successfully loaded Oxford dataset

=== DATA VALIDATION ===
Columns in dataset: ['mc_run_id', 'grouping', 'mean_ghgs', 'mean_land', 'mean_watscar', 'mean_eut', 'mean_ghgs_ch4', 'mean_ghgs_n2o', 'mean_bio', 'mean_watuse', 'mean_acid', 'sd_ghgs', 'sd_land', 'sd_watscar', 'sd_eut', 'sd_ghgs_ch4', 'sd_ghgs_n2o', 'sd_bio', 'sd_watuse', 'sd_acid', 'n_participants', 'sex', 'diet_group', 'age_group']

Unique diet groups: ['fish' 'meat100' 'meat50' 'meat' 'vegan' 'veggie']
Unique sex values: ['female' 'male']
Unique age groups: ['20-29' '30-39' '40-49' '50-59' '60-69' '70-79']
Categorizing diets...
Calculating impact scores...

=== GENERATING COMPREHENSIVE VISUALIZATION ===
Creating enhanced Sankey diagram with improved color scheme...
Creating comprehensive Sankey diagram with enhanced color scheme...
Total nodes: 27
Processing Diet -> Age links...
Processing Age -> Gender links...
Processing Gender -> Individual Metrics links...
Processing Individual Metrics -> Impact Categ

Enhanced Sankey diagram saved to HTML file

Setting up interactive dashboard...
Dashboard ready! Run app.run_server(debug=True) to start the dashboard


Creating comprehensive Sankey diagram with enhanced color scheme...
Total nodes: 27
Processing Diet -> Age links...
Processing Age -> Gender links...
Processing Gender -> Individual Metrics links...
Processing Individual Metrics -> Impact Categories links...
Enhanced Sankey diagram created in 0.99 seconds
Creating comprehensive Sankey diagram with enhanced color scheme...
Total nodes: 27
Processing Diet -> Age links...
Processing Age -> Gender links...
Processing Gender -> Individual Metrics links...
Processing Individual Metrics -> Impact Categories links...
Enhanced Sankey diagram created in 0.46 seconds
Creating comprehensive Sankey diagram with enhanced color scheme...
Total nodes: 27
Processing Diet -> Age links...
Processing Age -> Gender links...
Processing Gender -> Individual Metrics links...
Processing Individual Metrics -> Impact Categories links...
Enhanced Sankey diagram created in 0.08 seconds
