In [4]:
import pandas as pd
import altair as alt

In [2]:
df = pd.read_csv('/Users/trishapunamiya/tpunamiya.github.io-4/data/Enrolment in higher education  - Data For India.csv')

In [53]:
# Rename columns for clarity
df.columns = ['Country_code', 'Year', 'Enrollment']

# Remove rows with missing enrollment data
df = df.dropna(subset=['Enrollment'])

# Calculate World average by grouping by Year
world_avg = df.groupby('Year')['Enrollment'].mean().reset_index()
world_avg['Country_code'] = 'World'

# Filter for specific countries you want to display
countries_to_plot = ['India', 'China']  # Add other countries as needed
df_countries = df[df['Country_code'].isin(countries_to_plot)].copy()

# Combine country data with world average
df_combined = pd.concat([df_countries, world_avg], ignore_index=True)

# Sort by country and year
df_combined = df_combined.sort_values(['Country_code', 'Year'])

# Get year range
min_year = int(df_combined['Year'].min())
max_year = int(df_combined['Year'].max())

# Create two range sliders - one for start year and one for end year
year_start = alt.param(
    name='year_start',
    value=min_year,
    bind=alt.binding_range(
        min=min_year,
        max=max_year,
        step=1,
        name='Start Year: '
    )
)

year_end = alt.param(
    name='year_end',
    value=max_year,
    bind=alt.binding_range(
        min=min_year,
        max=max_year,
        step=1,
        name='End Year: '
    )
)

# Define color scheme
color_scale = alt.Scale(
    domain=['India', 'China', 'World'],
    range=['#003f5c', '#bc5090', '#ffa600']
)

# Get unique years from the data for tick values
unique_years = sorted(df_combined['Year'].unique())

# Create the base chart
base = alt.Chart(df_combined).encode(
    x=alt.X('Year:Q', 
            title='Date',
            scale=alt.Scale(domain=[min_year, max_year]),
            axis=alt.Axis(format='d', grid=False, domain=True, ticks=True, values=unique_years)),
    y=alt.Y('Enrollment:Q', 
            title='Enrollment (%)',
            scale=alt.Scale(domain=[0, 80]),
            axis=alt.Axis(format='d', labelExpr='datum.value + "%"', grid=False, domain=True, ticks=True)),
    color=alt.Color('Country_code:N', 
                    scale=color_scale,
                    legend=None)
)

# Lines filtered by the year range
line = base.mark_line(size=3).transform_filter(
    (alt.datum.Year >= year_start) & (alt.datum.Year <= year_end)
)

# Points at the start and end of the range
points_start = base.mark_circle(size=100).transform_filter(
    alt.datum.Year == year_start
)

points_end = base.mark_circle(size=100).transform_filter(
    alt.datum.Year == year_end
)

# Get the last data point for each country to place labels
df_labels = df_combined.loc[df_combined.groupby('Country_code')['Year'].idxmax()]

# Text labels at the end of lines (right side)
labels = alt.Chart(df_labels).mark_text(
    align='left',
    dx=7,
    fontSize=13,
    fontWeight=400
).encode(
    x=alt.X('Year:Q'),
    y=alt.Y('Enrollment:Q'),
    text='Country_code:N',
    color=alt.Color('Country_code:N', 
                    scale=color_scale,
                    legend=None)
).transform_filter(
    alt.datum.Year == year_end
)

# Combine the chart
chart = (line + points_start + points_end + labels).add_params(
    year_start,
    year_end
).properties(
    width=900,
    height=500,
    title={
        'text': 'Enrolment in higher education',
        'subtitle': 'More young Indians are enrolling in college, but the proportion lags behind the world',
        'fontSize': 20,
        'subtitleFontSize': 14,
        'subtitleColor': '#666666',
    }
).configure_view(
    strokeWidth=0
).configure_axis(
    grid=False,
    domain=False,
    tickSize=5
)
chart

In [52]:
# save to json
chart.save('charts/CC9_1.json')

In [74]:
import altair as alt
import pandas as pd
import requests

def get_world_bank_data(indicator, year=2022):
    """Fetch data from World Bank API"""
    url = f"https://api.worldbank.org/v2/country/all/indicator/{indicator}"
    params = {
        'date': year,
        'format': 'json',
        'per_page': 500
    }
    
    response = requests.get(url, params=params)
    if response.status_code == 200:
        data = response.json()
        if len(data) > 1:
            records = []
            for item in data[1]:
                if item['value'] is not None:
                    records.append({
                        'country': item['country']['value'],
                        'country_code': item['countryiso3code'],
                        'year': item['date'],
                        'value': item['value']
                    })
            return pd.DataFrame(records)
    return pd.DataFrame()


# Try 2022 data
gdp_data = get_world_bank_data('NY.GDP.PCAP.CD', 2022)
gdp_data = gdp_data.rename(columns={'value': 'gdp_per_capita'})

inflation_data = get_world_bank_data('FP.CPI.TOTL.ZG', 2022)
inflation_data = inflation_data.rename(columns={'value': 'inflation_rate'})

population_data = get_world_bank_data('SP.POP.TOTL', 2022)
population_data = population_data.rename(columns={'value': 'population'})

# Merge the datasets
df = gdp_data[['country', 'country_code', 'gdp_per_capita']].merge(
    inflation_data[['country_code', 'inflation_rate']], 
    on='country_code', 
    how='inner'
).merge(
    population_data[['country_code', 'population']], 
    on='country_code', 
    how='inner'
)

# Remove rows with missing values
df = df.dropna()

# Filter out aggregates
aggregate_patterns = [
    'income', 'dividend', 'IDA', 'IBRD', 'blend', 'only',
    'member', 'developing', 'developed', 'situations',
    'Europe and', 'Africa Western', 'Africa Eastern', 
    'Central Europe', 'demographic', 'Fragile',
    'conflict', 'Arab World', 'Caribbean', 'Pacific island',
    'Small states', 'OECD', 'Euro area', 'European Union'
]

for pattern in aggregate_patterns:
    df = df[~df['country'].str.contains(pattern, case=False, na=False)]


# Exclude aggregate codes
exclude_codes = ['WLD', 'EAS', 'ECS', 'LCN', 'MEA', 'NAC', 'SAS', 'SSF', 
                 'HIC', 'LIC', 'LMC', 'MIC', 'UMC', 'EUU', 'ARB', 'CSS',
                 'EAP', 'EMU', 'ECA', 'LAC', 'MNA', 'SSA', 'IBD', 'IBT',
                 'IDX', 'IDB', 'IDA', 'TEA', 'TEC', 'TLA', 'TMN', 'TSA',
                 'TSS', 'OED', 'OSS', 'PSS', 'PST', 'SST', 'INX',
                 'AFE', 'AFW', 'CEB', 'EAR', 'FCS', 'LTE', 'PRE', 'IDD',
                 'HPC']

df = df[~df['country_code'].isin(exclude_codes)]
print(f"After excluding aggregate codes: {len(df)} countries")

# Assign regions
def assign_region(country_code):
    asia_pacific = ['CHN', 'IND', 'IDN', 'JPN', 'KOR', 'THA', 'VNM', 'PHL', 'MYS', 'SGP', 
                    'PAK', 'BGD', 'MMR', 'KHM', 'LAO', 'NPL', 'LKA', 'AFG', 'BTN', 'MDV',
                    'AUS', 'NZL', 'PNG', 'FJI', 'MNG', 'BRN', 'TLS', 'KAZ', 'UZB', 'KGZ',
                    'TJK', 'TKM', 'WSM', 'TON', 'VUT', 'SLB', 'KIR', 'FSM', 'MHL', 'PLW',
                    'TUV', 'NRU', 'HKG', 'MAC']
    
    europe = ['DEU', 'GBR', 'FRA', 'ITA', 'ESP', 'POL', 'ROU', 'NLD', 'BEL', 'GRC', 
              'CZE', 'PRT', 'SWE', 'HUN', 'AUT', 'CHE', 'BGR', 'DNK', 'FIN', 'SVK',
              'NOR', 'IRL', 'HRV', 'SVN', 'LTU', 'LVA', 'EST', 'LUX', 'ISL', 'MLT',
              'CYP', 'UKR', 'BLR', 'RUS', 'MDA', 'ALB', 'SRB', 'MKD', 'BIH', 'MNE',
              'ARM', 'GEO', 'AZE']
    
    americas = ['USA', 'BRA', 'MEX', 'CAN', 'ARG', 'COL', 'PER', 'VEN', 'CHL', 'ECU',
                'GTM', 'CUB', 'BOL', 'HTI', 'DOM', 'HND', 'PRY', 'NIC', 'SLV', 'CRI',
                'PAN', 'URY', 'JAM', 'TTO', 'GUY', 'SUR', 'BHS', 'BLZ', 'BRB', 'GRD']
    
    africa = ['NGA', 'ETH', 'EGY', 'COD', 'TZA', 'ZAF', 'KEN', 'UGA', 'SDN', 'DZA',
              'MAR', 'AGO', 'GHA', 'MOZ', 'MDG', 'CMR', 'CIV', 'NER', 'BFA', 'MLI',
              'MWI', 'ZMB', 'SOM', 'SEN', 'TCD', 'ZWE', 'GIN', 'RWA', 'BEN', 'TUN',
              'BDI', 'SSD', 'TGO', 'SLE', 'LBY', 'LBR', 'MRT', 'CAF', 'ERI', 'GMB',
              'BWA', 'NAM', 'GAB', 'LSO', 'GNB', 'GNQ', 'MUS', 'SWZ', 'COM', 'CPV',
              'STP', 'SYC', 'DJI', 'COG']
    
    middle_east = ['SAU', 'IRN', 'IRQ', 'YEM', 'SYR', 'JOR', 'ARE', 'ISR', 'LBN', 'OMN',
                   'KWT', 'PSE', 'QAT', 'BHR', 'TUR']
    
    if country_code in asia_pacific:
        return 'Asia & Pacific'
    elif country_code in europe:
        return 'Europe'
    elif country_code in americas:
        return 'Americas'
    elif country_code in africa:
        return 'Africa'
    elif country_code in middle_east:
        return 'Middle East'
    else:
        return 'Other'

df['region'] = df['country_code'].apply(assign_region)


# Remove extreme outliers
df = df[df['inflation_rate'].abs() < 100]
df = df[df['gdp_per_capita'] < 120000]
df = df[df['gdp_per_capita'] > 100]


# Create a SIMPLE chart first
simple_chart = alt.Chart(df).mark_circle(size=100, opacity=0.7).encode(
    x=alt.X('gdp_per_capita:Q', scale=alt.Scale(type='log'), title='GDP per Capita'),
    y=alt.Y('inflation_rate:Q', title='Inflation Rate'),
    color='region:N',
    tooltip=['country:N', 'gdp_per_capita:Q', 'inflation_rate:Q', 'region:N']
).properties(
    width=600,
    height=400,
    title='GDP per Capita vs Inflation Rate (Simple)'
)


hover = alt.selection_point(
    fields=['country'],
    on='mouseover',
    nearest=True,
    empty=False
)

zoom = alt.selection_interval(bind='scales')

# Legend selection with ALL regions selected by default
legend_selection = alt.selection_point(
    fields=['region'],
    bind='legend',
    on='click',
    clear='dblclick'
)

full_chart = alt.Chart(df).mark_circle().encode(
    x=alt.X('gdp_per_capita:Q',
            title='GDP per Capita (US$)',
            scale=alt.Scale(type='log'),
            axis=alt.Axis(format='$,.0f', grid=False)),
    y=alt.Y('inflation_rate:Q',
            title='Inflation Rate (%)',
            scale=alt.Scale(zero=False),
            axis=alt.Axis(grid=False)),
    color=alt.Color('region:N',
                    title='Region',
                    scale=alt.Scale(scheme='category10'),
                    legend=alt.Legend(orient='right')),
    size=alt.condition(
        hover,
        alt.value(1000),
        alt.Size('population:Q', 
                 legend=None, 
                 scale=alt.Scale(type='sqrt', range=[100, 2000]))
    ),
    opacity=alt.condition(
        legend_selection,
        alt.value(0.8),   # Selected regions
        alt.value(0.1)    # Unselected regions
    ),
    tooltip=[
        alt.Tooltip('country:N', title='Country'),
        alt.Tooltip('region:N', title='Region'),
        alt.Tooltip('gdp_per_capita:Q', title='GDP per Capita', format='$,.0f'),
        alt.Tooltip('inflation_rate:Q', title='Inflation Rate', format='.2f'),
        alt.Tooltip('population:Q', title='Population', format=',.0f')
    ]
).add_params(
    hover,
    zoom,
    legend_selection
).properties(
    width=700,
    height=550,
    title={
        'text': 'Country Economics: GDP per Capita vs Inflation Rate (2022)',
        'subtitle': ['Hover to highlight, click legend to filter, drag to zoom. Circle size = population'],
        'subtitleFontStyle': 'italic',
        'subtitleFontSize': 11
    }
).configure_view(
    strokeWidth=0
)

full_chart

After excluding aggregate codes: 176 countries


In [76]:
# save to json
full_chart.save('charts/CC9_2.json')