In [None]:
import pandas as pd
import numpy as np
import altair as alt
import pymc as pm
import causalpy as cp
import pymc_bart as pmb
import arviz as az

from utils import *

az.style.use("arviz-doc")
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

tfr_data = pd.read_csv("./data/tfr2022.csv", index_col=0, sep=';').drop('Country Code', axis=1)

# Clean the data by dropping some groups and countries with too many missing values

tfr_data =  tfr_data.T.iloc[:-2].drop([
    'Central Europe and the Baltics', 
    'Europe & Central Asia (excluding high income)', 
    'Europe & Central Asia',
    'Africa Eastern and Southern',
    'Africa Western and Central',
    'European Union',
    'Euro area',
    'Post-demographic dividend',
    'Channel Islands',
    'Palau',
    'Andorra',
    'San Marino',
    'Liechtenstein',
    'Seychelles',
    'West Bank and Gaza',
    'Greenland',
    'Faroe Islands',
    'Europe & Central Asia (IDA & IBRD countries)'
    ], 
    axis=1).rename(columns={'United Kingdom': 'UK'}).dropna(axis=1, how='all')

tfr_data


In [None]:
target_country = 'Estonia'
treatment_year = 2017
cut_off = 1995

tfr_data.rename_axis(None, axis=1, inplace=True)

cdf = tfr_data[(tfr_data.index.astype('int') > cut_off) & (tfr_data.index.astype('int') < int(treatment_year))].corr()

# Select all countries with 0.8+ correlation between cutoff and treatment into the synthetic control
all_countries = cdf[cdf[target_country] > 0.8].index.tolist()
other_countries = [c for c in all_countries if c != target_country]

tfr_data = tfr_data[tfr_data.index.astype('int') > cut_off][all_countries]

tfr_at_intervention = tfr_data.loc[str(treatment_year), :]
tfr_normalised = (tfr_data.astype('float') / tfr_at_intervention)

tfr_normalised

In [None]:

source = tfr_normalised[all_countries].stack().reset_index().rename(columns={0: 'tfr', 'level_0':'year', 'level_1':'country'})

highlight = alt.selection_point(
    on='mouseover',
    fields=['country'], 
    nearest=True,
    empty=False
)

base = alt.Chart(
    source,
    width=600,
    height=400
).encode(
    x=alt.X('year:Q').axis(format='.0f').scale(domainMin=cut_off+1).title('Year'),
    y=alt.Y('tfr:Q', title='TFR (normalised)').scale(domain=[0.4, 1.5]),
    detail='country:N'
) 

points = base.mark_circle().encode(
    opacity=alt.value(0),
    tooltip='country'
).add_params(
    highlight
)

lines = base.mark_line(
    interpolate='natural'
).encode(
    size=alt.when(highlight).then(alt.value(2)).otherwise(alt.value(1)),
    color=alt.when(highlight).then(alt.value('red')).otherwise(alt.value('lightgray')),
    #color=alt.condition(~highlight, alt.value('lightgray'), alt.value('red')),
    tooltip='country'
)

rule = alt.Chart(pd.DataFrame({
  'aasta': [treatment_year],
  'color': 'red'
})).mark_rule(strokeDash=[1,1]).encode(
  x='aasta:Q',
  color=alt.Color('color:N', scale=None)
)

(points + lines + rule).configure(font='SF Compact Display')

In [None]:
def generate_data(start_year=2005, end_year=2022, treatment_year=2017):
    """
    Generate realistic synthetic data for demonstrating the model.
    
    Parameters:
    -----------
    start_year : int
        First year in the dataset
    end_year : int
        Last year in the dataset
    treatment_year : int
        Year when the treatment happened
        
    Returns:
    --------
    df : DataFrame
        DataFrame with TFR values for Estonia and control countries
    """
    # Create years range
    years = np.arange(start_year, end_year + 1)
    n_years = len(years)
    
    # Country names (Estonia + control countries with high correlation)
    countries = ['Estonia', 'Latvia', 'Lithuania', 'Finland', 'Czech Republic']
    n_countries = len(countries)
    
    # Base TFR trajectories (similar patterns with some correlation)
    np.random.seed(1997)
    
    # Create base trend - declining until around 2012-2014, then slightly increasing
    base_trend = np.concatenate([
        np.linspace(1.8, 1.4, years.searchsorted(2012) - years.searchsorted(start_year)),
        np.linspace(1.4, 1.6, years.searchsorted(end_year+1) - years.searchsorted(2012))
    ])
    
    # Generate correlated noise for each country
    corr_matrix = np.array([
        [1.0, 0.8, 0.7, 0.6, 0.7],
        [0.8, 1.0, 0.7, 0.8, 0.4],
        [0.7, 0.7, 1.0, 0.6, 0.5],
        [0.6, 0.8, 0.6, 1.0, 0.4],
        [0.7, 0.4, 0.5, 0.4, 1.0]
    ])
    
    # Cholesky decomposition for generating correlated random variables
    L = np.linalg.cholesky(corr_matrix)
    uncorrelated_noise = np.random.normal(0, 0.1, (n_years, n_countries))
    correlated_noise = np.dot(uncorrelated_noise, L.T)
    
    # Generate data for each country
    data = np.zeros((n_years, n_countries))
    for i in range(n_countries):
        # Each country follows the base trend with some country-specific variation
        country_variation = np.random.normal(0, 0.1)  # Country's base level difference
        data[:, i] = base_trend + country_variation + correlated_noise[:, i]
    
    # Add treatment effect for Estonia after treatment_year
    treatment_idx = years.searchsorted(treatment_year)
    # Linear decrease in treatment effect over time
    treatment_effect = np.zeros(n_years)
    treatment_effect[treatment_idx:] = np.linspace(0.3, 0.1, n_years - treatment_idx)
    data[:, 0] += treatment_effect

    ate = treatment_effect[treatment_idx:] .mean()
    
    # Create DataFrame
    df = pd.DataFrame(data, index=years, columns=countries)
    
    return df, treatment_year, ate

df, _, ate = generate_data(start_year=1995, end_year=2023, treatment_year=2017)

print('True average treatment effect:', ate)
df.plot(backend='hvplot')


In [6]:

def create_synthetic_control_model(X_control, y_treated_pre, y_treated_post=None):
    """
    Set up a PyMC model for synthetic control analysis.
    
    Parameters:
    -----------
    X_control : numpy.ndarray
        Array of shape (n_time_periods_pre + n_time_periods_post, n_control_units)
        Contains TFRs for control countries for both pre and post treatment periods
    y_treated_pre : numpy.ndarray
        Array of shape (n_time_periods_pre,)
        Contains Estonia's TFR values for pre-treatment period
    y_treated_post : numpy.ndarray, optional
        Array of shape (n_time_periods_post,)
        Contains Estonia's TFR values for post-treatment period
        
    Returns:
    --------
    PyMC model
    """
    
    n_control_units = X_control.shape[1]
    n_time_pre = len(y_treated_pre)
    
    # Split control data into pre and post periods
    X_control_pre = X_control[:n_time_pre]
    X_control_post = X_control[n_time_pre:] if X_control.shape[0] > n_time_pre else None
    
    with pm.Model() as model:
        
        # Prior for control weights (Dirichlet ensures weights sum to 1)
        beta = pm.Dirichlet('beta', a=np.ones(n_control_units))
        
        # Prior for observation noise
        sigma = pm.HalfNormal('sigma', sigma=0.1)
        
        # Deterministic prediction for pre-treatment period
        mu_pre = pm.Deterministic('mu_pre', pm.math.dot(X_control_pre, beta))
        
        # Likelihood for pre-treatment period (observed data)
        y_pre = pm.Normal('y_pre', mu=mu_pre, sigma=sigma, observed=y_treated_pre)
        
        # If we have post-treatment data, set up predictions for counterfactual
        if X_control_post is not None and y_treated_post is not None:
            # Deterministic prediction for post-treatment (counterfactual)
            mu_post = pm.Deterministic('mu_post', pm.math.dot(X_control_post, beta))
            
            # Compute treatment effect
            treatment_effect = pm.Deterministic('treatment_effect', 
                                              y_treated_post - mu_post)
            
            # Average treatment effect
            avg_treatment_effect = pm.Deterministic('avg_treatment_effect', 
                                                    pm.math.mean(treatment_effect))
            
    return model

def fit_model(model, samples=2000, tune=1000, chains=4, target_accept=0.9):

    with model:
        trace = pm.sample(samples, tune=tune, chains=chains, 
                          target_accept=target_accept)
    return trace



In [None]:
df, treatment_year, _ = generate_data(start_year=1995, end_year=2023, treatment_year=2017)

#df = tfr_data.copy()
#df.index = df.index.astype(int)

print("Example Data (TFR by country and year):")
print(df.head())

# Prepare data for model
X_control, estonia_pre, estonia_post, control_countries, pre_years, post_years = prepare_data(
    df, 'Estonia', treatment_year
)

print("\nControl countries:", control_countries)
print(f"Pre-treatment years: {pre_years[0]}-{pre_years[-1]}")
print(f"Post-treatment years: {post_years[0]}-{post_years[-1]}")


# Create and fit model
model = create_synthetic_control_model(X_control, estonia_pre, estonia_post)
trace = fit_model(model, samples=1000, tune=500)  # Reduced samples for demonstration

# Print summary statistics
print("\nModel Summary:")
summary = pm.summary(trace, var_names=['beta', 'sigma', 'avg_treatment_effect'])
print(summary)

# Print control country weights 
print("\nControl Country Weights:")
# Extract beta means from the summary DataFrame
beta_means = summary.loc[[b for b in summary.index if b.startswith('beta')], 'mean'].values
for country, weight in zip(control_countries, beta_means):
    print(f"{country}: {weight:.3f}")

print(f"\nEstimated Average Treatment Effect: {summary.loc['avg_treatment_effect', 'mean']:.3f}")


In [None]:
pm.plot_posterior(trace, var_names=['avg_treatment_effect'], ref_val=ate, textsize=10)

In [None]:
pd.DataFrame(trace.posterior.treatment_effect.mean(axis=1), columns=post_years).T.plot() 

In [None]:
   # Create and display the visualization
plot_treatment_effect(df, trace, 'Estonia', control_countries, treatment_year)

In [None]:
# OK, let's work with the real data now

df = tfr_data.copy()
df.index = df.index.astype(int)

X_control, estonia_pre, estonia_post, control_countries, pre_years, post_years = prepare_data(
    df, 'Estonia', treatment_year
)

print("\nControl countries:", control_countries)
print(f"Pre-treatment years: {pre_years[0]}-{pre_years[-1]}")
print(f"Post-treatment years: {post_years[0]}-{post_years[-1]}")


# Create and fit model
model = create_synthetic_control_model(X_control, estonia_pre, estonia_post)
trace = fit_model(model, samples=1000, tune=500)  # Reduced samples for demonstration

# Print summary statistics
print("\nModel Summary:")
summary = pm.summary(trace, var_names=['beta', 'sigma', 'avg_treatment_effect'])
print(summary)

# Print control country weights 
print("\nControl Country Weights:")
# Extract beta means from the summary DataFrame
beta_means = summary.loc[[b for b in summary.index if b.startswith('beta')], 'mean'].values
for country, weight in zip(control_countries, beta_means):
    print(f"{country}: {weight:.3f}")

print(f"\nEstimated Average Treatment Effect: {summary.loc['avg_treatment_effect', 'mean']:.3f}")



In [None]:
plot_treatment_effect(df, trace, 'Estonia', control_countries, treatment_year)

In [None]:

pm.plot_posterior(trace, var_names=['avg_treatment_effect'], ref_val=0, textsize=10)

### 

<img src="./life.png">


In [None]:
import causalpy as cp

formula = target_country + " ~ " + "0 + " + " + ".join(other_countries)

formula

In [None]:
treatment_time = pd.to_datetime('2017-01-01')

#Causalpy requires the index column be in DatetimeIndex format
data = df.copy()
data.index = pd.to_datetime(data.index.astype(str)+"-01-01")

sample_kwargs = {"tune": 4000, "target_accept": 0.95, 'nuts_sampler': 'nutpie'}

result = cp.SyntheticControl(
    data,
    treatment_time,
    formula=formula,
    model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
)

In [None]:
result.plot()

In [None]:
pm.summary(result.idata)