In [7]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Function to create sample data if remote data loading fails
# def create_sample_data():
#     print("Creating sample data since remote data couldn't be loaded...")
#     n_samples = 100
#     data = {
#         'mc_run_id': range(1, n_samples+1),
#         'grouping': np.random.choice(['Group A', 'Group B', 'Group C'], n_samples),
#         'mean_ghgs': np.random.normal(10, 3, n_samples),
#         'mean_land': np.random.normal(5, 1.5, n_samples),
#         'mean_watscar': np.random.normal(8, 2, n_samples),
#         'mean_eut': np.random.normal(4, 1, n_samples),
#         'mean_ghgs_ch4': np.random.normal(3, 0.8, n_samples),
#         'mean_ghgs_n2o': np.random.normal(2, 0.5, n_samples),
#         'mean_bio': np.random.normal(6, 1.8, n_samples),
#         'mean_watuse': np.random.normal(7, 2.5, n_samples),
#         'mean_acid': np.random.normal(3, 1.2, n_samples),
#         'sd_ghgs': np.random.uniform(0.5, 1.5, n_samples),
#         'sd_land': np.random.uniform(0.3, 0.9, n_samples),
#         'sd_watscar': np.random.uniform(0.4, 1.2, n_samples),
#         'sd_eut': np.random.uniform(0.2, 0.6, n_samples),
#         'sd_ghgs_ch4': np.random.uniform(0.1, 0.4, n_samples),
#         'sd_ghgs_n2o': np.random.uniform(0.1, 0.3, n_samples),
#         'sd_bio': np.random.uniform(0.3, 0.9, n_samples),
#         'sd_watuse': np.random.uniform(0.5, 1.5, n_samples),
#         'sd_acid': np.random.uniform(0.2, 0.8, n_samples),
#         'n_participants': np.random.randint(10, 100, n_samples),
#         'sex': np.random.choice(['M', 'F'], n_samples),
#         'diet_group': np.random.choice(['vegan', 'vegetarian', 'fish', 'meat50-99', 'meat100'], n_samples),
#         'age_group': np.random.choice(['20-30', '30-40', '40-50', '50-60', '60+'], n_samples)
#     }
#     return pd.DataFrame(data)

# Load or create data
try:
    url = "FoodDataset.csv"
    df = pd.read_csv(url)
    if 'diet_group' not in df.columns or len(df.columns) < 10:
        raise ValueError("Downloaded data does not contain expected columns")
    print("Successfully loaded remote dataset")
except Exception as e:
    print(f"Error loading remote dataset: {e}")
    #df = create_sample_data()

print("\nAvailable columns:", df.columns.tolist())
all_numeric_cols = ['mean_ghgs', 'mean_land', 'mean_watscar', 'mean_eut', 'mean_ghgs_ch4', 
                    'mean_ghgs_n2o', 'mean_bio', 'mean_acid', 'mean_watuse']
numeric_cols = [col for col in all_numeric_cols if col in df.columns]
print(f"\nUsing numeric columns: {numeric_cols}")

for col in numeric_cols:
    df[col] = pd.to_numeric(df[col], errors='coerce')
    df[col] = df[col].fillna(df[col].mean())

def categorize_diet(diet):
    diet_str = str(diet).lower()
    if 'vegan' in diet_str:
        return 'Vegan'
    elif any(term in diet_str for term in ['vegetarian', 'veggie']):
        return 'Vegetarian'
    elif any(term in diet_str for term in ['fish', 'pescatarian']):
        return 'Pescatarian'
    elif any(term in diet_str for term in ['meat100', '>100g']):
        return 'High Meat (>100g/day)'
    elif any(term in diet_str for term in ['meat50-99', '50-99g']):
        return 'Medium Meat (50-99g/day)'
    elif 'meat' in diet_str:
        return 'Low Meat (<50g/day)'
    else:
        return 'Other'

try:
    df['Diet_Category'] = df['diet_group'].apply(categorize_diet)
except KeyError:
    print("Warning: 'diet_group' column not found. Using 'grouping' as fallback.")
    if 'grouping' in df.columns:
        df['Diet_Category'] = df['grouping'].astype(str).apply(categorize_diet)
    else:
        print("Creating placeholder Diet_Category")
        df['Diet_Category'] = 'Unknown'

try:
    df['Age_Group'] = df['age_group'].str.extract('(\d+)').astype(int)
    df['Age_Category'] = pd.cut(df['Age_Group'], 
                                bins=[0, 30, 40, 50, 60, 100],
                                labels=['Under 30', '30-40', '40-50', '50-60', 'Over 60'])
except Exception as e:
    print(f"Warning when processing age groups: {e}")
    df['Age_Category'] = 'Unknown Age'
    
if 'sex' not in df.columns:
    print("Warning: 'sex' column not found. Creating placeholder.")
    df['sex'] = np.random.choice(['M', 'F'], len(df))

for col in numeric_cols:
    col_norm = f"{col}_norm"
    if df[col].max() > df[col].min():
        df[col_norm] = (df[col] - df[col].min()) / (df[col].max() - df[col].min())
    else:
        df[col_norm] = 0

available_norm_cols = [f"{col}_norm" for col in numeric_cols if f"{col}_norm" in df.columns]
weights = {}
if 'mean_ghgs_norm' in available_norm_cols:
    weights['mean_ghgs_norm'] = 0.25
if 'mean_land_norm' in available_norm_cols:
    weights['mean_land_norm'] = 0.20
if 'mean_watscar_norm' in available_norm_cols:
    weights['mean_watscar_norm'] = 0.15
if 'mean_watuse_norm' in available_norm_cols:
    weights['mean_watuse_norm'] = 0.15
if 'mean_eut_norm' in available_norm_cols:
    weights['mean_eut_norm'] = 0.10
if 'mean_bio_norm' in available_norm_cols:
    weights['mean_bio_norm'] = 0.10
if 'mean_acid_norm' in available_norm_cols:
    weights['mean_acid_norm'] = 0.05
total_weight = sum(weights.values())
weights = {col: weight/total_weight for col, weight in weights.items()}
df['Environmental_Impact_Score'] = sum(df[col] * weight for col, weight in weights.items())

if 'n_participants' not in df.columns:
    print("Warning: 'n_participants' column not found. Creating placeholder.")
    df['n_participants'] = np.random.randint(10, 100, len(df))

sankey_df = df.groupby(['Diet_Category', 'Age_Category', 'sex']).agg({
    'n_participants': 'sum',
    'Environmental_Impact_Score': 'mean',
    **{col: 'mean' for col in numeric_cols}
}).reset_index()

diet_categories = sankey_df['Diet_Category'].unique()
age_categories = sankey_df['Age_Category'].unique()
genders = sankey_df['sex'].unique()

nodes = []
node_colors = []
diet_color_map = {
    'Vegan': '#01665e',
    'Vegetarian': '#5ab4ac',
    'Pescatarian': '#80cdc1',
    'Low Meat (<50g/day)': '#dfc27d',
    'Medium Meat (50-99g/day)': '#bf812d',
    'High Meat (>100g/day)': '#8c510a',
    'Other': '#c7eae5',
    'Unknown': '#cccccc'
}
for diet in diet_categories:
    nodes.append(diet)
    node_colors.append(diet_color_map.get(diet, '#c7eae5'))
age_start_idx = len(nodes)
for age in age_categories:
    nodes.append(age)
    node_colors.append("#c6dbef")
gender_start_idx = len(nodes)
for gender in genders:
    nodes.append(gender)
    node_colors.append("#9ecae1")
impact_start_idx = len(nodes)
env_impacts = []
if 'mean_land' in numeric_cols:
    env_impacts.append("Land Use")
if 'mean_ghgs' in numeric_cols:
    env_impacts.append("GHG Emissions")
if 'mean_watscar' in numeric_cols:
    env_impacts.append("Water Scarcity")
if 'mean_bio' in numeric_cols:
    env_impacts.append("Biodiversity")
if 'mean_watuse' in numeric_cols:
    env_impacts.append("Water Usage")
if 'mean_acid' in numeric_cols:
    env_impacts.append("Acidification")
if 'mean_eut' in numeric_cols:
    env_impacts.append("Eutrophication")
for impact in env_impacts:
    nodes.append(impact)
    node_colors.append("#2171b5")

links_source = []
links_target = []
links_value = []
links_color = []

for i, row in sankey_df.iterrows():
    diet_idx = list(diet_categories).index(row['Diet_Category'])
    age_idx = list(age_categories).index(row['Age_Category']) + age_start_idx
    links_source.append(diet_idx)
    links_target.append(age_idx)
    links_value.append(row['n_participants'])
    links_color.append(diet_color_map.get(row['Diet_Category'], '#c7eae5'))
for i, row in sankey_df.iterrows():
    age_idx = list(age_categories).index(row['Age_Category']) + age_start_idx
    gender_idx = list(genders).index(row['sex']) + gender_start_idx
    links_source.append(age_idx)
    links_target.append(gender_idx)
    links_value.append(row['n_participants'])
    links_color.append('#c6dbef')

def calculate_impact_intensity(value, min_val, max_val):
    try:
        if np.isnan(value) or np.isnan(min_val) or np.isnan(max_val):
            return 100
        if max_val == min_val:
            return 100
        intensity = int(255 * (value - min_val) / (max_val - min_val))
        return max(0, min(255, intensity))
    except:
        return 100

for i, row in sankey_df.iterrows():
    gender_idx = list(genders).index(row['sex']) + gender_start_idx
    if 'mean_land' in numeric_cols and "Land Use" in env_impacts:
        impact_idx = impact_start_idx + env_impacts.index("Land Use")
        links_source.append(gender_idx)
        links_target.append(impact_idx)
        scaled_land = row['mean_land'] * row['n_participants'] / 1000
        links_value.append(max(1, scaled_land))
        impact_intensity = calculate_impact_intensity(row['mean_land'], df['mean_land'].min(), df['mean_land'].max())
        links_color.append(f'rgba({impact_intensity}, 50, 50, 0.6)')
    if 'mean_ghgs' in numeric_cols and "GHG Emissions" in env_impacts:
        impact_idx = impact_start_idx + env_impacts.index("GHG Emissions")
        links_source.append(gender_idx)
        links_target.append(impact_idx)
        scaled_ghg = row['mean_ghgs'] * row['n_participants'] / 1000
        links_value.append(max(1, scaled_ghg))
        impact_intensity = calculate_impact_intensity(row['mean_ghgs'], df['mean_ghgs'].min(), df['mean_ghgs'].max())
        links_color.append(f'rgba(50, {impact_intensity}, 50, 0.6)')
    if 'mean_watscar' in numeric_cols and "Water Scarcity" in env_impacts:
        impact_idx = impact_start_idx + env_impacts.index("Water Scarcity")
        links_source.append(gender_idx)
        links_target.append(impact_idx)
        scaled_water = row['mean_watscar'] * row['n_participants'] / 50000
        links_value.append(max(1, scaled_water))
        impact_intensity = calculate_impact_intensity(row['mean_watscar'], df['mean_watscar'].min(), df['mean_watscar'].max())
        links_color.append(f'rgba(50, 50, {impact_intensity}, 0.6)')
    if 'mean_bio' in numeric_cols and "Biodiversity" in env_impacts:
        impact_idx = impact_start_idx + env_impacts.index("Biodiversity")
        links_source.append(gender_idx)
        links_target.append(impact_idx)
        scaled_bio = row['mean_bio'] * row['n_participants'] / 5000
        links_value.append(max(1, scaled_bio))
        impact_intensity = calculate_impact_intensity(row['mean_bio'], df['mean_bio'].min(), df['mean_bio'].max())
        links_color.append(f'rgba({impact_intensity}, {impact_intensity}, 50, 0.6)')
    if 'mean_watuse' in numeric_cols and "Water Usage" in env_impacts:
        impact_idx = impact_start_idx + env_impacts.index("Water Usage")
        links_source.append(gender_idx)
        links_target.append(impact_idx)
        scaled_water_use = row['mean_watuse'] * row['n_participants'] / 50000
        links_value.append(max(1, scaled_water_use))
        impact_intensity = calculate_impact_intensity(row['mean_watuse'], df['mean_watuse'].min(), df['mean_watuse'].max())
        links_color.append(f'rgba({impact_intensity}, {impact_intensity}, {impact_intensity}, 0.6)')
    if 'mean_acid' in numeric_cols and "Acidification" in env_impacts:
        impact_idx = impact_start_idx + env_impacts.index("Acidification")
        links_source.append(gender_idx)
        links_target.append(impact_idx)
        scaled_acid = row['mean_acid'] * row['n_participants'] / 5000
        links_value.append(max(1, scaled_acid))
        impact_intensity = calculate_impact_intensity(row['mean_acid'], df['mean_acid'].min(), df['mean_acid'].max())
        links_color.append(f'rgba({impact_intensity}, 50, {impact_intensity}, 0.6)')
    if 'mean_eut' in numeric_cols and "Eutrophication" in env_impacts:
        impact_idx = impact_start_idx + env_impacts.index("Eutrophication")
        links_source.append(gender_idx)
        links_target.append(impact_idx)
        scaled_eut = row['mean_eut'] * row['n_participants'] / 5000
        links_value.append(max(1, scaled_eut))
        impact_intensity = calculate_impact_intensity(row['mean_eut'], df['mean_eut'].min(), df['mean_eut'].max())
        links_color.append(f'rgba(50, {impact_intensity}, {impact_intensity}, 0.6)')

fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=nodes,
        color=node_colors
    ),
    link=dict(
        source=links_source,
        target=links_target,
        value=links_value,
        color=links_color
    )
)])

fig.update_layout(
    title_text="Diet Choices Flow to Environmental Impacts: A Sankey Diagram Analysis",
    font_size=12,
    height=800,
    width=1200,
    margin=dict(l=25, r=25, t=50, b=25)
)

top_diets = ['Vegan', 'Vegetarian', 'Pescatarian', 'High Meat (>100g/day)']
top_diets = [d for d in top_diets if d in df['Diet_Category'].unique()]
if not top_diets:
    top_diets = df['Diet_Category'].unique()[:4]

radar_df = df[df['Diet_Category'].isin(top_diets)].groupby('Diet_Category').agg(
    {col: 'mean' for col in numeric_cols}
).reset_index()
metrics = numeric_cols
for col in metrics:
    col_min = radar_df[col].min()
    col_max = radar_df[col].max()
    if col_max > col_min:
        radar_df[col] = (radar_df[col] - col_min) / (col_max - col_min)
    else:
        radar_df[col] = 0

radar_fig = make_subplots(
    rows=1, cols=1,
    specs=[[{'type': 'polar'}]]
)

metric_labels = {
    'mean_ghgs': 'GHG Emissions',
    'mean_land': 'Land Use',
    'mean_watscar': 'Water Scarcity',
    'mean_eut': 'Eutrophication',
    'mean_bio': 'Biodiversity Loss',
    'mean_acid': 'Acidification',
    'mean_watuse': 'Water Usage',
    'mean_ghgs_ch4': 'Methane Emissions',
    'mean_ghgs_n2o': 'Nitrous Oxide'
}
for diet in top_diets:
    diet_data = radar_df[radar_df['Diet_Category'] == diet]
    if not diet_data.empty:
        values = diet_data[metrics].values.flatten().tolist()
        values = values + [values[0]]
        theta_labels = [metric_labels.get(metric, metric) for metric in metrics]
        theta_labels = theta_labels + [theta_labels[0]]
        radar_fig.add_trace(
            go.Scatterpolar(
                r=values,
                theta=theta_labels,
                fill='toself',
                name=diet,
                line_color=diet_color_map.get(diet)
            )
        )

radar_fig.update_layout(
    polar=dict(
        radialaxis=dict(
            visible=True,
            range=[0, 1]
        )
    ),
    title="Environmental Impact Profile by Diet Type",
    showlegend=True,
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.2,
        xanchor="center",
        x=0.5
    ),
    height=600,
    width=600
)

fig.show()
radar_fig.show()

fig.write_html("sankey_diet_environment.html")
radar_fig.write_html("radar_diet_environment.html")

def analyze_environmental_impact(df, numeric_cols):
    insights = []
    diet_summary = df.groupby('Diet_Category').agg(
        n_participants=('n_participants', 'sum'),
        env_impact_mean=('Environmental_Impact_Score', 'mean'),
        **{f"{col}_mean": (col, 'mean') for col in numeric_cols}
    ).reset_index()
    for col in numeric_cols:
        col_name = f"{col}_mean"
        if col_name in diet_summary.columns:
            best_diet = diet_summary.loc[diet_summary[col_name].idxmin()]['Diet_Category']
            worst_diet = diet_summary.loc[diet_summary[col_name].idxmax()]['Diet_Category']
            diff_pct = ((diet_summary[col_name].max() - diet_summary[col_name].min()) / 
                        diet_summary[col_name].min() * 100)
            insights.append(f"For {col.replace('mean_', '')}, {best_diet} has the lowest impact while "
                           f"{worst_diet} has the highest impact (difference of {diff_pct:.1f}%).")
    diet_categories = diet_summary['Diet_Category'].unique()
    plant_based = []
    for diet in diet_categories:
        if any(term in diet.lower() for term in ['vegan', 'vegetarian', 'plant']):
            plant_based.append(diet)
    for col in numeric_cols:
        col_name = f"{col}_mean"
        if col_name in diet_summary.columns and plant_based:
            ranked = diet_summary.sort_values(by=col_name)
            best_diet = ranked['Diet_Category'].iloc[0]
            for diet in plant_based:
                if diet in ranked['Diet_Category'].values and diet != best_diet:
                    diet_idx = ranked[ranked['Diet_Category'] == diet].index[0]
                    better_diets = ranked['Diet_Category'].iloc[:diet_idx].tolist()
                    if better_diets:
                        insights.append(f"Unexpected finding: {diet} performs worse on {col.replace('mean_', '')} "
                                      f"than {', '.join(better_diets)}.")
    if 'Age_Category' in df.columns:
        try:
            age_diet = pd.crosstab(df['Age_Category'], df['Diet_Category'], normalize="index") * 100
            for age in age_diet.index:
                top_diet = age_diet.loc[age].idxmax()
                proportion = age_diet.loc[age, top_diet]
                insights.append(f"The {age} age group predominantly follows {top_diet} diets ({proportion:.1f}%).")
        except Exception:
            pass
    if 'sex' in df.columns and len(df['sex'].unique()) > 1:
        try:
            gender_impact = df.groupby('sex')[numeric_cols].mean()
            genders = gender_impact.index.tolist()
            if len(genders) >= 2:
                for col in numeric_cols:
                    if col in gender_impact.columns:
                        for i in range(len(genders)):
                            for j in range(i+1, len(genders)):
                                gender1 = genders[i]
                                gender2 = genders[j]
                                val1 = gender_impact.loc[gender1, col]
                                val2 = gender_impact.loc[gender2, col]
                                if min(val1, val2) > 0:
                                    threshold = np.mean([val1, val2]) * 0.1
                                    if abs(val1 - val2) > threshold:
                                        if val1 > val2:
                                            diff_pct = (val1 - val2) / val2 * 100
                                            insights.append(f"{gender1} have {diff_pct:.1f}% higher {col.replace('mean_', '')} "
                                                         f"impact than {gender2}.")
                                        else:
                                            diff_pct = (val2 - val1) / val1 * 100
                                            insights.append(f"{gender2} have {diff_pct:.1f}% higher {col.replace('mean_', '')} "
                                                         f"impact than {gender1}.")
        except Exception:
            pass
    if len(diet_summary) > 1:
        reference_diet = diet_summary.loc[diet_summary['env_impact_mean'].idxmax()]['Diet_Category']
        reference_impacts = {}
        for col in numeric_cols:
            col_name = f"{col}_mean"
            if col_name in diet_summary.columns:
                ref_val = diet_summary.loc[diet_summary['Diet_Category'] == reference_diet, col_name].values
                if len(ref_val) > 0 and ref_val[0] > 0:
                    reference_impacts[col] = ref_val[0]
        for diet in diet_summary['Diet_Category'].unique():
            if diet != reference_diet:
                relative_impacts = []
                for col in numeric_cols:
                    col_name = f"{col}_mean"
                    if col in reference_impacts and col_name in diet_summary.columns:
                        diet_val = diet_summary.loc[diet_summary['Diet_Category'] == diet, col_name].values[0]
                        pct = (diet_val / reference_impacts[col]) * 100
                        relative_impacts.append(f"{col.replace('mean_', '')}: {pct:.1f}%")
                if relative_impacts:
                    insights.append(f"{diet} environmental impacts relative to {reference_diet}: {', '.join(relative_impacts)}.")
    return insights

def format_summary(insights):
    categories = {}
    for insight in insights:
        if "has the lowest impact" in insight:
            if "Impact Comparisons" not in categories:
                categories["Impact Comparisons"] = []
            categories["Impact Comparisons"].append(insight)
        elif "Unexpected finding" in insight:
            if "Unexpected Patterns" not in categories:
                categories["Unexpected Patterns"] = []
            categories["Unexpected Patterns"].append(insight)
        elif "predominantly follows" in insight:
            if "Demographic Patterns" not in categories:
                categories["Demographic Patterns"] = []
            categories["Demographic Patterns"].append(insight)
        elif "higher" in insight and "impact than" in insight:
            if "Comparative Impacts" not in categories:
                categories["Comparative Impacts"] = []
            categories["Comparative Impacts"].append(insight)
        elif "relative to" in insight:
            if "Relative Impacts" not in categories:
                categories["Relative Impacts"] = []
            categories["Relative Impacts"].append(insight)
        else:
            if "Other Findings" not in categories:
                categories["Other Findings"] = []
            categories["Other Findings"].append(insight)
    summary = "# Data-Driven Analysis of Environmental Impacts by Diet Type\n\n"
    for category, items in categories.items():
        if items:
            summary += f"## {category}\n"
            for item in items:
                summary += f"- {item}\n"
            summary += "\n"
    return summary

insights = analyze_environmental_impact(df, numeric_cols)
summary_output = format_summary(insights)
print(summary_output)
with open("diet_environmental_analysis.txt", "w") as f:
    f.write(summary_output)


Successfully loaded remote dataset

Available columns: ['mc_run_id', 'grouping', 'mean_ghgs', 'mean_land', 'mean_watscar', 'mean_eut', 'mean_ghgs_ch4', 'mean_ghgs_n2o', 'mean_bio', 'mean_watuse', 'mean_acid', 'sd_ghgs', 'sd_land', 'sd_watscar', 'sd_eut', 'sd_ghgs_ch4', 'sd_ghgs_n2o', 'sd_bio', 'sd_watuse', 'sd_acid', 'n_participants', 'sex', 'diet_group', 'age_group']

Using numeric columns: ['mean_ghgs', 'mean_land', 'mean_watscar', 'mean_eut', 'mean_ghgs_ch4', 'mean_ghgs_n2o', 'mean_bio', 'mean_acid', 'mean_watuse']


# Data-Driven Analysis of Environmental Impacts by Diet Type

## Impact Comparisons
- For ghgs, Vegan has the lowest impact while High Meat (>100g/day) has the highest impact (difference of 434.1%).
- For land, Vegan has the lowest impact while High Meat (>100g/day) has the highest impact (difference of 544.1%).
- For watscar, Vegan has the lowest impact while High Meat (>100g/day) has the highest impact (difference of 79.4%).
- For eut, Vegan has the lowest impact while High Meat (>100g/day) has the highest impact (difference of 352.4%).
- For ghgs_ch4, Vegan has the lowest impact while High Meat (>100g/day) has the highest impact (difference of 1855.5%).
- For ghgs_n2o, Vegan has the lowest impact while High Meat (>100g/day) has the highest impact (difference of 366.7%).
- For bio, Vegan has the lowest impact while High Meat (>100g/day) has the highest impact (difference of 315.8%).
- For acid, Vegan has the lowest impact while High Meat (>100g/day) has the highest impact (difference

In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import time

# ======================
# DATA LOADING & PREP
# ======================

print("Loading dataset...")
df = pd.read_csv("FoodDataset.csv")
print("Successfully loaded Oxford dataset")

# Standardize column names
column_mapping = {
    'ghg': 'mean_ghgs',
    'ch4': 'mean_ghgs_ch4',
    'n2o': 'mean_ghgs_n2o',
    'land': 'mean_land',
    'bio': 'mean_bio',
    'watscar': 'mean_watscar',
    'watuse': 'mean_watuse',
    'eut': 'mean_eut',
    'acid': 'mean_acid',
    'n': 'n_participants',
    'sex': 'sex',
    'age': 'age_group',
    'diet_group': 'diet_group'
}
column_mapping = {k:v for k,v in column_mapping.items() if k in df.columns}
df = df.rename(columns=column_mapping)

# Clean and standardize values
df['diet_group'] = df['diet_group'].str.lower().str.strip()
df['sex'] = df['sex'].str.lower().str.strip()

# ======================
# DATA VALIDATION
# ======================

print("\n=== DATA VALIDATION ===")
print("Columns in dataset:", df.columns.tolist())
print("\nUnique diet groups:", df['diet_group'].unique())
print("Unique sex values:", df['sex'].unique())
print("Unique age groups:", df['age_group'].unique())

# ======================
# DATA PROCESSING
# ======================

# Define impact categories with better organization
impact_categories = {
    'Climate_Impact': ['mean_ghgs', 'mean_ghgs_ch4', 'mean_ghgs_n2o'],
    'Land_Impact': ['mean_land', 'mean_bio'],
    'Water_Impact': ['mean_watscar', 'mean_watuse'],
    'Chemical_Impact': ['mean_eut', 'mean_acid']
}

# Define metric to category mapping
metric_to_category = {
    'mean_ghgs': 'Climate_Impact',
    'mean_ghgs_ch4': 'Climate_Impact',
    'mean_ghgs_n2o': 'Climate_Impact',
    'mean_land': 'Land_Impact',
    'mean_bio': 'Land_Impact',
    'mean_watscar': 'Water_Impact',
    'mean_watuse': 'Water_Impact',
    'mean_eut': 'Chemical_Impact',
    'mean_acid': 'Chemical_Impact'
}

# Friendly names for display
category_friendly_names = {
    'Climate_Impact': 'Climate Impact',
    'Land_Impact': 'Land & Biodiversity Impact',
    'Water_Impact': 'Water Resources Impact',
    'Chemical_Impact': 'Chemical Pollution Impact'
}

# Enhanced diet categorization
def categorize_diet(diet):
    diet = str(diet).lower().strip()
    if 'vegan' in diet:
        return 'Vegan'
    elif 'veggie' in diet or 'vegetarian' in diet:
        return 'Vegetarian'
    elif 'fish' in diet or 'pescatarian' in diet:
        return 'Pescatarian'
    elif 'meat100' in diet or '>100' in diet:
        return 'High Meat (>100g/day)'
    elif 'meat50-99' in diet or ('50' in diet and '99' in diet):
        return 'Medium Meat (50-99g/day)'
    elif 'meat50' in diet or '<50' in diet:
        return 'Low Meat (<50g/day)'
    elif 'meat' in diet:  # Catch-all for generic meat entries
        return 'Medium Meat (50-99g/day)'  # Default assumption
    else:
        print(f"Unknown diet group: {diet}")
        return 'Other'

print("Categorizing diets...")
df['Diet_Category'] = df['diet_group'].apply(categorize_diet)

# Convert sex values to Male/Female
df['Sex'] = df['sex'].str.lower().str.strip().map({'male': 'Male', 'female': 'Female'})
df['Age_Category'] = df['age_group']

# Calculate impact scores
print("Calculating impact scores...")
for category, metrics in impact_categories.items():
    available_metrics = [m for m in metrics if m in df.columns]
    if available_metrics:
        # First normalize each metric
        for metric in available_metrics:
            if df[metric].max() > df[metric].min():  # Avoid div by zero
                df[f'{metric}_norm'] = (df[metric] - df[metric].min()) / (df[metric].max() - df[metric].min())
            else:
                df[f'{metric}_norm'] = 0  # Default if all values are the same
        # Then calculate category score
        df[category] = df[[f'{m}_norm' for m in available_metrics]].mean(axis=1)

# ======================
# IMPROVED COMPREHENSIVE SANKEY DIAGRAM WITH ENHANCED COLOR SCHEME
# ======================

def create_comprehensive_sankey(df, weights=None):
    """
    Creates a comprehensive Sankey diagram showing:
    Diet → Age → Gender → Individual Metrics → Impact Categories

    Args:
        df: DataFrame with the data
        weights: Dict with category weights (optional)
    """
    print("Creating comprehensive Sankey diagram with enhanced color scheme...")
    start_time = time.time()

    if weights is None:
        weights = {
            'Climate_Impact': 0.4,
            'Land_Impact': 0.3,
            'Water_Impact': 0.2,
            'Chemical_Impact': 0.1
        }

    # 1. Prepare nodes
    # Diet categories
    diet_nodes = sorted(df['Diet_Category'].unique())

    # Age groups
    age_nodes = sorted(df['Age_Category'].unique())

    # Gender groups
    gender_nodes = sorted(df['Sex'].unique())

    # Create more readable metric labels
    metric_labels = {
        'mean_ghgs': 'Greenhouse Gas Emissions',
        'mean_ghgs_ch4': 'Methane (CH4) Emissions',
        'mean_ghgs_n2o': 'Nitrous Oxide (N2O) Emissions',
        'mean_land': 'Land Use',
        'mean_bio': 'Biodiversity Loss',
        'mean_watscar': 'Water Scarcity',
        'mean_watuse': 'Water Usage',
        'mean_eut': 'Eutrophication',
        'mean_acid': 'Acidification'
    }

    # Impact categories
    impact_categories_available = {k: v for k, v in impact_categories.items()
                                  if any(metric in df.columns for metric in v)}
    impact_category_nodes = list(impact_categories_available.keys())

    # Individual metrics
    all_metrics = []
    for category, metrics in impact_categories_available.items():
        all_metrics.extend([m for m in metrics if m in df.columns])

    # 2. Create nodes list in order of flow:
    # Diet → Age → Gender → Individual Metrics → Impact Categories
    all_nodes = (
        diet_nodes +
        age_nodes +
        gender_nodes +
        [metric_labels[m] for m in all_metrics] +
        [category_friendly_names[cat] for cat in impact_category_nodes]
    )

    print(f"Total nodes: {len(all_nodes)}")

    # Create node indices dictionary
    node_indices = {node: i for i, node in enumerate(all_nodes)}

    # 3. Define IMPROVED color mappings with better contrast and cohesive palette
    diet_colors = {
        'Vegan': 'rgba(39, 174, 96, 0.85)',          # Green
        'Vegetarian': 'rgba(142, 202, 100, 0.85)',    # Light green
        'Pescatarian': 'rgba(52, 152, 219, 0.85)',    # Blue
        'Low Meat (<50g/day)': 'rgba(241, 196, 15, 0.85)',  # Yellow
        'Medium Meat (50-99g/day)': 'rgba(230, 126, 34, 0.85)', # Orange
        'High Meat (>100g/day)': 'rgba(231, 76, 60, 0.85)',  # Red
        'Other': 'rgba(189, 195, 199, 0.85)'          # Gray
    }

    # Color gradient for age groups (from younger to older)
    age_colors = {
        '20-29': 'rgba(108, 92, 231, 0.85)',    # Violet
        '30-39': 'rgba(162, 155, 254, 0.85)',   # Light violet
        '40-49': 'rgba(116, 185, 255, 0.85)',   # Light blue
        '50-59': 'rgba(30, 144, 255, 0.85)',    # Dodger blue
        '60-69': 'rgba(25, 118, 210, 0.85)',    # Medium blue
        '70-79': 'rgba(21, 101, 192, 0.85)'     # Dark blue
    }

    gender_colors = {
        'Male': 'rgba(91, 192, 222, 0.85)',    # Cyan
        'Female': 'rgba(240, 98, 146, 0.85)'   # Pink
    }

    # More distinctive colors for metrics with a cohesive palette
    metric_base_colors = {
        # Climate metrics - red variants
        'mean_ghgs': 'rgba(229, 57, 53, 0.85)',       # Red
        'mean_ghgs_ch4': 'rgba(244, 81, 30, 0.85)',   # Orange-red
        'mean_ghgs_n2o': 'rgba(230, 74, 25, 0.75)',   # Lighter orange-red

        # Land metrics - green variants
        'mean_land': 'rgba(56, 142, 60, 0.85)',       # Green
        'mean_bio': 'rgba(104, 159, 56, 0.85)',       # Light green

        # Water metrics - blue variants
        'mean_watscar': 'rgba(3, 169, 244, 0.85)',    # Light blue
        'mean_watuse': 'rgba(21, 101, 192, 0.85)',    # Dark blue

        # Chemical metrics - purple variants
        'mean_eut': 'rgba(156, 39, 176, 0.85)',       # Purple
        'mean_acid': 'rgba(123, 31, 162, 0.85)'       # Dark purple
    }

    # Bold, distinctive colors for impact categories (final nodes)
    category_colors = {
        'Climate Impact': 'rgba(229, 57, 53, 0.9)',              # Bold red
        'Land & Biodiversity Impact': 'rgba(56, 142, 60, 0.9)',  # Bold green
        'Water Resources Impact': 'rgba(3, 169, 244, 0.9)',      # Bold blue
        'Chemical Pollution Impact': 'rgba(156, 39, 176, 0.9)'   # Bold purple
    }

    # 4. Assign node colors with improved scheme
    node_colors = []
    for node in all_nodes:
        if node in diet_colors:
            node_colors.append(diet_colors[node])
        elif node in age_colors:
            node_colors.append(age_colors[node])
        elif node in gender_colors:
            node_colors.append(gender_colors[node])
        elif node in category_colors:
            node_colors.append(category_colors[node])
        else:
            # Find the matching metric
            for metric, label in metric_labels.items():
                if node == label:
                    node_colors.append(metric_base_colors.get(metric, 'rgba(150, 150, 150, 0.85)'))
                    break
            else:
                node_colors.append('rgba(150, 150, 150, 0.85)')

    # 5. Create links with improved transparency for better flow visibility
    links = {
        'source': [],
        'target': [],
        'value': [],
        'color': [],
        'label': []
    }

    # Diet -> Age links
    print("Processing Diet -> Age links...")
    diet_age = df.groupby(['Diet_Category', 'Age_Category'])['n_participants'].sum().reset_index()
    for _, row in diet_age.iterrows():
        # Skip very small links to reduce complexity
        if row['n_participants'] < 5:
            continue

        links['source'].append(node_indices[row['Diet_Category']])
        links['target'].append(node_indices[row['Age_Category']])
        links['value'].append(row['n_participants'])
        # Reduce opacity for links
        color = diet_colors.get(row['Diet_Category'], 'rgba(150, 150, 150, 0.4)')
        if 'rgba' in color:
            # Extract the RGB values and set a lower opacity
            color = color.replace('0.85', '0.65')
        links['color'].append(color)
        links['label'].append(f"{row['Diet_Category']} → {row['Age_Category']}: {row['n_participants']} participants")

    # Age -> Gender links
    print("Processing Age -> Gender links...")
    age_gender = df.groupby(['Age_Category', 'Sex'])['n_participants'].sum().reset_index()
    for _, row in age_gender.iterrows():
        links['source'].append(node_indices[row['Age_Category']])
        links['target'].append(node_indices[row['Sex']])
        links['value'].append(row['n_participants'])
        color = age_colors.get(row['Age_Category'], 'rgba(150, 150, 150, 0.4)')
        if 'rgba' in color:
            color = color.replace('0.85', '0.65')
        links['color'].append(color)
        links['label'].append(f"{row['Age_Category']} → {row['Sex']}: {row['n_participants']} participants")

    # Gender -> Individual Metrics links
    print("Processing Gender -> Individual Metrics links...")
    # Calculate metrics values by gender
    gender_metrics = {}
    for gender in gender_nodes:
        gender_data = df[df['Sex'] == gender]
        participants = gender_data['n_participants'].sum()

        gender_metrics[gender] = {
            'participants': participants,
            'metrics': {}
        }

        for metric in all_metrics:
            if metric in df.columns:
                metric_value = gender_data[metric].mean()
                gender_metrics[gender]['metrics'][metric] = metric_value

    # Create links from gender to individual metrics
    for gender, data in gender_metrics.items():
        for metric, value in data['metrics'].items():
            metric_node = metric_labels[metric]
            # Scale value to make it visible
            scaled_value = value * data['participants'] / 10
            if scaled_value < 0.1:  # Ensure minimum visibility
                scaled_value = 0.1

            links['source'].append(node_indices[gender])
            links['target'].append(node_indices[metric_node])
            links['value'].append(scaled_value)
            color = gender_colors.get(gender, 'rgba(150, 150, 150, 0.4)')
            if 'rgba' in color:
                color = color.replace('0.85', '0.65')
            links['color'].append(color)
            links['label'].append(f"{gender} → {metric_node}: {value:.2f} avg. impact")

    # Individual Metrics -> Impact Categories links
    print("Processing Individual Metrics -> Impact Categories links...")
    for metric in all_metrics:
        if metric in metric_to_category:
            category = metric_to_category[metric]
            category_node = category_friendly_names[category]
            metric_node = metric_labels[metric]

            # Get average metric value
            metric_avg = df[metric].mean()
            # Scale by weights
            scaled_value = metric_avg * weights[category] * 1000

            links['source'].append(node_indices[metric_node])
            links['target'].append(node_indices[category_node])
            links['value'].append(scaled_value)
            color = metric_base_colors.get(metric, 'rgba(150, 150, 150, 0.4)')
            if 'rgba' in color:
                color = color.replace('0.85', '0.65')
            links['color'].append(color)
            links['label'].append(f"{metric_node} → {category_node}: contributes to {category_node}")

    # 6. Create the Sankey diagram with improved visual settings
    fig = go.Figure(go.Sankey(
        node=dict(
            pad=20,                          # Increased padding
            thickness=25,                    # Increased thickness for better visibility
            line=dict(color="rgba(50, 50, 50, 0.3)", width=0.3),  # Softer line color
            label=all_nodes,
            color=node_colors,
            hovertemplate='<b>%{label}</b><extra></extra>'  # Bold labels in hover
        ),
        link=dict(
            source=links['source'],
            target=links['target'],
            value=links['value'],
            color=links['color'],
            hovertemplate='<b>%{label}</b><extra></extra>',  # Bold labels in hover
            label=links['label']
        )
    ))

    # 7. Add a proper title and improved layout settings
    fig.update_layout(
        title=dict(
            text="Comprehensive Diet-Demographics-Environmental Impact Flow",
            font=dict(size=24, color="#424242", family="Arial, sans-serif"),
            y=0.98
        ),
        font=dict(
            family="Arial, sans-serif",
            size=14,
            color="#424242"
        ),
        height=900,
        width=1600,
        margin=dict(l=50, r=50, t=100, b=50),
        paper_bgcolor='rgba(250, 250, 250, 1)',  # Light gray background
        plot_bgcolor='rgba(250, 250, 250, 1)'    # Light gray background
    )

    # Add subtle watermark/annotation
    fig.add_annotation(
        text="Dietary Impact Analysis",
        x=0.5,
        y=1.05,
        xref="paper",
        yref="paper",
        showarrow=False,
        font=dict(
            family="Arial, sans-serif",
            size=14,
            color="rgba(150, 150, 150, 0.6)"
        )
    )

    print(f"Enhanced Sankey diagram created in {time.time() - start_time:.2f} seconds")
    return fig

# ======================
# INTERACTIVE DASHBOARD WITH WEIGHT CONTROLS
# ======================

def create_interactive_dashboard():
    """
    Creates an interactive dashboard with sliders to adjust impact weights
    """
    app = dash.Dash(__name__)

    app.layout = html.Div([
        html.H1("Diet-Demographics-Environmental Impact Dashboard",
                style={
                    'textAlign': 'center',
                    'fontFamily': 'Arial, sans-serif',
                    'color': '#2C3E50',
                    'marginBottom': '30px',
                    'marginTop': '20px'
                }),

        html.Div([
            html.H3("Adjust Impact Category Weights", style={'color': '#34495E'}),

            html.Div([
                html.Label("Climate Impact Weight:", style={'fontWeight': 'bold', 'color': '#C0392B'}),
                dcc.Slider(
                    id='climate-weight',
                    min=0.1,
                    max=1.0,
                    step=0.1,
                    value=0.4,
                    marks={i/10: str(i/10) for i in range(1, 11)},
                    className='custom-slider'
                )
            ], style={'marginBottom': 20}),

            html.Div([
                html.Label("Land & Biodiversity Impact Weight:", style={'fontWeight': 'bold', 'color': '#27AE60'}),
                dcc.Slider(
                    id='land-weight',
                    min=0.1,
                    max=1.0,
                    step=0.1,
                    value=0.3,
                    marks={i/10: str(i/10) for i in range(1, 11)},
                    className='custom-slider'
                )
            ], style={'marginBottom': 20}),

            html.Div([
                html.Label("Water Resources Impact Weight:", style={'fontWeight': 'bold', 'color': '#3498DB'}),
                dcc.Slider(
                    id='water-weight',
                    min=0.1,
                    max=1.0,
                    step=0.1,
                    value=0.2,
                    marks={i/10: str(i/10) for i in range(1, 11)},
                    className='custom-slider'
                )
            ], style={'marginBottom': 20}),

            html.Div([
                html.Label("Chemical Pollution Impact Weight:", style={'fontWeight': 'bold', 'color': '#9B59B6'}),
                dcc.Slider(
                    id='chemical-weight',
                    min=0.1,
                    max=1.0,
                    step=0.1,
                    value=0.1,
                    marks={i/10: str(i/10) for i in range(1, 11)},
                    className='custom-slider'
                )
            ], style={'marginBottom': 20}),

            html.Button('Reset Weights',
                      id='reset-button',
                      n_clicks=0,
                      style={
                          'backgroundColor': '#34495E',
                          'color': 'white',
                          'border': 'none',
                          'padding': '10px 20px',
                          'borderRadius': '5px',
                          'cursor': 'pointer',
                          'fontSize': '16px'
                      }),

        ], style={
            'padding': '25px',
            'backgroundColor': '#f8f9fa',
            'borderRadius': '10px',
            'boxShadow': '0 4px 8px rgba(0,0,0,0.1)',
            'marginBottom': '30px'
        }),

        # The main Sankey diagram
        dcc.Graph(id='sankey-graph'),

        # Helper text
        html.Div([
            html.H4("How to Read This Visualization", style={'color': '#34495E'}),
            html.P([
                "This Sankey diagram shows the flow from diet types through demographics to environmental impacts. ",
                "The width of each connection represents its relative magnitude. ",
                "The diagram follows this path: Diet Categories → Age Groups → Gender → Individual Metrics → Impact Categories."
            ], style={'lineHeight': '1.6'}),
            html.P([
                "Use the sliders above to adjust the weights of different environmental impact categories ",
                "and see how they affect the environmental impact assessment."
            ], style={'lineHeight': '1.6'})
        ], style={
            'padding': '25px',
            'backgroundColor': '#f8f9fa',
            'borderRadius': '10px',
            'boxShadow': '0 4px 8px rgba(0,0,0,0.1)',
            'marginTop': '30px'
        }),

    ], style={
        'maxWidth': '1800px',
        'margin': '0 auto',
        'padding': '20px',
        'fontFamily': 'Arial, sans-serif'
    })

    @app.callback(
        Output('sankey-graph', 'figure'),
        [Input('climate-weight', 'value'),
         Input('land-weight', 'value'),
         Input('water-weight', 'value'),
         Input('chemical-weight', 'value'),
         Input('reset-button', 'n_clicks')]
    )
    def update_graph(climate_weight, land_weight, water_weight, chemical_weight, n_clicks):
        # Define weights
        weights = {
            'Climate_Impact': climate_weight,
            'Land_Impact': land_weight,
            'Water_Impact': water_weight,
            'Chemical_Impact': chemical_weight
        }

        # Create the comprehensive Sankey diagram with current weights
        return create_comprehensive_sankey(df, weights)

    return app

# ======================
# EXECUTION
# ======================

print("\n=== GENERATING COMPREHENSIVE VISUALIZATION ===")
try:
    # Create a static version of the comprehensive Sankey diagram
    print("Creating enhanced Sankey diagram with improved color scheme...")
    comp_sankey = create_comprehensive_sankey(df)
    comp_sankey.show()
    comp_sankey.write_html("enhanced_diet_demographic_impact_sankey.html")
    print("Enhanced Sankey diagram saved to HTML file")

    # Create the interactive dashboard with weight controllers
    print("\nSetting up interactive dashboard...")
    app = create_interactive_dashboard()
    print("Dashboard ready! Run app.run_server(debug=True) to start the dashboard")

    # Uncomment to run the dashboard:
    app.run(debug=True)

except Exception as e:
    print(f"Error in visualization creation: {e}")
    import traceback
    traceback.print_exc()

Loading dataset...
Successfully loaded Oxford dataset

=== DATA VALIDATION ===
Columns in dataset: ['mc_run_id', 'grouping', 'mean_ghgs', 'mean_land', 'mean_watscar', 'mean_eut', 'mean_ghgs_ch4', 'mean_ghgs_n2o', 'mean_bio', 'mean_watuse', 'mean_acid', 'sd_ghgs', 'sd_land', 'sd_watscar', 'sd_eut', 'sd_ghgs_ch4', 'sd_ghgs_n2o', 'sd_bio', 'sd_watuse', 'sd_acid', 'n_participants', 'sex', 'diet_group', 'age_group']

Unique diet groups: ['fish' 'meat100' 'meat50' 'meat' 'vegan' 'veggie']
Unique sex values: ['female' 'male']
Unique age groups: ['20-29' '30-39' '40-49' '50-59' '60-69' '70-79']
Categorizing diets...
Calculating impact scores...

=== GENERATING COMPREHENSIVE VISUALIZATION ===
Creating enhanced Sankey diagram with improved color scheme...
Creating comprehensive Sankey diagram with enhanced color scheme...
Total nodes: 27
Processing Diet -> Age links...
Processing Age -> Gender links...
Processing Gender -> Individual Metrics links...
Processing Individual Metrics -> Impact Categ

Enhanced Sankey diagram saved to HTML file

Setting up interactive dashboard...
Dashboard ready! Run app.run_server(debug=True) to start the dashboard


Creating comprehensive Sankey diagram with enhanced color scheme...
Total nodes: 27
Processing Diet -> Age links...
Processing Age -> Gender links...
Processing Gender -> Individual Metrics links...
Processing Individual Metrics -> Impact Categories links...
Enhanced Sankey diagram created in 0.99 seconds
Creating comprehensive Sankey diagram with enhanced color scheme...
Total nodes: 27
Processing Diet -> Age links...
Processing Age -> Gender links...
Processing Gender -> Individual Metrics links...
Processing Individual Metrics -> Impact Categories links...
Enhanced Sankey diagram created in 0.46 seconds
Creating comprehensive Sankey diagram with enhanced color scheme...
Total nodes: 27
Processing Diet -> Age links...
Processing Age -> Gender links...
Processing Gender -> Individual Metrics links...
Processing Individual Metrics -> Impact Categories links...
Enhanced Sankey diagram created in 0.08 seconds


In [10]:
!pip install dash

Collecting dash
  Downloading dash-3.0.4-py3-none-any.whl.metadata (10 kB)
Collecting retrying (from dash)
  Downloading retrying-1.3.4-py3-none-any.whl.metadata (6.9 kB)
Downloading dash-3.0.4-py3-none-any.whl (7.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading retrying-1.3.4-py3-none-any.whl (11 kB)
Installing collected packages: retrying, dash
Successfully installed dash-3.0.4 retrying-1.3.4

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
