In [1]:
import pandas as pd
# Predict trends using linear regression
from sklearn.linear_model import LinearRegression
from collections import defaultdict
# Generate plots
import matplotlib.pyplot as plt
import math

# Inputs: 
# 1) raw CPI data 
# 2) wage growth data - run make_data in wfh directory and wage_growth_wf_occ which generates the wage growth measure 

  from pandas.core import (
Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
proj_dir =  "C:/Users/singhy/Dropbox/Labor_Market_PT/replication/empirical"

# wage growth by quartile 
df = pd.read_csv(f"{proj_dir}/temp/wfh_wage_growth_by_quartile.csv")

# wage growth pooled 
df_pol = pd.read_csv(f"{proj_dir}/temp/wfh_wage_growth_pooled.csv")

cpi = pd.read_csv(f"{proj_dir}/inputs/raw_data/CPI/CPIAUCSL.csv")


In [3]:
df = df.merge(df_pol, on='date_monthly')

In [4]:

# Convert 'date_monthly' from '2016m1' format to datetime
df['date'] = pd.to_datetime(
    df['date_monthly'].str.extract(r'(\d{4})m(\d{1,2})')
    .apply(lambda x: f"{x[0]}-{int(x[1]):02d}", axis=1)
)


cpi['date'] = pd.to_datetime(cpi['observation_date'])
cpi = cpi.rename(columns={'CPIAUCSL': 'P'})
cpi['P'] = pd.to_numeric(cpi['P'], errors='coerce')
cpi['P_12m_change'] = cpi['P'].pct_change(periods=12) * 100
cpi['P_1m_change'] = 1 + (cpi['P_12m_change'] / 100) / 12
cpi = cpi[['date', 'P_1m_change']]

    # Merge CPI with wage growth data
df = df.merge(cpi, on='date', how='left')
    
# Compute monthly wage growth factors
wage_columns = [col for col in df.columns if col.startswith('smwg')]
for col in wage_columns:
    df[f'{col}_mom_grth'] = 1 + (df[col] / 100) / 12

# Compute nominal wage indices
for col in wage_columns:
    df[f'nom_index_{col}'] = df[f'{col}_mom_grth'].cumprod()

# Compute price index
df['price_index'] = df['P_1m_change'].cumprod()

# Compute real wage indices
for col in wage_columns:
        df[f'real_index_{col}'] = df[f'nom_index_{col}'] / df['price_index']

# Select final columns
result_cols = ['date', 'price_index'] + [f'real_index_{col}' for col in wage_columns]
result_df = df[result_cols]




In [5]:
# Recalculate trend dataframe based on the existing df
start_date = pd.to_datetime("2016-01-01")
end_date = pd.to_datetime("2019-12-31")

trend_df = df[(df['date'] >= start_date) & (df['date'] <= end_date)]
X_trend = ((trend_df['date'].dt.year - trend_df['date'].min().year) * 12 +
           (trend_df['date'].dt.month - trend_df['date'].min().month)).values.reshape(-1, 1)
X_all = ((df['date'].dt.year - trend_df['date'].min().year) * 12 +
         (df['date'].dt.month - trend_df['date'].min().month)).values.reshape(-1, 1)

# Identify real wage index columns
real_index_cols = [col for col in df.columns if col.startswith('real_index_')]

predicted = {}
for col in real_index_cols + ['price_index']:
    y = trend_df[col].values
    model = LinearRegression()
    model.fit(X_trend, y)
    predicted[col] = model.predict(X_all)
    df[f'predicted_{col}'] = predicted[col]

# Compute final gaps between actual and trend
gaps = {
    col: df[f'predicted_{col}'].iloc[-1] - df[col].iloc[-1]
    for col in real_index_cols + ['price_index']
}


In [7]:
# Define relevant columns
gap_columns = {
    'WFH_1st_Quartile': ('real_index_smwg1st_high_wfh', 'predicted_real_index_smwg1st_high_wfh'),
    'No_WFH_1st_Quartile': ('real_index_smwg1st_no_wfh', 'predicted_real_index_smwg1st_no_wfh'),
    'WFH_4th_Quartile': ('real_index_smwg4th_high_wfh', 'predicted_real_index_smwg4th_high_wfh'),
    'No_WFH_4th_Quartile': ('real_index_smwg4th_no_wfh', 'predicted_real_index_smwg4th_no_wfh'),
    'WFH_Pooled': ('real_index_smwghigh_wfh', 'predicted_real_index_smwghigh_wfh'),
    'No_WFH_Pooled': ('real_index_smwgno_wfh', 'predicted_real_index_smwgno_wfh')
}

# Filter from Jan 2020 onward
plot_start_date = pd.to_datetime("2020-01-01")
mask = df['date'] >= plot_start_date
df_filtered = df.loc[mask].copy()

# Initialize output DataFrame
gap_df = df_filtered[['date']].copy()

# Calculate gaps
for label, (actual_col, trend_col) in gap_columns.items():
    gap = (df_filtered[actual_col] - df_filtered[trend_col]) *100
    gap.iloc[0] = 0  # normalize gap to 0 at 2020-01
    gap_df[label] = gap.values



In [None]:
gap_df.to_csv(f"{proj_dir}/outputs/processed_data/wfh_wage_plot_data.csv", index= False)


# Define columns of interest
cols = {
    'WFH 1st Quartile': 'real_index_smwgwfh_1st_high',
    'No WFH 1st Quartile': 'real_index_smwgwfh_1st_low',
    'WFH 4th Quartile': 'real_index_smwgwfh_4th_high',
    'No WFH 4th Quartile': 'real_index_smwgwfh_4th_low',
    'WFH Pooled': 'real_index_smwghigh_wfh',
    'No WFH Pooled': 'real_index_smwglow_wfh'
}

# Create the deviation columns and align
gap_df = pd.DataFrame({'date': df['date']})
plot_start_date = pd.to_datetime("2020-01-01")
gap_df = gap_df[gap_df['date'] >= plot_start_date].copy()

for label, col in cols.items():
    deviation = df[col] - df[f'predicted_{col}']
    deviation.loc[df['date'] == plot_start_date] = 0
    gap_df[label] = deviation.loc[df['date'] >= plot_start_date].values

gap_df.reset_index(drop=True, inplace=True)


plot_start_date = pd.to_datetime("2020-01-01")
df_plot = df[df['date'] >= plot_start_date]

# Populate quartile_groups from real_index_cols
quartile_groups = defaultdict(dict)

for col in real_index_cols:
    parts = col.split("_")
    quartile = parts[2]   # '1st', '2nd', etc.
    wfh_status = parts[3] # 'high' or 'low'
    label = 'High WFH' if wfh_status == 'high' else 'Low WFH'
    quartile_groups[quartile][label] = col


ncols = 2
nrows = math.ceil(len(quartile_groups) / ncols)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 4 * nrows))
axes = axes.flatten()

for i, (quartile, wfh_group_cols) in enumerate(sorted(quartile_groups.items())):
    ax = axes[i]

    for wfh_label, col in wfh_group_cols.items():
        # Full deviation from trend
        deviation = df[col] - df[f'predicted_{col}']
        
        # Set value at 2020-01-01 to 0 (only that point)
        deviation = deviation.copy()
        idx_anchor = df['date'] == plot_start_date
        deviation.loc[idx_anchor] = 0

        # Keep only 2020 onward for plotting
        deviation_plot = deviation[df['date'] >= plot_start_date]
        date_plot = df['date'][df['date'] >= plot_start_date]

        ax.plot(date_plot, deviation_plot, label=wfh_label)

    ax.axhline(0, color='black', linewidth=1, linestyle='--')
    ax.set_title(f"Quartile {quartile.upper()}", fontsize=14)
    ax.set_ylim(-0.1, 0.05)
    ax.legend(fontsize=12)
    ax.tick_params(labelsize=12)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Hide unused subplots
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

fig.tight_layout()
fig.savefig(f"{proj_dir}/outputs/figures/wfh_wage_figurespdf")
plt.show()


# Ensure quartile_groups is populated
quartile_groups = defaultdict(dict)
for col in real_index_cols:
    parts = col.split("_")
    quartile = parts[2]   # '1st', '2nd', etc.
    wfh_status = parts[3] # 'high' or 'low'
    label = 'High WFH' if wfh_status == 'high' else 'Low WFH'
    quartile_groups[quartile][label] = col

# Loop over each quartile — but create separate figures
for quartile, wfh_group_cols in sorted(quartile_groups.items()):
    fig, ax = plt.subplots(figsize=(8, 5))

    for wfh_label, col in wfh_group_cols.items():
        # Full deviation from trend
        deviation = df[col] - df[f'predicted_{col}']
        
        # Set 2020-01-01 to zero only
        deviation = deviation.copy()
        deviation.loc[df['date'] == plot_start_date] = 0



        # Restrict to post-2020
        date_plot = df['date'][df['date'] >= plot_start_date]
        deviation_plot = deviation[df['date'] >= plot_start_date]

        ax.plot(date_plot, deviation_plot, label=wfh_label)

    ax.axhline(0, color='black', linewidth=1, linestyle='--')
    ax.set_title(f"Quartile {quartile.upper()}", fontsize=14)
    ax.set_ylim(-0.1, 0.05)
    ax.legend(fontsize=12)
    ax.tick_params(labelsize=12)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    fig.tight_layout()
    fig.savefig(f"{proj_dir}/outputs/figures/wfh_wage_quartile_{quartile}.pdf")
    plt.show()
    plt.close(fig)  # Close to avoid overlap in next figure


from collections import defaultdict
import matplotlib.pyplot as plt

# --- Group quartile-specific series ---
quartile_groups = defaultdict(dict)
for col in real_index_cols:
    if 'smwghigh_wfh' in col or 'smwglow_wfh' in col:
        continue  # Skip pooled — we'll do it separately
    parts = col.split("_")
    quartile = parts[2]   # '1st', '2nd', etc.
    wfh_status = parts[3] # 'high' or 'low'
    label = 'WFH' if wfh_status == 'high' else 'No WFH'
    quartile_groups[quartile][label] = col

plot_start_date = pd.to_datetime("2020-01-01")

# --- Plot per quartile ---
for quartile, wfh_group_cols in sorted(quartile_groups.items()):
    fig, ax = plt.subplots(figsize=(8, 5))

    for wfh_label, col in wfh_group_cols.items():
        deviation = df[col] - df[f'predicted_{col}']
        deviation = deviation.copy()
        deviation.loc[df['date'] == plot_start_date] = 0
        mask = df['date'] >= plot_start_date
        ax.plot(df.loc[mask, 'date'], deviation.loc[mask], label=wfh_label)

    ax.axhline(0, color='black', linewidth=1, linestyle='--')
    ax.set_ylim(-0.1, 0.05)
    ax.legend(fontsize=11)
    ax.tick_params(labelsize=12)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    fig.tight_layout()
    plt.show()
    fig.savefig(f"{proj_dir}/outputs/figures/wfh_wage_quartile_{quartile}.pdf")
    plt.close(fig)

# --- Separate plot for pooled series ---
fig, ax = plt.subplots(figsize=(8, 5))
pooled_series = {
    'WFH': 'real_index_smwghigh_wfh',
    'No WFH': 'real_index_smwglow_wfh'
}

for label, col in pooled_series.items():
    deviation = df[col] - df[f'predicted_{col}']
    deviation = deviation.copy()
    deviation.loc[df['date'] == plot_start_date] = 0
    mask = df['date'] >= plot_start_date
    ax.plot(df.loc[mask, 'date'], deviation.loc[mask], label=label)

ax.axhline(0, color='black', linewidth=1, linestyle='--')
ax.set_ylim(-0.1, 0.05)
ax.legend(fontsize=11)
ax.tick_params(labelsize=12)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

fig.tight_layout()
fig.savefig(f"{proj_dir}/outputs/figures/wfh_wage_pooled.pdf")
plt.show()
plt.close(fig)
