In [1]:
import pandas as pd
import os
import altair as alt
import numpy as np
import plotnine as pt
import geopandas as gpd
import plotly

In [2]:
aser = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/ASER_State_Data.csv') 

In [7]:
aser.columns

Index(['State',
       'Govt school % Children (aged 6-14) enrolled in govt schools 2018',
       'Govt school % Children (aged 6-14) enrolled in govt schools 2022',
       'Govt school % Children (aged 6-14) enrolled in govt schools 2024',
       'Not in school % Children (aged 15-16) not enrolled in school 2018',
       'Not in school % Children (aged 15-16) not enrolled in school 2022',
       'Not in school % Children (aged 15-16) not enrolled in school 2024',
       'Std III: Learning levels % Children who can read Std II level text 2018',
       'Std III: Learning levels % Children who can read Std II level text 2022',
       'Std III: Learning levels % Children who can read Std II level text 2024',
       'Std III: Learning levels % Children who can do at least subtraction 2018',
       'Std III: Learning levels % Children who can do at least subtraction 2022',
       'Std III: Learning levels % Children who can do at least subtraction 2024',
       'Std V: Learning levels % Chi

For all ASER survey charts mention that some states have been excluded because no data from survey.

In [3]:
# filtering to 2024
columns_2024 = [col for col in aser.columns if '2024' in col]
aser_2024 = aser[['State'] + columns_2024]

In [12]:
# Enable data transformer for larger datasets
alt.data_transformers.enable('default', max_rows=None)

# Prepare data
df = aser_2024
df = df[df['State'] != 'All India']

# Get 2024 columns
reading_cols = [c for c in df.columns if '2024' in c and 'read' in c.lower()]
math_cols = [c for c in df.columns if '2024' in c and ('division' in c or 'subtraction' in c)]

# Process reading
df_r = df[['State'] + reading_cols].copy()
df_r.columns = ['State', 'Std III', 'Std V', 'Std VIII']
df_r = df_r.sort_values('Std V', ascending=False)

# Process math  
df_m = df[['State'] + math_cols].copy()
df_m.columns = ['State', 'Std III', 'Std V', 'Std VIII']
df_m['State'] = pd.Categorical(df_m['State'], categories=df_r['State'], ordered=True)

# Melt
reading_long = df_r.melt(id_vars='State', var_name='Standard', value_name='Value')
reading_long['Type'] = 'Reading'

math_long = df_m.melt(id_vars='State', var_name='Standard', value_name='Value')
math_long['Type'] = 'Math'

# Combine
combined = pd.concat([reading_long, math_long])

# Create heatmap with your color scheme
heatmap = alt.Chart(combined).mark_rect(stroke='white', strokeWidth=0.2).encode(
    x=alt.X('Standard:N', title=None, axis=alt.Axis(labelFontSize=12, labelAngle=-45)), # make x-axis labels slanted
    y=alt.Y('State:N', title=None, sort=df_r['State'].tolist(), axis=alt.Axis(labelFontSize=10)),
    color=alt.Color('Value:Q',
                    scale=alt.Scale(
                        domain=[0, 25, 50, 75, 100],
                        range=['#ba7798', '#c9909e', '#d8a9a4', '#9fceb7', '#53826b']
                    ),
                    legend=alt.Legend(title='Proficiency %', orient='right', titleFontSize=12)),
    tooltip=[
        alt.Tooltip('State:N', title='State'),
        alt.Tooltip('Standard:N', title='Standard'),
        alt.Tooltip('Value:Q', format='.1f', title='Proficiency %'),
        alt.Tooltip('Type:N', title='Subject')
    ]
).properties(
    width=300,
    height=350
)



# Combine and facet
chart = (heatmap ).facet(
    column=alt.Column('Type:N', 
                      title=None,
                      header=alt.Header(labelFontSize=16, labelFontWeight='bold'))
).properties(
    title={
        "text": "ASER 2024: Learning Outcomes Across Indian States",
        "subtitle": "Reading: % who can read Std II text | Math: Std III = Subtraction, Std V & VIII = Division",
        "fontSize": 18,
        "fontWeight": "bold",
        "anchor": "middle"
    }
).configure_view(
    strokeWidth=0
).configure_title(
    fontSize=18,
    anchor='middle',
    subtitleFontSize=12,
    subtitleColor='gray'
)


chart

# Save chart
#chart.save('Project_JSON/aser_learning_outcomes_heatmap.json')

In [5]:
import pandas as pd
import altair as alt

df = aser
df = df[df['State'] != 'All India']

region_map = {
    'Andhra Pradesh': 'South', 'Arunachal Pradesh': 'Northeast',
    'Assam': 'Northeast', 'Bihar': 'North', 'Chhattisgarh': 'Central',
    'Gujarat': 'West', 'Haryana': 'North', 'Himachal Pradesh': 'North',
    'Jammu and Kashmir': 'North', 'Jharkhand': 'East', 'Karnataka': 'South',
    'Kerala': 'South', 'Madhya Pradesh': 'Central', 'Maharashtra': 'West',
    'Meghalaya': 'Northeast', 'Mizoram': 'Northeast', 'Nagaland': 'Northeast',
    'Odisha': 'East', 'Punjab': 'North', 'Rajasthan': 'North',
    'Sikkim': 'Northeast', 'Tamil Nadu': 'South', 'Telangana': 'South',
    'Tripura': 'Northeast', 'Uttar Pradesh': 'North', 'Uttarakhand': 'North',
    'West Bengal': 'East'
}

df['Region'] = df['State'].map(region_map)

enrollment_col = [c for c in df.columns if '2024' in c and 'enrolled in govt schools' in c][0]
reading_col = [c for c in df.columns if '2024' in c and 'Std V' in c and ('division' in c or 'subtraction' in c)] [0]
not_in_school_col = [c for c in df.columns if '2024' in c and 'not enrolled in school' in c][0]

bubble_data = pd.DataFrame({
    'State': df['State'],
    'Region': df['Region'],
    'Enrollment': df[enrollment_col],
    'Reading': df[reading_col],
    'Not_in_School': df[not_in_school_col]
})

region_colors = {
    'North': '#77ba99',
    'South': '#ba7798',
    'East': '#9d8bb3',
    'West': '#b39d8b',
    'Central': '#8bb3a6',
    'Northeast': '#b3a68b'
}

chart = alt.Chart(bubble_data).mark_circle(opacity=0.7).encode(
    x=alt.X('Enrollment:Q', 
            title='% Enrolled in Govt Schools', 
            scale=alt.Scale(domain=[35, 95])),
    y=alt.Y('Reading:Q', 
            title='% Math Proficiency (Std V)', 
            scale=alt.Scale(domain=[0, 75])),
    size=alt.Size('Not_in_School:Q', 
                  title='% Not in School', 
                  scale=alt.Scale(range=[100, 1000])),
    color=alt.Color('Region:N', 
                    scale=alt.Scale(domain=list(region_colors.keys()), 
                                   range=list(region_colors.values()))),
    tooltip=[
        alt.Tooltip('State:N', title='State'),
        alt.Tooltip('Region:N', title='Region'),
        alt.Tooltip('Enrollment:Q', format='.1f', title='Enrollment %'),
        alt.Tooltip('Reading:Q', format='.1f', title='Reading %'),
        alt.Tooltip('Not_in_School:Q', format='.1f', title='Not in School %')
    ]
).properties(
    width=700,
    height=500,
    title='Quality-Access Gap in Indian Education (2024)'
)

chart

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [6]:
# Read data
df = aser_2024
df = df[df['State'] != 'All India']

# Define regions
region_map = {
    'Andhra Pradesh': 'South', 'Arunachal Pradesh': 'Northeast',
    'Assam': 'Northeast', 'Bihar': 'North', 'Chhattisgarh': 'Central',
    'Gujarat': 'West', 'Haryana': 'North', 'Himachal Pradesh': 'North',
    'Jammu and Kashmir': 'North', 'Jharkhand': 'East', 'Karnataka': 'South',
    'Kerala': 'South', 'Madhya Pradesh': 'Central', 'Maharashtra': 'West',
    'Meghalaya': 'Northeast', 'Mizoram': 'Northeast', 'Nagaland': 'Northeast',
    'Odisha': 'East', 'Punjab': 'North', 'Rajasthan': 'North',
    'Sikkim': 'Northeast', 'Tamil Nadu': 'South', 'Telangana': 'South',
    'Tripura': 'Northeast', 'Uttar Pradesh': 'North', 'Uttarakhand': 'North',
    'West Bengal': 'East'
}

df['Region'] = df['State'].map(region_map)

# Get 2024 columns for all standards
# Reading
std3_reading_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'read' in c][0]
std5_reading_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'read' in c][0]
std8_reading_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'read' in c][0]

# Math
std3_math_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'subtraction' in c][0]
std5_math_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'division' in c][0]
std8_math_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'division' in c][0]

# Enrollment
enrollment_col = [c for c in df.columns if '2024' in c and 'enrolled in govt schools' in c][0]

# Create dataframe with all standards
comparison_data = pd.DataFrame({
    'State': df['State'],
    'Region': df['Region'],
    'Enrollment': df[enrollment_col],
    'Std_III_Reading': df[std3_reading_col],
    'Std_III_Math': df[std3_math_col],
    'Std_V_Reading': df[std5_reading_col],
    'Std_V_Math': df[std5_math_col],
    'Std_VIII_Reading': df[std8_reading_col],
    'Std_VIII_Math': df[std8_math_col]
})

# Reshape for each standard
std3_data = comparison_data[['State', 'Region', 'Enrollment', 'Std_III_Reading', 'Std_III_Math']].copy()
std3_data.columns = ['State', 'Region', 'Enrollment', 'Reading', 'Math']
std3_data['Standard'] = 'Std III'

std5_data = comparison_data[['State', 'Region', 'Enrollment', 'Std_V_Reading', 'Std_V_Math']].copy()
std5_data.columns = ['State', 'Region', 'Enrollment', 'Reading', 'Math']
std5_data['Standard'] = 'Std V'

std8_data = comparison_data[['State', 'Region', 'Enrollment', 'Std_VIII_Reading', 'Std_VIII_Math']].copy()
std8_data.columns = ['State', 'Region', 'Enrollment', 'Reading', 'Math']
std8_data['Standard'] = 'Std VIII'

# Combine all standards
all_data = pd.concat([std3_data, std5_data, std8_data])

# Your color palette
region_colors = {
    'North': '#77ba99',
    'South': '#ba7798',
    'East': '#9d8bb3',
    'West': '#b39d8b',
    'Central': '#8bb3a6',
    'Northeast': '#b3a68b'
}

# Create dropdown selection
dropdown = alt.binding_select(
    options=['Std III', 'Std V', 'Std VIII'],
    name='Standard: '
)
selection = alt.selection_point(
    fields=['Standard'],
    bind=dropdown,
    value='Std V'  # Default to Std V
)

# Create bubble chart
bubble_chart = alt.Chart(all_data).mark_circle(
    opacity=0.7,
    stroke='white',
    strokeWidth=1
).encode(
    x=alt.X('Reading:Q',
            title='% Reading Proficiency',
            scale=alt.Scale(domain=[0, 100]),
            axis=alt.Axis(titleFontSize=13, labelFontSize=11)),
    y=alt.Y('Math:Q',
            title='% Math Proficiency',
            scale=alt.Scale(domain=[0, 100]),
            axis=alt.Axis(titleFontSize=13, labelFontSize=11)),
    size=alt.Size('Enrollment:Q',
                  title='% Enrollment',
                  scale=alt.Scale(range=[100, 1000]),
                  legend=alt.Legend(titleFontSize=11, labelFontSize=10)),
    color=alt.Color('Region:N',
                    scale=alt.Scale(
                        domain=list(region_colors.keys()),
                        range=list(region_colors.values())
                    ),
                    legend=alt.Legend(titleFontSize=12, labelFontSize=11)),
    tooltip=[
        alt.Tooltip('State:N', title='State'),
        alt.Tooltip('Region:N', title='Region'),
        alt.Tooltip('Standard:N', title='Standard'),
        alt.Tooltip('Reading:Q', title='Reading %', format='.1f'),
        alt.Tooltip('Math:Q', title='Math %', format='.1f'),
        alt.Tooltip('Enrollment:Q', title='Enrollment %', format='.1f')
    ]
).add_params(
    selection
).transform_filter(
    selection
)

# Add diagonal reference line (where reading = math)
diagonal_data = pd.DataFrame({'x': [0, 100], 'y': [0, 100]})
diagonal = alt.Chart(diagonal_data).mark_line(
    strokeDash=[5, 5],
    color='gray',
    opacity=0.5
).encode(
    x='x:Q',
    y='y:Q'
)

# Combine
final_chart = (diagonal + bubble_chart).properties(
    width=700,
    height=600,
    title={
        "text": "Reading vs Math Proficiency Across Indian States (2024)",
        "subtitle": "Points above diagonal = stronger math skills | Points below = stronger reading skills | Select standard from dropdown",
        "fontSize": 18,
        "fontWeight": "bold",
        "anchor": "start",
        "subtitleFontSize": 13,
        "subtitleColor": "gray"
    }
).configure_view(
    strokeWidth=0
).configure_axis(
    gridColor='#e0e0e0'
)



final_chart

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [8]:
import pandas as pd
import altair as alt
from scipy import stats

# Read data
df = aser_2024
df = df[df['State'] != 'All India']

# Define regions
region_map = {
    'Andhra Pradesh': 'South', 'Arunachal Pradesh': 'Northeast',
    'Assam': 'Northeast', 'Bihar': 'North', 'Chhattisgarh': 'Central',
    'Gujarat': 'West', 'Haryana': 'North', 'Himachal Pradesh': 'North',
    'Jammu and Kashmir': 'North', 'Jharkhand': 'East', 'Karnataka': 'South',
    'Kerala': 'South', 'Madhya Pradesh': 'Central', 'Maharashtra': 'West',
    'Meghalaya': 'Northeast', 'Mizoram': 'Northeast', 'Nagaland': 'Northeast',
    'Odisha': 'East', 'Punjab': 'North', 'Rajasthan': 'North',
    'Sikkim': 'Northeast', 'Tamil Nadu': 'South', 'Telangana': 'South',
    'Tripura': 'Northeast', 'Uttar Pradesh': 'North', 'Uttarakhand': 'North',
    'West Bengal': 'East'
}

df['Region'] = df['State'].map(region_map)

# Get 2024 columns
std3_reading_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'read' in c][0]
std5_reading_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'read' in c][0]
std8_reading_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'read' in c][0]

std3_math_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'subtraction' in c][0]
std5_math_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'division' in c][0]
std8_math_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'division' in c][0]

enrollment_col = [c for c in df.columns if '2024' in c and 'enrolled in govt schools' in c][0]

# Create dataframe
comparison_data = pd.DataFrame({
    'State': df['State'],
    'Region': df['Region'],
    'Enrollment': df[enrollment_col],
    'Std_III_Reading': df[std3_reading_col],
    'Std_III_Math': df[std3_math_col],
    'Std_V_Reading': df[std5_reading_col],
    'Std_V_Math': df[std5_math_col],
    'Std_VIII_Reading': df[std8_reading_col],
    'Std_VIII_Math': df[std8_math_col]
})

# Reshape for each standard
std3_data = comparison_data[['State', 'Region', 'Enrollment', 'Std_III_Reading', 'Std_III_Math']].copy()
std3_data.columns = ['State', 'Region', 'Enrollment', 'Reading', 'Math']
std3_data['Standard'] = 'Std III'

std5_data = comparison_data[['State', 'Region', 'Enrollment', 'Std_V_Reading', 'Std_V_Math']].copy()
std5_data.columns = ['State', 'Region', 'Enrollment', 'Reading', 'Math']
std5_data['Standard'] = 'Std V'

std8_data = comparison_data[['State', 'Region', 'Enrollment', 'Std_VIII_Reading', 'Std_VIII_Math']].copy()
std8_data.columns = ['State', 'Region', 'Enrollment', 'Reading', 'Math']
std8_data['Standard'] = 'Std VIII'

# Combine
all_data = pd.concat([std3_data, std5_data, std8_data])

# Calculate R² for each standard
correlations = []
for std in ['Std III', 'Std V', 'Std VIII']:
    std_data = all_data[all_data['Standard'] == std].dropna(subset=['Reading', 'Math'])
    if len(std_data) > 0:
        r, p = stats.pearsonr(std_data['Reading'], std_data['Math'])
        correlations.append({
            'Standard': std,
            'r_squared': r**2,
            'r_text': f'R² = {r**2:.3f}',
            'x_pos': 5,
            'y_pos': 95
        })

corr_df = pd.DataFrame(correlations)

# Your color palette
region_colors = {
    'North': '#77ba99',
    'South': '#ba7798',
    'East': '#9d8bb3',
    'West': '#b39d8b',
    'Central': '#8bb3a6',
    'Northeast': '#b3a68b'
}

# Create dropdown
dropdown = alt.binding_select(
    options=['Std III', 'Std V', 'Std VIII'],
    name='Standard: '
)
selection = alt.selection_point(
    fields=['Standard'],
    bind=dropdown,
    value='Std V'
)

# Bubble chart
bubble_chart = alt.Chart(all_data).mark_circle(
    opacity=0.7,
    stroke='white',
    strokeWidth=1
).encode(
    x=alt.X('Reading:Q',
            title='% Reading Proficiency',
            scale=alt.Scale(domain=[0, 100]),
            axis=alt.Axis(titleFontSize=13, labelFontSize=11)),
    y=alt.Y('Math:Q',
            title='% Math Proficiency',
            scale=alt.Scale(domain=[0, 100]),
            axis=alt.Axis(titleFontSize=13, labelFontSize=11)),
    size=alt.Size('Enrollment:Q',
                  title='% Enrollment',
                  scale=alt.Scale(range=[100, 1000]),
                  legend=alt.Legend(titleFontSize=11, labelFontSize=10)),
    color=alt.Color('Region:N',
                    scale=alt.Scale(
                        domain=list(region_colors.keys()),
                        range=list(region_colors.values())
                    ),
                    legend=alt.Legend(titleFontSize=12, labelFontSize=11)),
    tooltip=[
        alt.Tooltip('State:N', title='State'),
        alt.Tooltip('Region:N', title='Region'),
        alt.Tooltip('Standard:N', title='Standard'),
        alt.Tooltip('Reading:Q', title='Reading %', format='.1f'),
        alt.Tooltip('Math:Q', title='Math %', format='.1f'),
        alt.Tooltip('Enrollment:Q', title='Enrollment %', format='.1f')
    ]
).add_params(selection).transform_filter(selection)

# Regression line
regression = alt.Chart(all_data).mark_line(
    color='#ba7798',
    strokeWidth=3,
    opacity=0.8
).encode(
    x='Reading:Q',
    y='Math:Q'
).transform_filter(selection).transform_regression('Reading', 'Math', method='linear')

# R² text annotation (updates with dropdown)
r_squared_text = alt.Chart(corr_df).mark_text(
    align='left',
    baseline='top',
    dx=0,
    dy=0,
    fontSize=16,
    fontWeight='bold',
    color='#ba7798'
).encode(
    x=alt.value(20),  # Pixel position from left
    y=alt.value(20),  # Pixel position from top
    text='r_text:N'
).transform_filter(selection)

# Combine all layers
final_chart = (bubble_chart + regression + r_squared_text).properties(
    width=700,
    height=600,
    title={
        "text": "Reading vs Math Proficiency Across Indian States (2024)",
        "subtitle": "Do states that excel at reading also excel at math? Trendline shows correlation. Select standard from dropdown.",
        "fontSize": 18,
        "fontWeight": "bold",
        "anchor": "start",
        "subtitleFontSize": 13,
        "subtitleColor": "gray"
    }
).configure_view(
    strokeWidth=0
).configure_axis(
    gridColor='#e0e0e0'
)

# Save

final_chart

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [13]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy import stats
import numpy as np

# ===================================================================
# LOAD AND PREPARE INFRASTRUCTURE DATA (2024)
# ===================================================================

# Read in school facility data for 2024
df_fac_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_fac.csv')
df_profile_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_prof1.csv')

# Merge facility data with state info
df_sch_2024 = df_fac_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Calculate infrastructure index - only count FUNCTIONAL facilities
state_infrastructure = df_sch_2024.groupby('state').agg({
    'electricity_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'internet': lambda x: (x == 1).sum() / len(x) * 100,
    'library_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'total_girls_func_toilet': lambda x: (x > 0).sum() / len(x) * 100,
    'pseudocode': 'count'
}).reset_index()

# Rename columns
state_infrastructure.rename(columns={
    'electricity_availability': 'pct_electricity_functional',
    'internet': 'pct_internet',
    'library_availability': 'pct_library',
    'total_girls_func_toilet': 'pct_girls_toilets',
    'pseudocode': 'num_schools'
}, inplace=True)

# Create composite infrastructure index
state_infrastructure['infrastructure_index'] = state_infrastructure[[
    'pct_electricity_functional', 'pct_internet', 'pct_library', 'pct_girls_toilets'
]].mean(axis=1)

# Clean state names
name_mapping = {
    "ANDAMAN & NICOBAR ISLANDS": "Andaman and Nicobar",
    "ANDHRA PRADESH": "Andhra Pradesh",
    "ARUNACHAL PRADESH": "Arunachal Pradesh",
    "ASSAM": "Assam",
    "BIHAR": "Bihar",
    "CHANDIGARH": "Chandigarh",
    "CHHATTISGARH": "Chhattisgarh",
    "DADRA & NAGAR HAVELI AND DAMAN & DIU": "Dadra and Nagar Haveli and Daman and Diu",
    "DELHI": "Delhi",
    "GOA": "Goa",
    "GUJARAT": "Gujarat",
    "HARYANA": "Haryana",
    "HIMACHAL PRADESH": "Himachal Pradesh",
    "JAMMU & KASHMIR": "Jammu and Kashmir",
    "JHARKHAND": "Jharkhand",
    "KARNATAKA": "Karnataka",
    "KERALA": "Kerala",
    "LADAKH": "Ladakh",
    "LAKSHADWEEP": "Lakshadweep",
    "MADHYA PRADESH": "Madhya Pradesh",
    "MAHARASHTRA": "Maharashtra",
    "MANIPUR": "Manipur",
    "MEGHALAYA": "Meghalaya",
    "MIZORAM": "Mizoram",
    "NAGALAND": "Nagaland",
    "ODISHA": "Odisha",
    "PUDUCHERRY": "Puducherry",
    "PUNJAB": "Punjab",
    "RAJASTHAN": "Rajasthan",
    "SIKKIM": "Sikkim",
    "TAMIL NADU": "Tamil Nadu",
    "TELANGANA": "Telangana",
    "TRIPURA": "Tripura",
    "UTTAR PRADESH": "Uttar Pradesh",
    "UTTARAKHAND": "Uttarakhand",
    "WEST BENGAL": "West Bengal"
}

state_infrastructure['state'] = state_infrastructure['state'].map(name_mapping)

# ===================================================================
# LOAD AND PREPARE ASER DATA (2024)
# ===================================================================

# Load your ASER data
aser_2024 = aser_2024  # Update path
df = aser_2024[aser_2024['State'] != 'All India'].copy()

# Get 2024 columns for reading across all standards
std3_reading_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'read' in c][0]
std5_reading_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'read' in c][0]
std8_reading_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'read' in c][0]

# Calculate composite learning score (average reading across standards)
df['learning_score'] = df[[std3_reading_col, std5_reading_col, std8_reading_col]].mean(axis=1)

# Keep only State and learning_score
learning_data = df[['State', 'learning_score']].copy()

# ===================================================================
# MERGE INFRASTRUCTURE AND LEARNING DATA
# ===================================================================

merged_data = state_infrastructure.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

# Check merge results
print(f"\nMerged data: {len(merged_data)} states")
print(f"States in infrastructure: {state_infrastructure['state'].nunique()}")
print(f"States in ASER: {learning_data['State'].nunique()}")

# ===================================================================
# CALCULATE CORRELATION STATISTICS
# ===================================================================

correlation = merged_data['infrastructure_index'].corr(merged_data['learning_score'])
r_squared = correlation ** 2

# Calculate trend line
slope, intercept, r_value, p_value, std_err = stats.linregress(
    merged_data['infrastructure_index'], 
    merged_data['learning_score']
)

# Create trend line data
x_trend = np.linspace(
    merged_data['infrastructure_index'].min(), 
    merged_data['infrastructure_index'].max(), 
    100
)
y_trend = slope * x_trend + intercept

# ===================================================================
# CREATE SCATTER PLOT
# ===================================================================

fig = go.Figure()

# Add scatter points
fig.add_trace(go.Scatter(
    x=merged_data['infrastructure_index'],
    y=merged_data['learning_score'],
    mode='markers+text',
    marker=dict(
        size=10,
        color='#3b82f6',
        opacity=0.7,
        line=dict(width=1, color='#1e40af')
    ),
    text=merged_data['state'],
    textposition='top center',
    textfont=dict(size=8, color='#374151'),
    name='States',
    hovertemplate='<b>%{text}</b><br>' +
                  'Infrastructure: %{x:.1f}%<br>' +
                  'Reading Score: %{y:.1f}%<br>' +
                  '<extra></extra>'
))

# Add trend line
fig.add_trace(go.Scatter(
    x=x_trend,
    y=y_trend,
    mode='lines',
    line=dict(color='#ef4444', width=2, dash='dash'),
    name=f'Trend Line (r={correlation:.3f})',
    hoverinfo='skip'
))

# Update layout
fig.update_layout(
    title=dict(
        text=f'School Infrastructure vs Learning Outcomes (2024)<br>' +
             f'<sub>Correlation: r = {correlation:.3f}, R² = {r_squared:.3f}, p = {p_value:.4f}</sub>',
        x=0.5,
        xanchor='center',
        font=dict(size=18, color='#1f2937')
    ),
    xaxis=dict(
        title=dict(
            text='Composite Infrastructure Index (%)<br><sub>(Electricity, Internet, Library, Girls\' Toilets)</sub>',
            font=dict(size=13, color='#374151')
        ),
        showgrid=True,
        gridcolor='#e5e7eb',
        range=[
            merged_data['infrastructure_index'].min() - 3, 
            merged_data['infrastructure_index'].max() + 3
        ]
    ),
    yaxis=dict(
        title=dict(
            text='Composite Reading Proficiency Score (%)<br><sub>(Average across Std III, V, VIII)</sub>',
            font=dict(size=13, color='#374151')
        ),
        showgrid=True,
        gridcolor='#e5e7eb',
        range=[
            merged_data['learning_score'].min() - 3, 
            merged_data['learning_score'].max() + 3
        ]
    ),
    plot_bgcolor='white',
    hovermode='closest',
    showlegend=True,
    legend=dict(
        x=0.02, 
        y=0.98, 
        bgcolor='rgba(255,255,255,0.9)',
        bordercolor='#d1d5db',
        borderwidth=1
    ),
    width=1000,
    height=650,
    margin=dict(t=100, b=100, l=100, r=40)
)

# Display the plot
fig.show()

# ===================================================================
# PRINT STATISTICS
# ===================================================================

print(f"\n{'='*70}")
print(f"CORRELATION ANALYSIS: INFRASTRUCTURE vs LEARNING OUTCOMES (2024)")
print(f"{'='*70}")
print(f"Correlation Coefficient (r): {correlation:.4f}")
print(f"R-squared (R²): {r_squared:.4f}")
print(f"P-value: {p_value:.6f}")
print(f"Standard Error: {std_err:.4f}")
print(f"\nTrend Line Equation: y = {slope:.3f}x + {intercept:.3f}")

print(f"\nInterpretation:")
if p_value < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if correlation > 0.7:
    print("✓ Strong positive correlation - better infrastructure is strongly")
    print("  associated with higher learning outcomes")
elif correlation > 0.4:
    print("✓ Moderate positive correlation - infrastructure improvements tend")
    print("  to correlate with better learning outcomes")
else:
    print("○ Weak positive correlation - infrastructure alone may not fully")
    print("  explain learning outcome differences")
print(f"{'='*70}\n")

# ===================================================================
# DISPLAY TOP AND BOTTOM PERFORMERS
# ===================================================================

print("\nTop 5 States (Infrastructure Index):")
print(merged_data.nlargest(5, 'infrastructure_index')[['state', 'infrastructure_index', 'learning_score']])

print("\nBottom 5 States (Infrastructure Index):")
print(merged_data.nsmallest(5, 'infrastructure_index')[['state', 'infrastructure_index', 'learning_score']])

print("\nTop 5 States (Learning Score):")
print(merged_data.nlargest(5, 'learning_score')[['state', 'infrastructure_index', 'learning_score']])

print("\nBottom 5 States (Learning Score):")
print(merged_data.nsmallest(5, 'learning_score')[['state', 'infrastructure_index', 'learning_score']])

# ===================================================================
# OPTIONAL: SAVE OUTPUT
# ===================================================================

# Save the figure
# fig.write_html("infrastructure_vs_learning_2024.html")
# fig.write_image("infrastructure_vs_learning_2024.png", width=1000, height=650)

# Save merged data
# merged_data.to_csv('infrastructure_learning_correlation_2024.csv', index=False)


Merged data: 27 states
States in infrastructure: 36
States in ASER: 27



CORRELATION ANALYSIS: INFRASTRUCTURE vs LEARNING OUTCOMES (2024)
Correlation Coefficient (r): 0.1897
R-squared (R²): 0.0360
P-value: 0.343259
Standard Error: 0.1349

Trend Line Equation: y = 0.130x + 39.530

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak positive correlation - infrastructure alone may not fully
  explain learning outcome differences


Top 5 States (Infrastructure Index):
             state  infrastructure_index  learning_score
0   Andhra Pradesh             99.039010       36.533333
11          Kerala             96.782382       65.366667
5          Gujarat             95.969918       49.333333
17          Odisha             94.472103       58.733333
6          Haryana             93.423853       63.400000

Bottom 5 States (Infrastructure Index):
                state  infrastructure_index  learning_score
14          Meghalaya             35.375677       45.866667
1   Arunachal Pradesh             49.899350       45.466667
23            Tripura    

In [14]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy import stats
import numpy as np

# ===================================================================
# LOAD AND PREPARE INFRASTRUCTURE DATA (2024)
# ===================================================================

# Read in school facility data for 2024
df_fac_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_fac.csv')
df_profile_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_prof1.csv')

# Merge facility data with state info
df_sch_2024 = df_fac_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Calculate infrastructure index - only count FUNCTIONAL facilities
state_infrastructure = df_sch_2024.groupby('state').agg({
    'electricity_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'internet': lambda x: (x == 1).sum() / len(x) * 100,
    'library_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'total_girls_func_toilet': lambda x: (x > 0).sum() / len(x) * 100,
    'pseudocode': 'count'
}).reset_index()

# Rename columns
state_infrastructure.rename(columns={
    'electricity_availability': 'pct_electricity_functional',
    'internet': 'pct_internet',
    'library_availability': 'pct_library',
    'total_girls_func_toilet': 'pct_girls_toilets',
    'pseudocode': 'num_schools'
}, inplace=True)

# Create composite infrastructure index
state_infrastructure['infrastructure_index'] = state_infrastructure[[
    'pct_electricity_functional', 'pct_internet', 'pct_library', 'pct_girls_toilets'
]].mean(axis=1)

# Clean state names
name_mapping = {
    "ANDAMAN & NICOBAR ISLANDS": "Andaman and Nicobar",
    "ANDHRA PRADESH": "Andhra Pradesh",
    "ARUNACHAL PRADESH": "Arunachal Pradesh",
    "ASSAM": "Assam",
    "BIHAR": "Bihar",
    "CHANDIGARH": "Chandigarh",
    "CHHATTISGARH": "Chhattisgarh",
    "DADRA & NAGAR HAVELI AND DAMAN & DIU": "Dadra and Nagar Haveli and Daman and Diu",
    "DELHI": "Delhi",
    "GOA": "Goa",
    "GUJARAT": "Gujarat",
    "HARYANA": "Haryana",
    "HIMACHAL PRADESH": "Himachal Pradesh",
    "JAMMU & KASHMIR": "Jammu and Kashmir",
    "JHARKHAND": "Jharkhand",
    "KARNATAKA": "Karnataka",
    "KERALA": "Kerala",
    "LADAKH": "Ladakh",
    "LAKSHADWEEP": "Lakshadweep",
    "MADHYA PRADESH": "Madhya Pradesh",
    "MAHARASHTRA": "Maharashtra",
    "MANIPUR": "Manipur",
    "MEGHALAYA": "Meghalaya",
    "MIZORAM": "Mizoram",
    "NAGALAND": "Nagaland",
    "ODISHA": "Odisha",
    "PUDUCHERRY": "Puducherry",
    "PUNJAB": "Punjab",
    "RAJASTHAN": "Rajasthan",
    "SIKKIM": "Sikkim",
    "TAMIL NADU": "Tamil Nadu",
    "TELANGANA": "Telangana",
    "TRIPURA": "Tripura",
    "UTTAR PRADESH": "Uttar Pradesh",
    "UTTARAKHAND": "Uttarakhand",
    "WEST BENGAL": "West Bengal"
}

state_infrastructure['state'] = state_infrastructure['state'].map(name_mapping)

# ===================================================================
# LOAD AND PREPARE ASER DATA (2024)
# ===================================================================

# Load your ASER data
aser_2024 = aser_2024
df = aser_2024[aser_2024['State'] != 'All India'].copy()

# Get 2024 columns for reading across all standards
std3_reading_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'read' in c][0]
std5_reading_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'read' in c][0]
std8_reading_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'read' in c][0]

# Get 2024 columns for math across all standards
std3_math_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'subtraction' in c][0]
std5_math_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'division' in c][0]
std8_math_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'division' in c][0]

# Calculate composite scores (average across standards)
df['reading_score'] = df[[std3_reading_col, std5_reading_col, std8_reading_col]].mean(axis=1)
df['math_score'] = df[[std3_math_col, std5_math_col, std8_math_col]].mean(axis=1)

# Keep State and both scores
learning_data = df[['State', 'reading_score', 'math_score']].copy()

# ===================================================================
# MERGE INFRASTRUCTURE AND LEARNING DATA
# ===================================================================

merged_data = state_infrastructure.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

# Check merge results
print(f"\nMerged data: {len(merged_data)} states")
print(f"States in infrastructure: {state_infrastructure['state'].nunique()}")
print(f"States in ASER: {learning_data['State'].nunique()}")

# ===================================================================
# FUNCTION TO CREATE CORRELATION PLOT
# ===================================================================

def create_correlation_plot(data, x_col, y_col, title_subject, y_label):
    """Create a correlation scatter plot"""
    
    # Calculate correlation statistics
    correlation = data[x_col].corr(data[y_col])
    r_squared = correlation ** 2
    
    # Calculate trend line
    slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col], data[y_col])
    
    # Create trend line data
    x_trend = np.linspace(data[x_col].min(), data[x_col].max(), 100)
    y_trend = slope * x_trend + intercept
    
    # Create figure
    fig = go.Figure()
    
    # Add scatter points
    fig.add_trace(go.Scatter(
        x=data[x_col],
        y=data[y_col],
        mode='markers+text',
        marker=dict(
            size=10,
            color='#3b82f6',
            opacity=0.7,
            line=dict(width=1, color='#1e40af')
        ),
        text=data['state'],
        textposition='top center',
        textfont=dict(size=8, color='#374151'),
        name='States',
        hovertemplate='<b>%{text}</b><br>' +
                      'Infrastructure: %{x:.1f}%<br>' +
                      f'{title_subject} Score: ' + '%{y:.1f}%<br>' +
                      '<extra></extra>'
    ))
    
    # Add trend line
    fig.add_trace(go.Scatter(
        x=x_trend,
        y=y_trend,
        mode='lines',
        line=dict(color='#ef4444', width=2, dash='dash'),
        name=f'Trend Line (r={correlation:.3f})',
        hoverinfo='skip'
    ))
    
    # Update layout
    fig.update_layout(
        title=dict(
            text=f'School Infrastructure vs {title_subject} Outcomes (2024)<br>' +
                 f'<sub>Correlation: r = {correlation:.3f}, R² = {r_squared:.3f}, p = {p_value:.4f}</sub>',
            x=0.5,
            xanchor='center',
            font=dict(size=18, color='#1f2937')
        ),
        xaxis=dict(
            title=dict(
                text='Composite Infrastructure Index (%)<br><sub>(Electricity, Internet, Library, Girls\' Toilets)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb',
            range=[data[x_col].min() - 3, data[x_col].max() + 3]
        ),
        yaxis=dict(
            title=dict(
                text=f'{y_label}<br><sub>(Average across Std III, V, VIII)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb',
            range=[data[y_col].min() - 3, data[y_col].max() + 3]
        ),
        plot_bgcolor='white',
        hovermode='closest',
        showlegend=True,
        legend=dict(
            x=0.02, 
            y=0.98, 
            bgcolor='rgba(255,255,255,0.9)',
            bordercolor='#d1d5db',
            borderwidth=1
        ),
        width=1000,
        height=650,
        margin=dict(t=100, b=100, l=100, r=40)
    )
    
    return fig, correlation, r_squared, p_value, slope, intercept, std_err

# ===================================================================
# CREATE READING CORRELATION PLOT
# ===================================================================

print("\n" + "="*70)
print("READING OUTCOMES ANALYSIS")
print("="*70)

fig_reading, corr_reading, r2_reading, p_reading, slope_reading, intercept_reading, stderr_reading = create_correlation_plot(
    merged_data,
    'infrastructure_index',
    'reading_score',
    'Reading',
    'Composite Reading Proficiency Score (%)'
)

fig_reading.show()

# Print statistics
print(f"Correlation Coefficient (r): {corr_reading:.4f}")
print(f"R-squared (R²): {r2_reading:.4f}")
print(f"P-value: {p_reading:.6f}")
print(f"Standard Error: {stderr_reading:.4f}")
print(f"Trend Line Equation: y = {slope_reading:.3f}x + {intercept_reading:.3f}")

print(f"\nInterpretation:")
if p_reading < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if corr_reading > 0.7:
    print("✓ Strong positive correlation")
elif corr_reading > 0.4:
    print("✓ Moderate positive correlation")
else:
    print("○ Weak positive correlation")

# ===================================================================
# CREATE MATH CORRELATION PLOT
# ===================================================================

print("\n" + "="*70)
print("MATH OUTCOMES ANALYSIS")
print("="*70)

fig_math, corr_math, r2_math, p_math, slope_math, intercept_math, stderr_math = create_correlation_plot(
    merged_data,
    'infrastructure_index',
    'math_score',
    'Math',
    'Composite Math Proficiency Score (%)'
)

fig_math.show()

# Print statistics
print(f"Correlation Coefficient (r): {corr_math:.4f}")
print(f"R-squared (R²): {r2_math:.4f}")
print(f"P-value: {p_math:.6f}")
print(f"Standard Error: {stderr_math:.4f}")
print(f"Trend Line Equation: y = {slope_math:.3f}x + {intercept_math:.3f}")

print(f"\nInterpretation:")
if p_math < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if corr_math > 0.7:
    print("✓ Strong positive correlation")
elif corr_math > 0.4:
    print("✓ Moderate positive correlation")
else:
    print("○ Weak positive correlation")

# ===================================================================
# COMPARISON SUMMARY
# ===================================================================

print("\n" + "="*70)
print("COMPARISON: READING vs MATH CORRELATIONS")
print("="*70)
print(f"Reading Correlation (r): {corr_reading:.4f} | Math Correlation (r): {corr_math:.4f}")
print(f"Reading R²: {r2_reading:.4f} | Math R²: {r2_math:.4f}")
print(f"Reading p-value: {p_reading:.4f} | Math p-value: {p_math:.4f}")

if abs(corr_reading - corr_math) < 0.05:
    print("\nBoth subjects show similar weak correlations with infrastructure.")
elif corr_math > corr_reading:
    print(f"\nMath shows a stronger correlation with infrastructure (+{corr_math - corr_reading:.3f}).")
else:
    print(f"\nReading shows a stronger correlation with infrastructure (+{corr_reading - corr_math:.3f}).")

print("="*70)

# ===================================================================
# OPTIONAL: SAVE OUTPUT
# ===================================================================

# Save the figures
# fig_reading.write_html("infrastructure_vs_reading_2024.html")
# fig_math.write_html("infrastructure_vs_math_2024.html")
# fig_reading.write_image("infrastructure_vs_reading_2024.png", width=1000, height=650)
# fig_math.write_image("infrastructure_vs_math_2024.png", width=1000, height=650)

# Save merged data
# merged_data.to_csv('infrastructure_learning_correlation_2024.csv', index=False)


Merged data: 27 states
States in infrastructure: 36
States in ASER: 27

READING OUTCOMES ANALYSIS


Correlation Coefficient (r): 0.1897
R-squared (R²): 0.0360
P-value: 0.343259
Standard Error: 0.1349
Trend Line Equation: y = 0.130x + 39.530

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak positive correlation

MATH OUTCOMES ANALYSIS


Correlation Coefficient (r): 0.2150
R-squared (R²): 0.0462
P-value: 0.281437
Standard Error: 0.1350
Trend Line Equation: y = 0.149x + 23.909

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak positive correlation

COMPARISON: READING vs MATH CORRELATIONS
Reading Correlation (r): 0.1897 | Math Correlation (r): 0.2150
Reading R²: 0.0360 | Math R²: 0.0462
Reading p-value: 0.3433 | Math p-value: 0.2814

Both subjects show similar weak correlations with infrastructure.


In [21]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy import stats
import numpy as np

# ===================================================================
# LOAD AND PREPARE INFRASTRUCTURE DATA (2024)
# ===================================================================

# Read in school facility data for 2024
df_fac_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_fac.csv')
df_profile_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_prof1.csv')

# Merge facility data with state info
df_sch_2024 = df_fac_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Calculate infrastructure index - only count FUNCTIONAL facilities
state_infrastructure = df_sch_2024.groupby('state').agg({
    'electricity_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'internet': lambda x: (x == 1).sum() / len(x) * 100,
    'library_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'total_girls_func_toilet': lambda x: (x > 0).sum() / len(x) * 100,
    'pseudocode': 'count'
}).reset_index()

# Rename columns
state_infrastructure.rename(columns={
    'electricity_availability': 'pct_electricity_functional',
    'internet': 'pct_internet',
    'library_availability': 'pct_library',
    'total_girls_func_toilet': 'pct_girls_toilets',
    'pseudocode': 'num_schools'
}, inplace=True)

# Create composite infrastructure index
state_infrastructure['infrastructure_index'] = state_infrastructure[[
    'pct_electricity_functional', 'pct_internet', 'pct_library', 'pct_girls_toilets'
]].mean(axis=1)

# Clean state names
name_mapping = {
    "ANDAMAN & NICOBAR ISLANDS": "Andaman and Nicobar",
    "ANDHRA PRADESH": "Andhra Pradesh",
    "ARUNACHAL PRADESH": "Arunachal Pradesh",
    "ASSAM": "Assam",
    "BIHAR": "Bihar",
    "CHANDIGARH": "Chandigarh",
    "CHHATTISGARH": "Chhattisgarh",
    "DADRA & NAGAR HAVELI AND DAMAN & DIU": "Dadra and Nagar Haveli and Daman and Diu",
    "DELHI": "Delhi",
    "GOA": "Goa",
    "GUJARAT": "Gujarat",
    "HARYANA": "Haryana",
    "HIMACHAL PRADESH": "Himachal Pradesh",
    "JAMMU & KASHMIR": "Jammu and Kashmir",
    "JHARKHAND": "Jharkhand",
    "KARNATAKA": "Karnataka",
    "KERALA": "Kerala",
    "LADAKH": "Ladakh",
    "LAKSHADWEEP": "Lakshadweep",
    "MADHYA PRADESH": "Madhya Pradesh",
    "MAHARASHTRA": "Maharashtra",
    "MANIPUR": "Manipur",
    "MEGHALAYA": "Meghalaya",
    "MIZORAM": "Mizoram",
    "NAGALAND": "Nagaland",
    "ODISHA": "Odisha",
    "PUDUCHERRY": "Puducherry",
    "PUNJAB": "Punjab",
    "RAJASTHAN": "Rajasthan",
    "SIKKIM": "Sikkim",
    "TAMIL NADU": "Tamil Nadu",
    "TELANGANA": "Telangana",
    "TRIPURA": "Tripura",
    "UTTAR PRADESH": "Uttar Pradesh",
    "UTTARAKHAND": "Uttarakhand",
    "WEST BENGAL": "West Bengal"
}

state_infrastructure['state'] = state_infrastructure['state'].map(name_mapping)

# ===================================================================
# LOAD AND PREPARE ASER DATA (2024)
# ===================================================================

# Load your ASER data
aser_2024 = aser_2024
df = aser_2024[aser_2024['State'] != 'All India'].copy()

# Get 2024 columns for reading across all standards
std3_reading_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'read' in c][0]
std5_reading_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'read' in c][0]
std8_reading_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'read' in c][0]

# Get 2024 columns for math across all standards
std3_math_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'subtraction' in c][0]
std5_math_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'division' in c][0]
std8_math_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'division' in c][0]

# Calculate composite scores (average across standards)
df['reading_score'] = df[[std3_reading_col, std5_reading_col, std8_reading_col]].mean(axis=1)
df['math_score'] = df[[std3_math_col, std5_math_col, std8_math_col]].mean(axis=1)

# Keep State and both scores
learning_data = df[['State', 'reading_score', 'math_score']].copy()

# ===================================================================
# CALCULATE STUDENT-TEACHER RATIO BY STATE
# ===================================================================

# Load enrollment data
df_enr_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_enr1.csv')

# Merge enrollment with state info
df_enr_merged = df_enr_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Class columns for boys and girls
boy_cols = ['cpp_b', 'c1_b', 'c2_b', 'c3_b', 'c4_b', 'c5_b', 'c6_b', 
            'c7_b', 'c8_b', 'c9_b', 'c10_b', 'c11_b', 'c12_b']
girl_cols = ['cpp_g', 'c1_g', 'c2_g', 'c3_g', 'c4_g', 'c5_g', 'c6_g', 
             'c7_g', 'c8_g', 'c9_g', 'c10_g', 'c11_g', 'c12_g']

# Calculate total students for each school
df_enr_merged['total_boys'] = df_enr_merged[boy_cols].sum(axis=1)
df_enr_merged['total_girls'] = df_enr_merged[girl_cols].sum(axis=1)
df_enr_merged['total_students'] = df_enr_merged['total_boys'] + df_enr_merged['total_girls']

# Aggregate students by state
state_students = df_enr_merged.groupby('state').agg({
    'total_students': 'sum'
}).reset_index()

# Load teacher data
df_teacher_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_tch.csv')

# Merge teacher data with state info
df_teacher_merged = df_teacher_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Aggregate teachers by state
state_teachers = df_teacher_merged.groupby('state').agg({
    'total_tch': 'sum'
}).reset_index()

# Merge teachers and students
state_str = state_teachers.merge(state_students, on='state', how='inner')

# Calculate student-teacher ratio
state_str['student_teacher_ratio'] = state_str['total_students'] / state_str['total_tch']

# Clean state names using same mapping
state_str['state'] = state_str['state'].map(name_mapping)

print(f"\nStudent-Teacher Ratio data: {len(state_str)} states")
print("\nSample STR data:")
print(state_str[['state', 'student_teacher_ratio', 'total_students', 'total_tch']].head(10))
print("\nSTR Summary Statistics:")
print(state_str['student_teacher_ratio'].describe())

# ===================================================================
# MERGE INFRASTRUCTURE AND LEARNING DATA
# ===================================================================

merged_data = state_infrastructure.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

# Check merge results
print(f"\nMerged infrastructure data: {len(merged_data)} states")
print(f"States in infrastructure: {state_infrastructure['state'].nunique()}")
print(f"States in ASER: {learning_data['State'].nunique()}")

# ===================================================================
# MERGE STR DATA WITH LEARNING DATA
# ===================================================================

merged_str_data = state_str.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

print(f"\nMerged STR data: {len(merged_str_data)} states")

# ===================================================================
# FUNCTION TO CREATE CORRELATION PLOT
# ===================================================================

def create_correlation_plot(data, x_col, y_col, title_subject, y_label):
    """Create a correlation scatter plot"""
    
    # Calculate correlation statistics
    correlation = data[x_col].corr(data[y_col])
    r_squared = correlation ** 2
    
    # Calculate trend line
    slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col], data[y_col])
    
    # Create trend line data
    x_trend = np.linspace(data[x_col].min(), data[x_col].max(), 100)
    y_trend = slope * x_trend + intercept
    
    # Create figure
    fig = go.Figure()
    
    # Add scatter points
    fig.add_trace(go.Scatter(
        x=data[x_col],
        y=data[y_col],
        mode='markers+text',
        marker=dict(
            size=10,
            color='#3b82f6',
            opacity=0.7,
            line=dict(width=1, color='#1e40af')
        ),
        text=data['state'],
        textposition='top center',
        textfont=dict(size=8, color='#374151'),
        name='States',
        hovertemplate='<b>%{text}</b><br>' +
                      'Infrastructure: %{x:.1f}%<br>' +
                      f'{title_subject} Score: ' + '%{y:.1f}%<br>' +
                      '<extra></extra>'
    ))
    
    # Add trend line
    fig.add_trace(go.Scatter(
        x=x_trend,
        y=y_trend,
        mode='lines',
        line=dict(color='#ef4444', width=2, dash='dash'),
        name=f'Trend Line (r={correlation:.3f})',
        hoverinfo='skip'
    ))
    
    # Update layout
    fig.update_layout(
        title=dict(
            text=f'School Infrastructure vs {title_subject} Outcomes (2024)<br>' +
                 f'<sub>Correlation: r = {correlation:.3f}, R² = {r_squared:.3f}, p = {p_value:.4f}</sub>',
            x=0.5,
            xanchor='center',
            font=dict(size=18, color='#1f2937')
        ),
        xaxis=dict(
            title=dict(
                text='Composite Infrastructure Index (%)<br><sub>(Electricity, Internet, Library, Girls\' Toilets)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb',
            range=[data[x_col].min() - 3, data[x_col].max() + 3]
        ),
        yaxis=dict(
            title=dict(
                text=f'{y_label}<br><sub>(Average across Std III, V, VIII)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb',
            range=[data[y_col].min() - 3, data[y_col].max() + 3]
        ),
        plot_bgcolor='white',
        hovermode='closest',
        showlegend=True,
        legend=dict(
            x=0.02, 
            y=0.98, 
            bgcolor='rgba(255,255,255,0.9)',
            bordercolor='#d1d5db',
            borderwidth=1
        ),
        width=1000,
        height=650,
        margin=dict(t=100, b=100, l=100, r=40)
    )
    
    return fig, correlation, r_squared, p_value, slope, intercept, std_err

# ===================================================================
# CREATE READING CORRELATION PLOT
# ===================================================================

print("\n" + "="*70)
print("READING OUTCOMES ANALYSIS")
print("="*70)

fig_reading, corr_reading, r2_reading, p_reading, slope_reading, intercept_reading, stderr_reading = create_correlation_plot(
    merged_data,
    'infrastructure_index',
    'reading_score',
    'Reading',
    'Composite Reading Proficiency Score (%)'
)

fig_reading.show()

# Print statistics
print(f"Correlation Coefficient (r): {corr_reading:.4f}")
print(f"R-squared (R²): {r2_reading:.4f}")
print(f"P-value: {p_reading:.6f}")
print(f"Standard Error: {stderr_reading:.4f}")
print(f"Trend Line Equation: y = {slope_reading:.3f}x + {intercept_reading:.3f}")

print(f"\nInterpretation:")
if p_reading < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if corr_reading > 0.7:
    print("✓ Strong positive correlation")
elif corr_reading > 0.4:
    print("✓ Moderate positive correlation")
else:
    print("○ Weak positive correlation")

# ===================================================================
# CREATE MATH CORRELATION PLOT
# ===================================================================

print("\n" + "="*70)
print("MATH OUTCOMES ANALYSIS")
print("="*70)

fig_math, corr_math, r2_math, p_math, slope_math, intercept_math, stderr_math = create_correlation_plot(
    merged_data,
    'infrastructure_index',
    'math_score',
    'Math',
    'Composite Math Proficiency Score (%)'
)

fig_math.show()

# Print statistics
print(f"Correlation Coefficient (r): {corr_math:.4f}")
print(f"R-squared (R²): {r2_math:.4f}")
print(f"P-value: {p_math:.6f}")
print(f"Standard Error: {stderr_math:.4f}")
print(f"Trend Line Equation: y = {slope_math:.3f}x + {intercept_math:.3f}")

print(f"\nInterpretation:")
if p_math < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if corr_math > 0.7:
    print("✓ Strong positive correlation")
elif corr_math > 0.4:
    print("✓ Moderate positive correlation")
else:
    print("○ Weak positive correlation")

# ===================================================================
# CREATE STUDENT-TEACHER RATIO PLOTS
# ===================================================================

def create_str_correlation_plot(data, x_col, y_col, title_subject, y_label):
    """Create a correlation scatter plot for student-teacher ratio"""
    
    # Calculate correlation statistics
    correlation = data[x_col].corr(data[y_col])
    r_squared = correlation ** 2
    
    # Calculate trend line
    slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col], data[y_col])
    
    # Create trend line data
    x_trend = np.linspace(data[x_col].min(), data[x_col].max(), 100)
    y_trend = slope * x_trend + intercept
    
    # Create figure
    fig = go.Figure()
    
    # Add scatter points
    fig.add_trace(go.Scatter(
        x=data[x_col],
        y=data[y_col],
        mode='markers+text',
        marker=dict(
            size=10,
            color='#10b981',
            opacity=0.7,
            line=dict(width=1, color='#059669')
        ),
        text=data['state'],
        textposition='top center',
        textfont=dict(size=8, color='#374151'),
        name='States',
        hovertemplate='<b>%{text}</b><br>' +
                      'Student-Teacher Ratio: %{x:.1f}<br>' +
                      f'{title_subject} Score: ' + '%{y:.1f}%<br>' +
                      '<extra></extra>'
    ))
    
    # Add trend line
    fig.add_trace(go.Scatter(
        x=x_trend,
        y=y_trend,
        mode='lines',
        line=dict(color='#ef4444', width=2, dash='dash'),
        name=f'Trend Line (r={correlation:.3f})',
        hoverinfo='skip'
    ))
    
    # Update layout
    fig.update_layout(
        title=dict(
            text=f'Student-Teacher Ratio vs {title_subject} Outcomes (2024)<br>' +
                 f'<sub>Correlation: r = {correlation:.3f}, R² = {r_squared:.3f}, p = {p_value:.4f}</sub>',
            x=0.5,
            xanchor='center',
            font=dict(size=18, color='#1f2937')
        ),
        xaxis=dict(
            title=dict(
                text='Student-Teacher Ratio<br><sub>(Lower is better - fewer students per teacher)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb'
        ),
        yaxis=dict(
            title=dict(
                text=f'{y_label}<br><sub>(Average across Std III, V, VIII)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb',
            range=[data[y_col].min() - 3, data[y_col].max() + 3]
        ),
        plot_bgcolor='white',
        hovermode='closest',
        showlegend=True,
        legend=dict(
            x=0.02, 
            y=0.98, 
            bgcolor='rgba(255,255,255,0.9)',
            bordercolor='#d1d5db',
            borderwidth=1
        ),
        width=1000,
        height=650,
        margin=dict(t=100, b=100, l=100, r=40)
    )
    
    return fig, correlation, r_squared, p_value, slope, intercept, std_err

# ===================================================================
# READING vs STUDENT-TEACHER RATIO
# ===================================================================

print("\n" + "="*70)
print("STUDENT-TEACHER RATIO vs READING OUTCOMES")
print("="*70)

fig_str_reading, corr_str_reading, r2_str_reading, p_str_reading, slope_str_reading, intercept_str_reading, stderr_str_reading = create_str_correlation_plot(
    merged_str_data,
    'student_teacher_ratio',
    'reading_score',
    'Reading',
    'Composite Reading Proficiency Score (%)'
)

fig_str_reading.show()

print(f"Correlation Coefficient (r): {corr_str_reading:.4f}")
print(f"R-squared (R²): {r2_str_reading:.4f}")
print(f"P-value: {p_str_reading:.6f}")
print(f"Standard Error: {stderr_str_reading:.4f}")

print(f"\nInterpretation:")
if p_str_reading < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if abs(corr_str_reading) > 0.7:
    print(f"✓ Strong {'negative' if corr_str_reading < 0 else 'positive'} correlation")
elif abs(corr_str_reading) > 0.4:
    print(f"✓ Moderate {'negative' if corr_str_reading < 0 else 'positive'} correlation")
else:
    print(f"○ Weak {'negative' if corr_str_reading < 0 else 'positive'} correlation")

# ===================================================================
# MATH vs STUDENT-TEACHER RATIO
# ===================================================================

print("\n" + "="*70)
print("STUDENT-TEACHER RATIO vs MATH OUTCOMES")
print("="*70)

fig_str_math, corr_str_math, r2_str_math, p_str_math, slope_str_math, intercept_str_math, stderr_str_math = create_str_correlation_plot(
    merged_str_data,
    'student_teacher_ratio',
    'math_score',
    'Math',
    'Composite Math Proficiency Score (%)'
)

fig_str_math.show()

print(f"Correlation Coefficient (r): {corr_str_math:.4f}")
print(f"R-squared (R²): {r2_str_math:.4f}")
print(f"P-value: {p_str_math:.6f}")
print(f"Standard Error: {stderr_str_math:.4f}")

print(f"\nInterpretation:")
if p_str_math < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if abs(corr_str_math) > 0.7:
    print(f"✓ Strong {'negative' if corr_str_math < 0 else 'positive'} correlation")
elif abs(corr_str_math) > 0.4:
    print(f"✓ Moderate {'negative' if corr_str_math < 0 else 'positive'} correlation")
else:
    print(f"○ Weak {'negative' if corr_str_math < 0 else 'positive'} correlation")

# ===================================================================
# FINAL COMPARISON SUMMARY
# ===================================================================

print("\n" + "="*70)
print("FINAL COMPARISON: ALL CORRELATIONS")
print("="*70)
print("\nINFRASTRUCTURE CORRELATIONS:")
print(f"  Reading: r = {corr_reading:.4f}, R² = {r2_reading:.4f}, p = {p_reading:.4f}")
print(f"  Math:    r = {corr_math:.4f}, R² = {r2_math:.4f}, p = {p_math:.4f}")

print("\nSTUDENT-TEACHER RATIO CORRELATIONS:")
print(f"  Reading: r = {corr_str_reading:.4f}, R² = {r2_str_reading:.4f}, p = {p_str_reading:.4f}")
print(f"  Math:    r = {corr_str_math:.4f}, R² = {r2_str_math:.4f}, p = {p_str_math:.4f}")

print("\nKEY FINDINGS:")
if abs(corr_str_reading) > abs(corr_reading) or abs(corr_str_math) > abs(corr_math):
    print("• Student-teacher ratio shows stronger correlation with outcomes than infrastructure")
else:
    print("• Infrastructure shows stronger correlation with outcomes than student-teacher ratio")

if p_str_reading < 0.05 or p_str_math < 0.05:
    print("• Student-teacher ratio correlations are statistically significant")
else:
    print("• Student-teacher ratio correlations are NOT statistically significant")
    
print("="*70)

# ===================================================================
# OPTIONAL: SAVE OUTPUT
# ===================================================================

# Save all figures
# fig_reading.write_html("infrastructure_vs_reading_2024.html")
# fig_math.write_html("infrastructure_vs_math_2024.html")
# fig_str_reading.write_html("str_vs_reading_2024.html")
# fig_str_math.write_html("str_vs_math_2024.html")

# Save merged data
# merged_data.to_csv('infrastructure_learning_correlation_2024.csv', index=False)
# merged_str_data.to_csv('str_learning_correlation_2024.csv', index=False)


Student-Teacher Ratio data: 36 states

Sample STR data:
                                      state  student_teacher_ratio  \
0                       Andaman and Nicobar              16.745108   
1                            Andhra Pradesh              46.886272   
2                         Arunachal Pradesh              26.142254   
3                                     Assam              44.475124   
4                                     Bihar              59.358718   
5                                Chandigarh              30.140750   
6                              Chhattisgarh              37.372101   
7  Dadra and Nagar Haveli and Daman and Diu              33.775340   
8                                     Delhi              34.531730   
9                                       Goa              29.438536   

   total_students  total_tch  
0          102681       6132  
1        16068910     342721  
2          656615      25117  
3        15142490     340471  
4        41997243

Correlation Coefficient (r): 0.1897
R-squared (R²): 0.0360
P-value: 0.343259
Standard Error: 0.1349
Trend Line Equation: y = 0.130x + 39.530

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak positive correlation

MATH OUTCOMES ANALYSIS


Correlation Coefficient (r): 0.2150
R-squared (R²): 0.0462
P-value: 0.281437
Standard Error: 0.1350
Trend Line Equation: y = 0.149x + 23.909

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak positive correlation

STUDENT-TEACHER RATIO vs READING OUTCOMES


Correlation Coefficient (r): -0.3008
R-squared (R²): 0.0905
P-value: 0.127367
Standard Error: 0.1724

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak negative correlation

STUDENT-TEACHER RATIO vs MATH OUTCOMES


Correlation Coefficient (r): -0.1099
R-squared (R²): 0.0121
P-value: 0.585384
Standard Error: 0.1809

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak negative correlation

FINAL COMPARISON: ALL CORRELATIONS

INFRASTRUCTURE CORRELATIONS:
  Reading: r = 0.1897, R² = 0.0360, p = 0.3433
  Math:    r = 0.2150, R² = 0.0462, p = 0.2814

STUDENT-TEACHER RATIO CORRELATIONS:
  Reading: r = -0.3008, R² = 0.0905, p = 0.1274
  Math:    r = -0.1099, R² = 0.0121, p = 0.5854

KEY FINDINGS:
• Student-teacher ratio shows stronger correlation with outcomes than infrastructure
• Student-teacher ratio correlations are NOT statistically significant


In [32]:
# ===================================================================
# MERGE INFRASTRUCTURE AND LEARNING DATA
# ===================================================================

merged_data = state_infrastructure.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

# Add total enrollment from student data for bubble size
merged_data = merged_data.merge(
    state_students[['state', 'total_students']],
    left_on='state',
    right_on='state',
    how='left'
)

# Check what columns we have
print(f"\nColumns in merged_data: {merged_data.columns.tolist()}")

# Check merge results
print(f"\nMerged infrastructure data: {len(merged_data)} states")
print(f"States in infrastructure: {state_infrastructure['state'].nunique()}")
print(f"States in ASER: {learning_data['State'].nunique()}")

# ===================================================================
# MERGE STR DATA WITH LEARNING DATA
# ===================================================================

merged_str_data = state_str.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

# Add total students for bubble size
merged_str_data = merged_str_data.merge(
    state_students[['state', 'total_students']],
    left_on='state',
    right_on='state',
    how='left',
    suffixes=('', '_y')
)

# Drop duplicate column if exists
if 'total_students_y' in merged_str_data.columns:
    merged_str_data = merged_str_data.drop('total_students_y', axis=1)

print(f"\nMerged STR data: {len(merged_str_data)} states")
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy import stats
import numpy as np

# ===================================================================
# LOAD AND PREPARE INFRASTRUCTURE DATA (2024)
# ===================================================================

# Read in school facility data for 2024
df_fac_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_fac.csv')
df_profile_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_prof1.csv')

# Merge facility data with state info
df_sch_2024 = df_fac_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Calculate infrastructure index - only count FUNCTIONAL facilities
state_infrastructure = df_sch_2024.groupby('state').agg({
    'electricity_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'internet': lambda x: (x == 1).sum() / len(x) * 100,
    'library_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'total_girls_func_toilet': lambda x: (x > 0).sum() / len(x) * 100,
    'pseudocode': 'count'
}).reset_index()

# Rename columns
state_infrastructure.rename(columns={
    'electricity_availability': 'pct_electricity_functional',
    'internet': 'pct_internet',
    'library_availability': 'pct_library',
    'total_girls_func_toilet': 'pct_girls_toilets',
    'pseudocode': 'num_schools'
}, inplace=True)

# Create composite infrastructure index
state_infrastructure['infrastructure_index'] = state_infrastructure[[
    'pct_electricity_functional', 'pct_internet', 'pct_library', 'pct_girls_toilets'
]].mean(axis=1)

# Clean state names
name_mapping = {
    "ANDAMAN & NICOBAR ISLANDS": "Andaman and Nicobar",
    "ANDHRA PRADESH": "Andhra Pradesh",
    "ARUNACHAL PRADESH": "Arunachal Pradesh",
    "ASSAM": "Assam",
    "BIHAR": "Bihar",
    "CHANDIGARH": "Chandigarh",
    "CHHATTISGARH": "Chhattisgarh",
    "DADRA & NAGAR HAVELI AND DAMAN & DIU": "Dadra and Nagar Haveli and Daman and Diu",
    "DELHI": "Delhi",
    "GOA": "Goa",
    "GUJARAT": "Gujarat",
    "HARYANA": "Haryana",
    "HIMACHAL PRADESH": "Himachal Pradesh",
    "JAMMU & KASHMIR": "Jammu and Kashmir",
    "JHARKHAND": "Jharkhand",
    "KARNATAKA": "Karnataka",
    "KERALA": "Kerala",
    "LADAKH": "Ladakh",
    "LAKSHADWEEP": "Lakshadweep",
    "MADHYA PRADESH": "Madhya Pradesh",
    "MAHARASHTRA": "Maharashtra",
    "MANIPUR": "Manipur",
    "MEGHALAYA": "Meghalaya",
    "MIZORAM": "Mizoram",
    "NAGALAND": "Nagaland",
    "ODISHA": "Odisha",
    "PUDUCHERRY": "Puducherry",
    "PUNJAB": "Punjab",
    "RAJASTHAN": "Rajasthan",
    "SIKKIM": "Sikkim",
    "TAMIL NADU": "Tamil Nadu",
    "TELANGANA": "Telangana",
    "TRIPURA": "Tripura",
    "UTTAR PRADESH": "Uttar Pradesh",
    "UTTARAKHAND": "Uttarakhand",
    "WEST BENGAL": "West Bengal"
}

state_infrastructure['state'] = state_infrastructure['state'].map(name_mapping)

# ===================================================================
# LOAD AND PREPARE ASER DATA (2024)
# ===================================================================

# Load your ASER data
df = aser_2024[aser_2024['State'] != 'All India'].copy()

# Get 2024 columns for reading across all standards
std3_reading_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'read' in c][0]
std5_reading_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'read' in c][0]
std8_reading_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'read' in c][0]

# Get 2024 columns for math across all standards
std3_math_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'subtraction' in c][0]
std5_math_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'division' in c][0]
std8_math_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'division' in c][0]

# Calculate composite scores (average across standards)
df['reading_score'] = df[[std3_reading_col, std5_reading_col, std8_reading_col]].mean(axis=1)
df['math_score'] = df[[std3_math_col, std5_math_col, std8_math_col]].mean(axis=1)

# Keep State and both scores
learning_data = df[['State', 'reading_score', 'math_score']].copy()

# Add region mapping
region_map = {
    'Andhra Pradesh': 'South', 'Arunachal Pradesh': 'Northeast',
    'Assam': 'Northeast', 'Bihar': 'North', 'Chhattisgarh': 'Central',
    'Gujarat': 'West', 'Haryana': 'North', 'Himachal Pradesh': 'North',
    'Jammu and Kashmir': 'North', 'Jharkhand': 'East', 'Karnataka': 'South',
    'Kerala': 'South', 'Madhya Pradesh': 'Central', 'Maharashtra': 'West',
    'Meghalaya': 'Northeast', 'Mizoram': 'Northeast', 'Nagaland': 'Northeast',
    'Odisha': 'East', 'Punjab': 'North', 'Rajasthan': 'North',
    'Sikkim': 'Northeast', 'Tamil Nadu': 'South', 'Telangana': 'South',
    'Tripura': 'Northeast', 'Uttar Pradesh': 'North', 'Uttarakhand': 'North',
    'West Bengal': 'East'
}

learning_data['Region'] = learning_data['State'].map(region_map)

# ===================================================================
# CALCULATE STUDENT-TEACHER RATIO BY STATE
# ===================================================================

# Load enrollment data
df_enr_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_enr1.csv')

# Merge enrollment with state info
df_enr_merged = df_enr_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Class columns for boys and girls
boy_cols = ['cpp_b', 'c1_b', 'c2_b', 'c3_b', 'c4_b', 'c5_b', 'c6_b', 
            'c7_b', 'c8_b', 'c9_b', 'c10_b', 'c11_b', 'c12_b']
girl_cols = ['cpp_g', 'c1_g', 'c2_g', 'c3_g', 'c4_g', 'c5_g', 'c6_g', 
             'c7_g', 'c8_g', 'c9_g', 'c10_g', 'c11_g', 'c12_g']

# Calculate total students for each school
df_enr_merged['total_boys'] = df_enr_merged[boy_cols].sum(axis=1)
df_enr_merged['total_girls'] = df_enr_merged[girl_cols].sum(axis=1)
df_enr_merged['total_students'] = df_enr_merged['total_boys'] + df_enr_merged['total_girls']

# Aggregate students by state
state_students = df_enr_merged.groupby('state').agg({
    'total_students': 'sum'
}).reset_index()

# Load teacher data
df_teacher_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_tch.csv')

# Merge teacher data with state info
df_teacher_merged = df_teacher_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Aggregate teachers by state
state_teachers = df_teacher_merged.groupby('state').agg({
    'total_tch': 'sum'
}).reset_index()

# Merge teachers and students
state_str = state_teachers.merge(state_students, on='state', how='inner')

# Calculate student-teacher ratio
state_str['student_teacher_ratio'] = state_str['total_students'] / state_str['total_tch']

# Clean state names using same mapping
state_str['state'] = state_str['state'].map(name_mapping)

print(f"\nStudent-Teacher Ratio data: {len(state_str)} states")
print("\nSample STR data:")
print(state_str[['state', 'student_teacher_ratio', 'total_students', 'total_tch']].head(10))
print("\nSTR Summary Statistics:")
print(state_str['student_teacher_ratio'].describe())

# ===================================================================
# MERGE INFRASTRUCTURE AND LEARNING DATA
# ===================================================================

merged_data = state_infrastructure.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

# Check merge results
print(f"\nMerged infrastructure data: {len(merged_data)} states")
print(f"States in infrastructure: {state_infrastructure['state'].nunique()}")
print(f"States in ASER: {learning_data['State'].nunique()}")

# ===================================================================
# MERGE STR DATA WITH LEARNING DATA
# ===================================================================

merged_str_data = state_str.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

print(f"\nMerged STR data: {len(merged_str_data)} states")

# ===================================================================
# FUNCTION TO CREATE CORRELATION PLOT WITH REGIONS
# ===================================================================

def create_correlation_plot(data, x_col, y_col, title_subject, y_label):
    """Create a correlation scatter plot with region colors and enrollment size"""
    
    # Calculate correlation statistics
    correlation = data[x_col].corr(data[y_col])
    r_squared = correlation ** 2
    
    # Calculate trend line
    slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col], data[y_col])
    
    # Create trend line data
    x_trend = np.linspace(data[x_col].min(), data[x_col].max(), 100)
    y_trend = slope * x_trend + intercept
    
    # Region colors
    region_colors = {
        'North': '#ef4444',      # Red
        'South': '#3b82f6',      # Blue
        'East': '#10b981',       # Green
        'West': '#f59e0b',       # Orange
        'Central': '#8b5cf6',    # Purple
        'Northeast': '#ec4899'   # Pink
    }
    
    # Check if total_students column exists
    has_enrollment = 'total_students' in data.columns
    
    # Create figure
    fig = go.Figure()
    
    # Add scatter points by region
    for region in data['Region'].unique():
        region_data = data[data['Region'] == region]
        
        # Create custom hover text with enrollment
        hover_text = []
        for idx, row in region_data.iterrows():
            hover_base = (
                f"<b>{row['state']}</b><br>" +
                f"{x_col.replace('_', ' ').title()}: {row[x_col]:.1f}<br>" +
                f"{title_subject} Score: {row[y_col]:.1f}%"
            )
            if has_enrollment:
                hover_base += f"<br>Students: {row['total_students']:,.0f}"
            hover_text.append(hover_base)
        
        # Determine marker size
        if has_enrollment:
            marker_size = region_data['total_students'] / 1000000
        else:
            marker_size = 10
        
        fig.add_trace(go.Scatter(
            x=region_data[x_col],
            y=region_data[y_col],
            mode='markers+text',
            marker=dict(
                size=marker_size,
                color=region_colors.get(region, '#6b7280'),
                opacity=0.7,
                line=dict(width=1, color='white'),
                sizemode='diameter',
                sizemin=4
            ),
            text=region_data['state'],
            textposition='top center',
            textfont=dict(size=8, color='#374151'),
            name=region,
            hovertext=hover_text,
            hoverinfo='text',
            legendgroup=region,
            showlegend=True
        ))
    
    # Add trend line
    fig.add_trace(go.Scatter(
        x=x_trend,
        y=y_trend,
        mode='lines',
        line=dict(color='#374151', width=2, dash='dash'),
        name=f'Trend (r={correlation:.3f})',
        hoverinfo='skip',
        showlegend=True
    ))
    
    # Update layout
    fig.update_layout(
        title=dict(
            text=f'School Infrastructure vs {title_subject} Outcomes (2024)<br>' +
                 f'<sub>Correlation: r = {correlation:.3f}, R² = {r_squared:.3f}, p = {p_value:.4f}</sub>',
            x=0.5,
            xanchor='center',
            font=dict(size=18, color='#1f2937')
        ),
        xaxis=dict(
            title=dict(
                text='Composite Infrastructure Index (%)<br><sub>(Electricity, Internet, Library, Girls\' Toilets)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb',
            range=[data[x_col].min() - 5, data[x_col].max() + 5]
        ),
        yaxis=dict(
            title=dict(
                text=f'{y_label}<br><sub>(Average across Std III, V, VIII)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb',
            range=[data[y_col].min() - 5, data[y_col].max() + 5]
        ),
        plot_bgcolor='white',
        hovermode='closest',
        legend=dict(
            x=1.02,
            y=1,
            xanchor='left',
            yanchor='top',
            bgcolor='rgba(255,255,255,0.9)',
            bordercolor='#d1d5db',
            borderwidth=1,
            title=dict(text='Region', font=dict(size=12, color='#374151'))
        ),
        width=1100,
        height=650,
        margin=dict(t=100, b=100, l=100, r=180)
    )
    
    return fig, correlation, r_squared, p_value, slope, intercept, std_err

# ===================================================================
# CREATE READING CORRELATION PLOT
# ===================================================================

print("\n" + "="*70)
print("READING OUTCOMES ANALYSIS")
print("="*70)

# Debug: Check columns before calling function
print(f"\nDEBUG - Columns in merged_data: {merged_data.columns.tolist()}")
print(f"DEBUG - First few rows of merged_data:")
print(merged_data[['state', 'infrastructure_index', 'reading_score']].head())

fig_reading, corr_reading, r2_reading, p_reading, slope_reading, intercept_reading, stderr_reading = create_correlation_plot(
    merged_data,
    'infrastructure_index',
    'reading_score',
    'Reading',
    'Composite Reading Proficiency Score (%)'
)

fig_reading.show()

# Print statistics
print(f"Correlation Coefficient (r): {corr_reading:.4f}")
print(f"R-squared (R²): {r2_reading:.4f}")
print(f"P-value: {p_reading:.6f}")
print(f"Standard Error: {stderr_reading:.4f}")
print(f"Trend Line Equation: y = {slope_reading:.3f}x + {intercept_reading:.3f}")

print(f"\nInterpretation:")
if p_reading < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if corr_reading > 0.7:
    print("✓ Strong positive correlation")
elif corr_reading > 0.4:
    print("✓ Moderate positive correlation")
else:
    print("○ Weak positive correlation")

# ===================================================================
# CREATE MATH CORRELATION PLOT
# ===================================================================

print("\n" + "="*70)
print("MATH OUTCOMES ANALYSIS")
print("="*70)

fig_math, corr_math, r2_math, p_math, slope_math, intercept_math, stderr_math = create_correlation_plot(
    merged_data,
    'infrastructure_index',
    'math_score',
    'Math',
    'Composite Math Proficiency Score (%)'
)

fig_math.show()

# Print statistics
print(f"Correlation Coefficient (r): {corr_math:.4f}")
print(f"R-squared (R²): {r2_math:.4f}")
print(f"P-value: {p_math:.6f}")
print(f"Standard Error: {stderr_math:.4f}")
print(f"Trend Line Equation: y = {slope_math:.3f}x + {intercept_math:.3f}")

print(f"\nInterpretation:")
if p_math < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if corr_math > 0.7:
    print("✓ Strong positive correlation")
elif corr_math > 0.4:
    print("✓ Moderate positive correlation")
else:
    print("○ Weak positive correlation")

# ===================================================================
# CREATE STUDENT-TEACHER RATIO PLOTS
# ===================================================================

def create_str_correlation_plot(data, x_col, y_col, title_subject, y_label):
    """Create a correlation scatter plot for student-teacher ratio with region colors"""
    
    # Calculate correlation statistics
    correlation = data[x_col].corr(data[y_col])
    r_squared = correlation ** 2
    
    # Calculate trend line
    slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col], data[y_col])
    
    # Create trend line data
    x_trend = np.linspace(data[x_col].min(), data[x_col].max(), 100)
    y_trend = slope * x_trend + intercept
    
    # Region colors
    region_colors = {
        'North': '#ef4444',      # Red
        'South': '#3b82f6',      # Blue
        'East': '#10b981',       # Green
        'West': '#f59e0b',       # Orange
        'Central': '#8b5cf6',    # Purple
        'Northeast': '#ec4899'   # Pink
    }
    
    # Check if total_students column exists
    has_enrollment = 'total_students' in data.columns
    
    # Create figure
    fig = go.Figure()
    
    # Add scatter points by region
    for region in data['Region'].unique():
        region_data = data[data['Region'] == region]
        
        # Create custom hover text with enrollment
        hover_text = []
        for idx, row in region_data.iterrows():
            hover_base = (
                f"<b>{row['state']}</b><br>" +
                f"Student-Teacher Ratio: {row[x_col]:.1f}<br>" +
                f"{title_subject} Score: {row[y_col]:.1f}%"
            )
            if has_enrollment:
                hover_base += f"<br>Students: {row['total_students']:,.0f}"
            hover_text.append(hover_base)
        
        # Determine marker size
        if has_enrollment:
            marker_size = region_data['total_students'] / 1000000
        else:
            marker_size = 10
        
        fig.add_trace(go.Scatter(
            x=region_data[x_col],
            y=region_data[y_col],
            mode='markers+text',
            marker=dict(
                size=marker_size,
                color=region_colors.get(region, '#6b7280'),
                opacity=0.7,
                line=dict(width=1, color='white'),
                sizemode='diameter',
                sizemin=4
            ),
            text=region_data['state'],
            textposition='top center',
            textfont=dict(size=8, color='#374151'),
            name=region,
            hovertext=hover_text,
            hoverinfo='text',
            legendgroup=region,
            showlegend=True
        ))
    
    # Add trend line
    fig.add_trace(go.Scatter(
        x=x_trend,
        y=y_trend,
        mode='lines',
        line=dict(color='#374151', width=2, dash='dash'),
        name=f'Trend (r={correlation:.3f})',
        hoverinfo='skip',
        showlegend=True
    ))
    
    # Update layout
    fig.update_layout(
        title=dict(
            text=f'Student-Teacher Ratio vs {title_subject} Outcomes (2024)<br>' +
                 f'<sub>Correlation: r = {correlation:.3f}, R² = {r_squared:.3f}, p = {p_value:.4f}</sub>',
            x=0.5,
            xanchor='center',
            font=dict(size=18, color='#1f2937')
        ),
        xaxis=dict(
            title=dict(
                text='Student-Teacher Ratio<br><sub>(Lower is better - fewer students per teacher)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb'
        ),
        yaxis=dict(
            title=dict(
                text=f'{y_label}<br><sub>(Average across Std III, V, VIII)</sub>',
                font=dict(size=13, color='#374151')
            ),
            showgrid=True,
            gridcolor='#e5e7eb',
            range=[data[y_col].min() - 5, data[y_col].max() + 5]
        ),
        plot_bgcolor='white',
        hovermode='closest',
        legend=dict(
            x=1.02,
            y=1,
            xanchor='left',
            yanchor='top',
            bgcolor='rgba(255,255,255,0.9)',
            bordercolor='#d1d5db',
            borderwidth=1,
            title=dict(text='Region', font=dict(size=12, color='#374151'))
        ),
        width=1100,
        height=650,
        margin=dict(t=100, b=100, l=100, r=180)
    )
    
    return fig, correlation, r_squared, p_value, slope, intercept, std_err

# ===================================================================
# READING vs STUDENT-TEACHER RATIO
# ===================================================================

print("\n" + "="*70)
print("STUDENT-TEACHER RATIO vs READING OUTCOMES")
print("="*70)

fig_str_reading, corr_str_reading, r2_str_reading, p_str_reading, slope_str_reading, intercept_str_reading, stderr_str_reading = create_str_correlation_plot(
    merged_str_data,
    'student_teacher_ratio',
    'reading_score',
    'Reading',
    'Composite Reading Proficiency Score (%)'
)

fig_str_reading.show()

print(f"Correlation Coefficient (r): {corr_str_reading:.4f}")
print(f"R-squared (R²): {r2_str_reading:.4f}")
print(f"P-value: {p_str_reading:.6f}")
print(f"Standard Error: {stderr_str_reading:.4f}")

print(f"\nInterpretation:")
if p_str_reading < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if abs(corr_str_reading) > 0.7:
    print(f"✓ Strong {'negative' if corr_str_reading < 0 else 'positive'} correlation")
elif abs(corr_str_reading) > 0.4:
    print(f"✓ Moderate {'negative' if corr_str_reading < 0 else 'positive'} correlation")
else:
    print(f"○ Weak {'negative' if corr_str_reading < 0 else 'positive'} correlation")

# ===================================================================
# MATH vs STUDENT-TEACHER RATIO
# ===================================================================

print("\n" + "="*70)
print("STUDENT-TEACHER RATIO vs MATH OUTCOMES")
print("="*70)

fig_str_math, corr_str_math, r2_str_math, p_str_math, slope_str_math, intercept_str_math, stderr_str_math = create_str_correlation_plot(
    merged_str_data,
    'student_teacher_ratio',
    'math_score',
    'Math',
    'Composite Math Proficiency Score (%)'
)

fig_str_math.show()

print(f"Correlation Coefficient (r): {corr_str_math:.4f}")
print(f"R-squared (R²): {r2_str_math:.4f}")
print(f"P-value: {p_str_math:.6f}")
print(f"Standard Error: {stderr_str_math:.4f}")

print(f"\nInterpretation:")
if p_str_math < 0.05:
    print(f"✓ Statistically significant (p < 0.05)")
else:
    print(f"✗ Not statistically significant (p >= 0.05)")

if abs(corr_str_math) > 0.7:
    print(f"✓ Strong {'negative' if corr_str_math < 0 else 'positive'} correlation")
elif abs(corr_str_math) > 0.4:
    print(f"✓ Moderate {'negative' if corr_str_math < 0 else 'positive'} correlation")
else:
    print(f"○ Weak {'negative' if corr_str_math < 0 else 'positive'} correlation")

# ===================================================================
# FINAL COMPARISON SUMMARY
# ===================================================================

print("\n" + "="*70)
print("FINAL COMPARISON: ALL CORRELATIONS")
print("="*70)
print("\nINFRASTRUCTURE CORRELATIONS:")
print(f"  Reading: r = {corr_reading:.4f}, R² = {r2_reading:.4f}, p = {p_reading:.4f}")
print(f"  Math:    r = {corr_math:.4f}, R² = {r2_math:.4f}, p = {p_math:.4f}")

print("\nSTUDENT-TEACHER RATIO CORRELATIONS:")
print(f"  Reading: r = {corr_str_reading:.4f}, R² = {r2_str_reading:.4f}, p = {p_str_reading:.4f}")
print(f"  Math:    r = {corr_str_math:.4f}, R² = {r2_str_math:.4f}, p = {p_str_math:.4f}")

print("\nKEY FINDINGS:")
if abs(corr_str_reading) > abs(corr_reading) or abs(corr_str_math) > abs(corr_math):
    print("• Student-teacher ratio shows stronger correlation with outcomes than infrastructure")
else:
    print("• Infrastructure shows stronger correlation with outcomes than student-teacher ratio")

if p_str_reading < 0.05 or p_str_math < 0.05:
    print("• Student-teacher ratio correlations are statistically significant")
else:
    print("• Student-teacher ratio correlations are NOT statistically significant")
    
print("="*70)

# ===================================================================
# OPTIONAL: SAVE OUTPUT
# ===================================================================

# Save all figures
# fig_reading.write_html("infrastructure_vs_reading_2024.html")
# fig_math.write_html("infrastructure_vs_math_2024.html")
# fig_str_reading.write_html("str_vs_reading_2024.html")
# fig_str_math.write_html("str_vs_math_2024.html")

# Save merged data
# merged_data.to_csv('infrastructure_learning_correlation_2024.csv', index=False)
# merged_str_data.to_csv('str_learning_correlation_2024.csv', index=False)


Columns in merged_data: ['state', 'pct_electricity_functional', 'pct_internet', 'pct_library', 'pct_girls_toilets', 'num_schools', 'infrastructure_index', 'State', 'reading_score', 'math_score', 'Region', 'total_students']

Merged infrastructure data: 27 states
States in infrastructure: 36
States in ASER: 27

Merged STR data: 27 states

Student-Teacher Ratio data: 36 states

Sample STR data:
                                      state  student_teacher_ratio  \
0                       Andaman and Nicobar              16.745108   
1                            Andhra Pradesh              46.886272   
2                         Arunachal Pradesh              26.142254   
3                                     Assam              44.475124   
4                                     Bihar              59.358718   
5                                Chandigarh              30.140750   
6                              Chhattisgarh              37.372101   
7  Dadra and Nagar Haveli and Daman and Diu 

Correlation Coefficient (r): 0.1897
R-squared (R²): 0.0360
P-value: 0.343259
Standard Error: 0.1349
Trend Line Equation: y = 0.130x + 39.530

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak positive correlation

MATH OUTCOMES ANALYSIS


Correlation Coefficient (r): 0.2150
R-squared (R²): 0.0462
P-value: 0.281437
Standard Error: 0.1350
Trend Line Equation: y = 0.149x + 23.909

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak positive correlation

STUDENT-TEACHER RATIO vs READING OUTCOMES


Correlation Coefficient (r): -0.3008
R-squared (R²): 0.0905
P-value: 0.127367
Standard Error: 0.1724

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak negative correlation

STUDENT-TEACHER RATIO vs MATH OUTCOMES


Correlation Coefficient (r): -0.1099
R-squared (R²): 0.0121
P-value: 0.585384
Standard Error: 0.1809

Interpretation:
✗ Not statistically significant (p >= 0.05)
○ Weak negative correlation

FINAL COMPARISON: ALL CORRELATIONS

INFRASTRUCTURE CORRELATIONS:
  Reading: r = 0.1897, R² = 0.0360, p = 0.3433
  Math:    r = 0.2150, R² = 0.0462, p = 0.2814

STUDENT-TEACHER RATIO CORRELATIONS:
  Reading: r = -0.3008, R² = 0.0905, p = 0.1274
  Math:    r = -0.1099, R² = 0.0121, p = 0.5854

KEY FINDINGS:
• Student-teacher ratio shows stronger correlation with outcomes than infrastructure
• Student-teacher ratio correlations are NOT statistically significant


In [None]:
import pandas as pd
import altair as alt
from scipy import stats
import numpy as np

# Enable Altair to handle larger datasets
alt.data_transformers.disable_max_rows()

# ===================================================================
# LOAD AND PREPARE INFRASTRUCTURE DATA (2024)
# ===================================================================

# Read in school facility data for 2024
df_fac_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_fac.csv')
df_profile_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_prof1.csv')

# Merge facility data with state info
df_sch_2024 = df_fac_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Calculate infrastructure index - only count FUNCTIONAL facilities
state_infrastructure = df_sch_2024.groupby('state').agg({
    'electricity_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'internet': lambda x: (x == 1).sum() / len(x) * 100,
    'library_availability': lambda x: (x == 1).sum() / len(x) * 100,
    'total_girls_func_toilet': lambda x: (x > 0).sum() / len(x) * 100,
    'pseudocode': 'count'
}).reset_index()

# Rename columns
state_infrastructure.rename(columns={
    'electricity_availability': 'pct_electricity_functional',
    'internet': 'pct_internet',
    'library_availability': 'pct_library',
    'total_girls_func_toilet': 'pct_girls_toilets',
    'pseudocode': 'num_schools'
}, inplace=True)

# Create composite infrastructure index
state_infrastructure['infrastructure_index'] = state_infrastructure[[
    'pct_electricity_functional', 'pct_internet', 'pct_library', 'pct_girls_toilets'
]].mean(axis=1)

# Clean state names
name_mapping = {
    "ANDAMAN & NICOBAR ISLANDS": "Andaman and Nicobar",
    "ANDHRA PRADESH": "Andhra Pradesh",
    "ARUNACHAL PRADESH": "Arunachal Pradesh",
    "ASSAM": "Assam",
    "BIHAR": "Bihar",
    "CHANDIGARH": "Chandigarh",
    "CHHATTISGARH": "Chhattisgarh",
    "DADRA & NAGAR HAVELI AND DAMAN & DIU": "Dadra and Nagar Haveli and Daman and Diu",
    "DELHI": "Delhi",
    "GOA": "Goa",
    "GUJARAT": "Gujarat",
    "HARYANA": "Haryana",
    "HIMACHAL PRADESH": "Himachal Pradesh",
    "JAMMU & KASHMIR": "Jammu and Kashmir",
    "JHARKHAND": "Jharkhand",
    "KARNATAKA": "Karnataka",
    "KERALA": "Kerala",
    "LADAKH": "Ladakh",
    "LAKSHADWEEP": "Lakshadweep",
    "MADHYA PRADESH": "Madhya Pradesh",
    "MAHARASHTRA": "Maharashtra",
    "MANIPUR": "Manipur",
    "MEGHALAYA": "Meghalaya",
    "MIZORAM": "Mizoram",
    "NAGALAND": "Nagaland",
    "ODISHA": "Odisha",
    "PUDUCHERRY": "Puducherry",
    "PUNJAB": "Punjab",
    "RAJASTHAN": "Rajasthan",
    "SIKKIM": "Sikkim",
    "TAMIL NADU": "Tamil Nadu",
    "TELANGANA": "Telangana",
    "TRIPURA": "Tripura",
    "UTTAR PRADESH": "Uttar Pradesh",
    "UTTARAKHAND": "Uttarakhand",
    "WEST BENGAL": "West Bengal"
}

state_infrastructure['state'] = state_infrastructure['state'].map(name_mapping)

# ===================================================================
# LOAD AND PREPARE ASER DATA (2024)
# ===================================================================

# Load your ASER data
df = aser_2024[aser_2024['State'] != 'All India'].copy()

# Get 2024 columns for reading across all standards
std3_reading_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'read' in c][0]
std5_reading_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'read' in c][0]
std8_reading_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'read' in c][0]

# Get 2024 columns for math across all standards
std3_math_col = [c for c in df.columns if '2024' in c and 'Std III' in c and 'subtraction' in c][0]
std5_math_col = [c for c in df.columns if '2024' in c and 'Std V' in c and 'division' in c][0]
std8_math_col = [c for c in df.columns if '2024' in c and 'Std VIII' in c and 'division' in c][0]

# Calculate composite scores (average across standards)
df['reading_score'] = df[[std3_reading_col, std5_reading_col, std8_reading_col]].mean(axis=1)
df['math_score'] = df[[std3_math_col, std5_math_col, std8_math_col]].mean(axis=1)

# Keep State and both scores
learning_data = df[['State', 'reading_score', 'math_score']].copy()

# Add region mapping
region_map = {
    'Andhra Pradesh': 'South', 'Arunachal Pradesh': 'Northeast',
    'Assam': 'Northeast', 'Bihar': 'North', 'Chhattisgarh': 'Central',
    'Gujarat': 'West', 'Haryana': 'North', 'Himachal Pradesh': 'North',
    'Jammu and Kashmir': 'North', 'Jharkhand': 'East', 'Karnataka': 'South',
    'Kerala': 'South', 'Madhya Pradesh': 'Central', 'Maharashtra': 'West',
    'Meghalaya': 'Northeast', 'Mizoram': 'Northeast', 'Nagaland': 'Northeast',
    'Odisha': 'East', 'Punjab': 'North', 'Rajasthan': 'North',
    'Sikkim': 'Northeast', 'Tamil Nadu': 'South', 'Telangana': 'South',
    'Tripura': 'Northeast', 'Uttar Pradesh': 'North', 'Uttarakhand': 'North',
    'West Bengal': 'East'
}

learning_data['Region'] = learning_data['State'].map(region_map)

# ===================================================================
# CALCULATE STUDENT-TEACHER RATIO BY STATE
# ===================================================================

# Load enrollment data
df_enr_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_enr1.csv')

# Merge enrollment with state info
df_enr_merged = df_enr_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Class columns for boys and girls
boy_cols = ['cpp_b', 'c1_b', 'c2_b', 'c3_b', 'c4_b', 'c5_b', 'c6_b', 
            'c7_b', 'c8_b', 'c9_b', 'c10_b', 'c11_b', 'c12_b']
girl_cols = ['cpp_g', 'c1_g', 'c2_g', 'c3_g', 'c4_g', 'c5_g', 'c6_g', 
             'c7_g', 'c8_g', 'c9_g', 'c10_g', 'c11_g', 'c12_g']

# Calculate total students for each school
df_enr_merged['total_boys'] = df_enr_merged[boy_cols].sum(axis=1)
df_enr_merged['total_girls'] = df_enr_merged[girl_cols].sum(axis=1)
df_enr_merged['total_students'] = df_enr_merged['total_boys'] + df_enr_merged['total_girls']

# Aggregate students by state
state_students = df_enr_merged.groupby('state').agg({
    'total_students': 'sum'
}).reset_index()

# Load teacher data
df_teacher_2024 = pd.read_csv('/Users/trishapunamiya/Desktop/LSE/Data Viz/Project/Raw Data/2024/100_tch.csv')

# Merge teacher data with state info
df_teacher_merged = df_teacher_2024.merge(
    df_profile_2024[['pseudocode', 'state']], 
    on='pseudocode', 
    how='left'
)

# Aggregate teachers by state
state_teachers = df_teacher_merged.groupby('state').agg({
    'total_tch': 'sum'
}).reset_index()

# Merge teachers and students
state_str = state_teachers.merge(state_students, on='state', how='inner')

# Calculate student-teacher ratio
state_str['student_teacher_ratio'] = state_str['total_students'] / state_str['total_tch']

# Clean state names using same mapping
state_str['state'] = state_str['state'].map(name_mapping)
state_students['state'] = state_students['state'].map(name_mapping)

print(f"\nStudent-Teacher Ratio data: {len(state_str)} states")

# ===================================================================
# MERGE INFRASTRUCTURE AND LEARNING DATA
# ===================================================================

merged_data = state_infrastructure.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

# Add total enrollment from student data for bubble size
merged_data = merged_data.merge(
    state_students[['state', 'total_students']],
    left_on='state',
    right_on='state',
    how='left'
)

# Scale enrollment for bubble size (in millions)
merged_data['enrollment_millions'] = merged_data['total_students'] / 1000000

print(f"\nMerged infrastructure data: {len(merged_data)} states")

# ===================================================================
# MERGE STR DATA WITH LEARNING DATA
# ===================================================================

merged_str_data = state_str.merge(
    learning_data,
    left_on='state',
    right_on='State',
    how='inner'
)

# Scale enrollment for bubble size (in millions)
merged_str_data['enrollment_millions'] = merged_str_data['total_students'] / 1000000

print(f"\nMerged STR data: {len(merged_str_data)} states")

# ===================================================================
# FUNCTION TO CREATE ALTAIR CORRELATION PLOT
# ===================================================================

def create_altair_plot(data, x_col, y_col, x_title, y_title, chart_title, x_min=None, x_max=None, y_min=None, y_max=None, show_legend=False):
    """Create an interactive Altair scatter plot with region colors and trend line"""
    
    # Calculate correlation
    correlation = data[x_col].corr(data[y_col])
    r_squared = correlation ** 2
    slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col], data[y_col])
    
    # Use provided min/max or calculate from data
    if x_min is None:
        x_min = data[x_col].min() - 5
    if x_max is None:
        x_max = data[x_col].max() + 5
    if y_min is None:
        y_min = data[y_col].min() - 5
    if y_max is None:
        y_max = data[y_col].max() + 5
    
    # Create trend line data
    trend_df = pd.DataFrame({
        x_col: [x_min, x_max],
        y_col: [slope * x_min + intercept, slope * x_max + intercept]
    })
    
    # Define region colors
    region_colors = alt.Scale(
        domain=['North', 'South', 'East', 'West', 'Central', 'Northeast'],
        range=['#ef4444', '#3b82f6', '#10b981', '#f59e0b', '#8b5cf6', '#ec4899']
    )
    
    # Create interactive selection
    region_selection = alt.selection_point(fields=['Region'], bind='legend')
    
    # Determine legend configuration
    if show_legend:
        legend_config = alt.Legend(
            title='Region',
            symbolSize=100,
            orient='right',
            direction='vertical'
        )
    else:
        legend_config = None  # Hide legend
    
    # Base scatter plot
    scatter = alt.Chart(data).mark_circle(
        opacity=0.7,
        stroke='white',
        strokeWidth=1
    ).encode(
        x=alt.X(x_col, 
                scale=alt.Scale(domain=[x_min, x_max]),
                title=x_title),
        y=alt.Y(y_col, 
                scale=alt.Scale(domain=[y_min, y_max]),
                title=y_title),
        size=alt.Size('enrollment_millions:Q', 
                      scale=alt.Scale(range=[50, 2000]),
                      legend=None),
        color=alt.Color('Region:N', 
                       scale=region_colors,
                       legend=legend_config),
        opacity=alt.condition(region_selection, alt.value(0.7), alt.value(0.1)),
        tooltip=[
            alt.Tooltip('state:N', title='State'),
            alt.Tooltip(f'{x_col}:Q', title=x_title, format='.1f'),
            alt.Tooltip(f'{y_col}:Q', title=y_title, format='.1f'),
            alt.Tooltip('total_students:Q', title='Total Students', format=',')
        ]
    ).add_params(
        region_selection
    )
    
    # Trend line
    trend_line = alt.Chart(trend_df).mark_line(
        color='#374151',
        strokeDash=[5, 5],
        size=2
    ).encode(
        x=alt.X(x_col),
        y=alt.Y(y_col)
    )
    
    # Combine layers - SIMPLIFIED TITLE
    chart = (scatter + trend_line).properties(
        width=450,
        height=400,
        title={
            "text": chart_title,
            "fontSize": 14,
            "anchor": "start",
            "offset": 10
        }
    ).configure_axis(
        grid=False,
        domainColor='#d1d5db'
    ).configure_view(
        strokeWidth=0
    )
    
    return chart, correlation, r_squared, p_value

# ===================================================================
# CREATE INFRASTRUCTURE PLOTS
# ===================================================================

print("\n" + "="*70)
print("INFRASTRUCTURE vs READING")
print("="*70)

chart_infra_reading, corr_reading, r2_reading, p_reading = create_altair_plot(
    merged_data,
    'infrastructure_index',
    'reading_score',
    'Infrastructure Index (%)',
    'Reading Score (%)',
    'Infrastructure vs Reading',
    x_min=0,
    x_max=100,
    y_min=0,
    y_max=100,
    show_legend=True  # Show legend only on first chart
)

chart_infra_reading.display()

print(f"Correlation: r = {corr_reading:.4f}, R² = {r2_reading:.4f}, p = {p_reading:.4f}")

print("\n" + "="*70)
print("INFRASTRUCTURE vs MATH")
print("="*70)

chart_infra_math, corr_math, r2_math, p_math = create_altair_plot(
    merged_data,
    'infrastructure_index',
    'math_score',
    'Infrastructure Index (%)',
    'Math Score (%)',
    'Infrastructure vs Math',
    x_min=0,
    x_max=100,
    y_min=0,
    y_max=100,
    show_legend=True  # Hide legend
)

chart_infra_math.display()

print(f"Correlation: r = {corr_math:.4f}, R² = {r2_math:.4f}, p = {p_math:.4f}")

# ===================================================================
# CREATE STR PLOTS
# ===================================================================

print("\n" + "="*70)
print("STUDENT-TEACHER RATIO vs READING")
print("="*70)

chart_str_reading, corr_str_reading, r2_str_reading, p_str_reading = create_altair_plot(
    merged_str_data,
    'student_teacher_ratio',
    'reading_score',
    'Student-Teacher Ratio',
    'Reading Score (%)',
    'Class Size vs Reading',
    x_min=0,
    x_max=None,
    y_min=0,
    y_max=100,
    show_legend=True  # Hide legend
)

chart_str_reading.display()

print(f"Correlation: r = {corr_str_reading:.4f}, R² = {r2_str_reading:.4f}, p = {p_str_reading:.4f}")

print("\n" + "="*70)
print("STUDENT-TEACHER RATIO vs MATH")
print("="*70)

chart_str_math, corr_str_math, r2_str_math, p_str_math = create_altair_plot(
    merged_str_data,
    'student_teacher_ratio',
    'math_score',
    'Student-Teacher Ratio',
    'Math Score (%)',
    'Class Size vs Math',
    x_min=0,
    x_max=None,
    y_min=0,
    y_max=100,
    show_legend=True  # Hide legend
)

chart_str_math.display()

print(f"Correlation: r = {corr_str_math:.4f}, R² = {r2_str_math:.4f}, p = {p_str_math:.4f}")

# ===================================================================
# SUMMARY
# ===================================================================

print("\n" + "="*70)
print("SUMMARY: ALL CORRELATIONS")
print("="*70)
print("\nINFRASTRUCTURE CORRELATIONS:")
print(f"  Reading: r = {corr_reading:.4f}, R² = {r2_reading:.4f}, p = {p_reading:.4f}")
print(f"  Math:    r = {corr_math:.4f}, R² = {r2_math:.4f}, p = {p_math:.4f}")

print("\nSTUDENT-TEACHER RATIO CORRELATIONS:")
print(f"  Reading: r = {corr_str_reading:.4f}, R² = {r2_str_reading:.4f}, p = {p_str_reading:.4f}")
print(f"  Math:    r = {corr_str_math:.4f}, R² = {r2_str_math:.4f}, p = {p_str_math:.4f}")

# ===================================================================
# OPTIONAL: SAVE CHARTS
# ===================================================================

# Save charts as JSON for web
chart_infra_reading.save('charts/infrastructure_vs_reading_2024.json')
chart_infra_math.save('charts/infrastructure_vs_math_2024.json')
chart_str_reading.save('charts/str_vs_reading_2024.json')
chart_str_math.save('charts/str_vs_math_2024.json')


Student-Teacher Ratio data: 36 states

Merged infrastructure data: 27 states

Merged STR data: 27 states

INFRASTRUCTURE vs READING


Correlation: r = 0.1897, R² = 0.0360, p = 0.3433

INFRASTRUCTURE vs MATH


Correlation: r = 0.2150, R² = 0.0462, p = 0.2814

STUDENT-TEACHER RATIO vs READING


Correlation: r = -0.3008, R² = 0.0905, p = 0.1274

STUDENT-TEACHER RATIO vs MATH


Correlation: r = -0.1099, R² = 0.0121, p = 0.5854

SUMMARY: ALL CORRELATIONS

INFRASTRUCTURE CORRELATIONS:
  Reading: r = 0.1897, R² = 0.0360, p = 0.3433
  Math:    r = 0.2150, R² = 0.0462, p = 0.2814

STUDENT-TEACHER RATIO CORRELATIONS:
  Reading: r = -0.3008, R² = 0.0905, p = 0.1274
  Math:    r = -0.1099, R² = 0.0121, p = 0.5854
