[![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vanderschaarlab/temporai/blob/main/tutorials/usage/tutorial06_treatments.ipynb)

# User Guide Tutorial 06: Treatment Effects

This tutorial shows how to use TemporAI `treatments` plugins.

## All `treatments` plugins

> ⚠️ The `treatments` API is preliminary and likely to change.

In the treatment effects estimation task, the goal is to predict a counterfactual outcome given an alternative treatment.

To see all the relevant plugins:

In [1]:
from tempor.methods import plugin_loader
from rich.pretty import pprint

all_treatments_plugins = plugin_loader.list()["treatments"]

pprint(all_treatments_plugins, indent_guides=False)

## Using a temporal treatment effects plugin.

In this setting, the treatments are time series, and the outcomes are also time series.

In [2]:
from tempor.data.datasources import DummyTemporalTreatmentEffectsDataSource
from tempor.methods import plugin_loader

dataset = DummyTemporalTreatmentEffectsDataSource(
    random_state=42,
    temporal_covariates_missing_prob=0.0,
    temporal_treatments_n_features=1,
    temporal_treatments_n_categories=2,
).load()
print(dataset)

model = plugin_loader.get("treatments.temporal.regression.crn_regressor", epochs=20)
print(model)

TemporalTreatmentEffectsDataset(
    time_series=TimeSeriesSamples([100, *, 5]),
    static=StaticSamples([100, 3]),
    predictive=TemporalTreatmentEffectsTaskData(
        targets=TimeSeriesSamples([100, *, 2]),
        treatments=TimeSeriesSamples([100, *, 1])
    )
)
CRNTreatmentsRegressor(
    name='crn_regressor',
    category='treatments.temporal.regression',
    plugin_type='method',
    params={
        'encoder_rnn_type': 'LSTM',
        'encoder_hidden_size': 100,
        'encoder_num_layers': 1,
        'encoder_bias': True,
        'encoder_dropout': 0.0,
        'encoder_bidirectional': False,
        'encoder_nonlinearity': None,
        'encoder_proj_size': None,
        'decoder_rnn_type': 'LSTM',
        'decoder_hidden_size': 100,
        'decoder_num_layers': 1,
        'decoder_bias': True,
        'decoder_dropout': 0.0,
        'decoder_bidirectional': False,
        'decoder_nonlinearity': None,
        'decoder_proj_size': None,
        'adapter_hidden_dims': [

In [3]:
# Targets:
dataset.predictive.targets

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1
sample_idx,time_idx,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0,-3.110475,-3.566948
0,1,1.528495,-0.653673
0,2,2.275307,-0.695371
0,3,4.844060,3.469371
0,4,4.420301,5.147500
...,...,...,...
99,7,5.994185,6.225290
99,8,10.913662,5.346697
99,9,9.558824,7.585175
99,10,10.194430,5.795619


In [4]:
# Treatments:
dataset.predictive.treatments

Unnamed: 0_level_0,Unnamed: 1_level_0,0
sample_idx,time_idx,Unnamed: 2_level_1
0,0,0
0,1,1
0,2,1
0,3,0
0,4,0
...,...,...
99,7,1
99,8,1
99,9,0
99,10,0


In [5]:
# Train.
model.fit(dataset);

Preparing data for decoder training...
Preparing data for decoder training DONE.
=== Training stage: 1. Train encoder ===
Epoch: 0, Prediction Loss: 73.011, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 73.011
Epoch: 1, Prediction Loss: 36.266, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 36.266
Epoch: 2, Prediction Loss: 22.669, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 22.669
Epoch: 3, Prediction Loss: 18.582, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 18.582
Epoch: 4, Prediction Loss: 18.174, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 18.174
Epoch: 5, Prediction Loss: 14.973, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 14.973
Epoch: 6, Prediction Loss: 10.067, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 10.067
Epoch: 7, Prediction Loss: 7.859, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 7.859
Epoch: 8, Prediction Loss: 6.596, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 6.596
Epoch: 9, Prediction Loss: 5.638, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 

In [6]:
# Predict counterfactuals:

import numpy as np

dataset = dataset[:5]

# Define horizons for each sample.
horizons = [tc.time_indexes()[0][len(tc.time_indexes()[0]) // 2 :] for tc in dataset.time_series]
print("Horizons for sample 0:\n", horizons[0], end="\n\n")

# Define treatment scenarios for each sample.
treatment_scenarios = [[np.asarray([1] * len(h)), np.asarray([0] * len(h))] for h in horizons]
print("Alternative treatment scenarios for sample 0:\n", treatment_scenarios[0], end="\n\n")

# Call predict_counterfactuals.
counterfactuals = model.predict_counterfactuals(dataset, horizons=horizons, treatment_scenarios=treatment_scenarios)
print("Counterfactual outcomes for sample 0, given the alternative treatment scenarios:\n")
for idx, c in enumerate(counterfactuals[0]):
    print(f"Treatment scenario {idx}, {treatment_scenarios[0][idx]}")
    print(c, end="\n\n")

Horizons for sample 0:
 [5, 6, 7, 8, 9, 10]

Alternative treatment scenarios for sample 0:
 [array([1, 1, 1, 1, 1, 1]), array([0, 0, 0, 0, 0, 0])]

Counterfactual outcomes for sample 0, given the alternative treatment scenarios:

Treatment scenario 0, [1 1 1 1 1 1]
TimeSeries() with data:
                 0         1
time_idx                    
5         6.428976  5.560657
6         6.545256  5.668885
7         6.552768  5.677176
8         6.553290  5.677758
9         6.553326  5.677799
10        6.553329  5.677802

Treatment scenario 1, [0 0 0 0 0 0]
TimeSeries() with data:
                 0         1
time_idx                    
5         6.781982  5.572202
6         6.910295  5.694051
7         6.917727  5.701763
8         6.918169  5.702227
9         6.918196  5.702255
10        6.918198  5.702257

