In [None]:
project_path = "/home/jupyter"
import sys
sys.path.append(project_path)
import numpy as np
import pandas as pd
from fintrans_toolbox.src import bq_utils as bq
from fintrans_toolbox.src import table_utils as t
from google.cloud import bigquery

import matplotlib.pyplot as plt
#import seaborn as sns

In [None]:
print(np.__version__)

In [None]:
#import folium


import plotly.express as px

#from geopy.geocoders import Nominatim
from sklearn.linear_model import LinearRegression
#from scikit-learn.linear_model import LinearRegression
#from toolz import partial
#from scipy.optimize import fmin_slsqp

In [None]:
client = bigquery.Client()

sql  = "SELECT * FROM `ons-fintrans-data-prod.fintrans_visa.spend_merchant_location` WHERE merchant_location_level = 'POSTAL_AREA' and cardholder_issuing_level = 'All' and mcg = 'All' and time_period = 'Month'"
df = bq.read_bq_table_sql(client, sql)

In [None]:
my_results2 = df[["time_period_value","merchant_location","spend"]]

In [None]:
my_results2 = my_results2.loc[my_results2['merchant_location'] != 'EC'].reset_index(drop=True)


In [None]:
# Sort by time period value and location before creating an index
mr2 = my_results2.sort_values(by = ['merchant_location','time_period_value'])
mr2.head()

In [None]:
# Index spend to 2019 Jan values
mr2["index"] = (
    mr2.groupby("merchant_location")["spend"].transform(lambda x: x / x.iloc[0])
)


In [None]:
# Create wide table. Check values for 2019 Jan are 1
df = mr2.pivot(index='time_period_value', columns='merchant_location', values='index').reset_index()
df.head()

In [None]:
# Remove any columns (locations) where there is a month with an NA value
df = df.dropna(axis=1, how='any')
df.head()

In [None]:
treated_city = "RG"
treatment_year = "202205"

cities = [c for c in df.columns if c not in ['year','time_period_value']]
df['Other Cities'] = df[[c for c in cities if c != treated_city]].mean(axis=1)

In [None]:
def plot_lines(df, line1, line2, year, hline=True):
    sns.lineplot(x=df['time_period_value'], y=df[line1].values, label=line1)
    sns.lineplot(x=df['time_period_value'], y=df[line2].values, label=line2)
    plt.axvline(x=year, ls=":", color='C2', label='Self-Driving Cars', zorder=1)
    plt.legend();
    plt.title("Average revenue per day (in M$)");
    
sns.set_palette(sns.color_palette(['#f14db3', '#0dc3e2', '#443a84']))
plot_lines(df, treated_city, 'Other Cities', treatment_year)

In [None]:
def synth_predict(df, model, city, year):
    other_cities = [c for c in cities if c not in ['time_period_value', city]]
    y = df.loc[df['time_period_value'] <= year, city]
    X = df.loc[df['time_period_value'] <= year, other_cities]
    df[f'Synthetic {city}'] = model.fit(X, y).predict(df[other_cities])
    return model

df = df.drop('Other Cities', axis=1)

coef = synth_predict(df, LinearRegression(), treated_city, treatment_year).coef_

plot_lines(df, treated_city, f'Synthetic {treated_city}', treatment_year)

In [None]:
df_states = pd.DataFrame({'city': [c for c in cities if c not in ["time_period_value",treated_city]], 'ols_coef': coef})
plt.figure(figsize=(10, 9))
sns.barplot(data=df_states, x='ols_coef', y='city');

In [None]:
class SyntheticControl():
    
    # Loss function
    def loss(self, W, X, y) -> float:
        return np.sqrt(np.mean((y - X.dot(W))**2))

    # Fit model
    def fit(self, X, y):
        w_start = [1/X.shape[1]]*X.shape[1]
        self.coef_ = fmin_slsqp(partial(self.loss, X=X, y=y),
                         np.array(w_start),
                         f_eqcons=lambda x: np.sum(x) - 1,
                         bounds=[(0.0, 1.0)]*len(w_start),
                         disp=False)
        self.mse = self.loss(W=self.coef_, X=X, y=y)
        return self
    
    # Predict 
    def predict(self, X):
        return X.dot(self.coef_)

In [None]:
df_states['coef_synth'] = synth_predict(df, SyntheticControl(), treated_city, treatment_year).coef_
plot_lines(df, treated_city, f'Synthetic {treated_city}', treatment_year)

In [None]:
def plot_difference(df, city, year, vline=True, hline=True, **kwargs):
    sns.lineplot(x=df['time_period_value'], y=df[city] - df[f'Synthetic {city}'], **kwargs)
    if vline: 
        plt.axvline(x=year, ls=":", color='C2', label='Self-driving cars', lw=3, zorder=100)
        plt.legend()
    if hline: sns.lineplot(x=df['time_period_value'], y=0, lw=3, color='k', zorder=1)
    plt.title("Estimated effect of self-driving cars");

In [None]:
plot_difference(df, treated_city, treatment_year)

In [None]:
plt.figure(figsize=(10, 9))
sns.barplot(data=df_states, x='coef_synth', y='city');

In [None]:
# Look at which postcodes significantly contribute to the synthetic control model
df_states.loc[df_states['coef_synth']> 0.01]

In [None]:
# Run the synthetic control model for all postal areas. This is rather slow.
fig, ax = plt.subplots()
for city in cities:
    synth_predict(df, SyntheticControl(), city, treatment_year)
    plot_difference(df, city, treatment_year, vline=False, alpha=0.2, color='C1', lw=3)
plot_difference(df, treated_city, treatment_year)

In [None]:
# Reference mse
mse_treated = synth_predict(df, SyntheticControl(), treated_city, treatment_year).mse

# Other mse
fig, ax = plt.subplots()
for city in cities:
    mse = synth_predict(df, SyntheticControl(), city, treatment_year).mse
    if mse < 2 * mse_treated:
        plot_difference(df, city, treatment_year, vline=False, alpha=0.2, color='C1', lw=3)
plot_difference(df, treated_city, treatment_year)

In [None]:
#This line of code and below CRASHES. Suggested to choose a random sample of some cities rather than all of them

lambdas = {}
for city in cities:
    mse_pre = synth_predict(df, SyntheticControl(), city, treatment_year).mse
    mse_tot = np.mean((df[f'Synthetic {city}'] - df[city])**2)
    lambdas[city] = (mse_tot - mse_pre) / mse_pre
    
print(f"p-value: {np.mean(np.fromiter(lambdas.values(), dtype='float') > lambdas[treated_city]):.4}")

In [None]:
fig, ax = plt.subplots()
_, bins, _ = plt.hist(lambdas.values(), bins=20, color="C1");
plt.hist([lambdas[treated_city]], bins=bins)
plt.title('Ratio of $MSE_{post}$ and $MSE_{pre}$ across cities');
ax.add_artist(AnnotationBbox(OffsetImage(plt.imread('fig/miami.png'), zoom=0.25), (2.7, 1.7), frameon=False));