In [1]:
import pandas as pd
from pyaugsynth import dataprep, Synth

In [2]:
df = pd.read_csv('germany.csv')

X0, X1, Z0, Z1 = dataprep(
    foo=df,
    predictors=('gdp', 'trade', 'infrate'),
    predictors_op='mean',
    time_predictors_prior=range(1971, 1981),
    special_predictors=(
        ('industry', range(1971, 1981) ,'mean'),
        ('schooling', [1970, 1975], 'mean'),
        ('invest70', [1980], 'mean')
    ),
    dependent = 'gdp',
    unit_variable = 'country',
    time_variable = 'year',
    treatment_identifier = 'West Germany',
    controls_identifier = ('USA', 'UK','Austria', 'Belgium', 'Denmark', 'France',
        'Italy', 'Netherlands', 'Norway', 'Switzerland', 'Japan',
        'Greece', 'Portugal', 'Spain', 'Australia', 'New Zealand'),
    time_optimize_ssr = range(1981, 1991),
    time_plot = range(1960, 2004)
)

synth = Synth()
_, _, V_train, _ = synth.fit(X0, X1, Z0, Z1)

In [3]:
X0, X1, Z0, Z1 = dataprep(
    foo=df,
    predictors=('gdp', 'trade', 'infrate'),
    predictors_op='mean',
    time_predictors_prior=range(1981, 1991),
    special_predictors=(
        ('industry', range(1981, 1991) ,'mean'),
        ('schooling', [1980, 1985], 'mean'),
        ('invest80', [1980], 'mean')
    ),
    dependent='gdp',
    unit_variable='country',
    time_variable='year',
    treatment_identifier='West Germany',
    controls_identifier=('USA', 'UK','Austria', 'Belgium', 'Denmark', 'France',
        'Italy', 'Netherlands', 'Norway', 'Switzerland', 'Japan',
        'Greece', 'Portugal', 'Spain', 'Australia', 'New Zealand'),
    time_optimize_ssr = range(1960, 1990),
    time_plot = range(1960, 2004)
)

W, _, _, _ = synth.fit(X0, X1, Z0, Z1, custom_V=V_train)

In [4]:
for country, weight in zip(X0.columns, W):
    if weight > 1e-5:
        print(country, round(weight, 3))

USA 0.216
Austria 0.415
Netherlands 0.098
Switzerland 0.108
Japan 0.162
