### Plotting Diff-in-Diffs

In [2]:
import pandas as pd
import numpy as np
import altair as alt


In [23]:
data_merge = pd.read_csv('/Users/preetkhowaja/Documents/midssp2022/unifying/uds-2022-ids-701-team-3/20_analysis/big_merge.csv')
data_merge = data_merge.drop(['Unnamed: 0', 'Unnamed: 0.1'], axis = 1)
data_merge.head()

Unnamed: 0,sex,subprovince,region,sample_population,enrolled_total,rate_enrollment,year
0,male,Abbotabad,urban,0,0,,2004
1,male,Abbotabad,rural,60,55,0.916667,2004
2,male,Attock,urban,0,0,,2004
3,male,Attock,rural,64,60,0.9375,2004
4,male,Awaran,urban,0,0,,2004


In [46]:
# dropping NAs
data_merge = data_merge.dropna(axis = 0)


In [66]:
data_pre = data_merge[data_merge.year >= 2010]
data_post = data_merge[data_merge.year < 2007]

In [55]:
# Mean enrollment rates
data_merge.groupby(['sex', 'year'])['rate_enrollment'].mean().reset_index()

Unnamed: 0,sex,year,rate_enrollment
0,female,2004,0.461533
1,female,2005,0.510109
2,female,2006,0.569063
3,female,2007,0.335922
4,female,2008,0.563482
5,female,2010,0.565069
6,female,2011,0.564087
7,female,2012,0.622112
8,female,2013,0.607155
9,female,2014,0.642135


In [8]:
## Nick's code for confidence bands 
def get_reg_fit(data, yvar, xvar, alpha=0.05,col="blue"):
    import statsmodels.formula.api as smf

    # Grid for predicted values
    x = data.loc[pd.notnull(data[yvar]), xvar]
    xmin = x.min()
    xmax = x.max()
    step = (xmax - xmin) / 100
    grid = np.arange(xmin, xmax + step, step)
    predictions = pd.DataFrame({xvar: grid})

    # Fit model, get predictions
    model = smf.ols(f"{yvar} ~ {xvar}", data=data).fit()
    model_predict = model.get_prediction(predictions[xvar])
    predictions[yvar] = model_predict.summary_frame()["mean"]
    predictions[["ci_low", "ci_high"]] = model_predict.conf_int(alpha=alpha)

    # Build chart
    reg = alt.Chart(predictions).mark_line(color=col).encode(
        x=alt.X(xvar, axis=alt.Axis(title='Years from Policy Change')),
        y=alt.X(yvar, axis=alt.Axis(title='')))
    ci = (
        alt.Chart(predictions)
        .mark_errorband(color=col)
        .encode(
            x=xvar,
            y=alt.Y("ci_low", title=""),
            y2="ci_high",
        )
    )
    chart = ci + reg
    return predictions, chart

In [64]:
# Separating Line
data = pd.DataFrame({"a": [2007]})
sep_line = (alt.Chart(data).mark_rule(color="black", strokeDash=[10, 10]).encode(x=alt.X("a:Q", title="")))

In [67]:
# Enrolment by Gender

legend = alt.Chart(data_merge).transform_calculate(f= "'female'", m= "'male'")
scale = alt.Scale(domain=["Female", "Male"], range=['red', 'blue'])

## Female Trends
alt.data_transformers.disable_max_rows()
before = alt.Chart(
    data_merge[data_merge["sex"] == 'female'], title="Enrollment Trends Pakistan"
).encode(x="year", y=alt.Y("rate_enrollment", title="Rate of Enrollment", scale=alt.Scale(zero=False)), color=alt.Color('female:N', scale=scale, title=''))

base_female = before.transform_regression("year", "rate_enrollment").mark_line()

fit, female_pre_line = get_reg_fit(
    data_pre[data_pre["sex"] == 'female'],
    yvar="rate_enrollment",
    xvar="year",
    alpha=0.05,
    col="red"
)

fit, female_post_line = get_reg_fit(
    data_post[data_post["sex"] == 'female'],
    yvar="rate_enrollment",
    xvar="year",
    alpha=0.05,
    col="red"
)



## Male Trends
alt.data_transformers.disable_max_rows()
before = alt.Chart(
    data_merge[data_merge["sex"] == 'male'], title="Enrollment Trends Pakistan"
).encode(x="year", y=alt.Y("rate_enrollment", title="Rate of Enrollment", scale=alt.Scale(zero=False)), color=alt.Color('male:N', scale=scale, title=''))

base_male = before.transform_regression("year", "rate_enrollment").mark_line()

fit, male_pre_line = get_reg_fit(
    data_pre[data_pre["sex"] == 'male'],
    yvar="rate_enrollment",
    xvar="year",
    alpha=0.05,
    col="blue"
)

fit, male_post_line = get_reg_fit(
    data_post[data_post["sex"] == 'male'],
    yvar="rate_enrollment",
    xvar="year",
    alpha=0.05,
    col="blue"
)

plots= base_female + female_pre_line + female_post_line + base_male + male_pre_line + male_post_line + sep_line
plots