In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.formula.api as smf

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sns.set_theme()
plt.rcParams['figure.figsize'] = [8,8]

In [None]:
enroll = pd.read_csv("../datasets/enrollment.csv")
enroll.head()

In [None]:
enroll.Births = enroll.Births / 1000000
enroll.Enrollment = enroll.Enrollment / 1000000
enroll.head()

In [None]:
trend = enroll.query("EnrollmentYear > 1980 & EnrollmentYear < 2019")
sns.lineplot( data=trend, x="EnrollmentYear", y="Enrollment")
sns.scatterplot( data=trend, x="EnrollmentYear", y="Enrollment")
plt.title("Undergraduate enrollments, 1985 to 2018")
plt.xlabel("Year")
plt.ylabel("Enrollment (millions)")
# plt.savefig("annual_enrollments.png")

In [None]:
train = enroll.query("BirthYear < 1999")
sns.scatterplot(data=train, x="Births", y="Enrollment")
plt.title("Enrollments versus Births 20 years earlier")
plt.xlabel("Births (millions)")
plt.ylabel("Enrollment (millions)")
# plt.savefig("enrollment_births.png")

In [None]:
train = train.assign( BirthDecade = train.BirthYear - (train.BirthYear % 10) )
train.BirthDecade = train.BirthDecade.astype("category")
sns.scatterplot(data=train, x="Births", y="Enrollment", hue="BirthDecade")
plt.xlabel("Births (millions)")
plt.ylabel("Enrollment (millions)")
# plt.savefig("enrollment_births_decade.png")

In [None]:
train = train.query("BirthYear > 1978")
train.BirthDecade = train.BirthDecade.cat.remove_unused_categories()
sns.scatterplot(data=train, x="Births", y="Enrollment", hue="BirthDecade")
plt.xlabel("Births (millions)")
plt.ylabel("Enrollment (millions)")
# plt.savefig("enrollment_births_80s.png")

In [None]:
train.Births.corr( train.Enrollment )

In [None]:
train_model = smf.ols("Enrollment ~ Births", data=train)
train_fit = train_model.fit()
train_fit.params

In [None]:
train_fit.rsquared

In [None]:
np.sqrt( train_fit.scale )

In [None]:
predictions = enroll.query("BirthYear > 1998").filter(["BirthYear","Births","EnrollmentYear"])
predictions = predictions.assign( Enrollment = train_fit.predict( predictions.Births ) )
predictions

In [None]:
sns.regplot(data=train, x="Births", y="Enrollment", ci=None, line_kws={"color":"black"})
sns.scatterplot(data=predictions, x="Births", y="Enrollment")
plt.xlim(3.4,4.4)
plt.xlabel("Births (millions)")
plt.ylabel("Enrollment (millions)")
plt.title("Observed and predicted enrollments versus births")
# plt.savefig("predicted_enrollments.png")

In [None]:
predictions.query("Enrollment > 19")

In [None]:
sns.lineplot( data=trend, x="EnrollmentYear", y="Enrollment")
sns.scatterplot( data=trend, x="EnrollmentYear", y="Enrollment")
sns.lineplot( data=predictions, x="EnrollmentYear", y="Enrollment")
sns.scatterplot( data=predictions, x="EnrollmentYear", y="Enrollment")
plt.axhline( 19, linestyle=":", color="black")
plt.title("Observed and predicted enrollments, 1985 to 2038")
plt.xlabel("Year")
plt.ylabel("Enrollment (millions)")
# plt.savefig("combined_enrollments.png")