# User Guide Tutorial 06: Treatment Effects

This tutorial shows how to use TemporAI `treatments` plugins.

*Skip the below cell if you are not on Google Colab / already have TemporAI installed:*

In [None]:
%pip install temporai

# Or from the repo, for the latest version:
# %pip install git+https://github.com/vanderschaarlab/temporai.git

## All `treatments` plugins

> ⚠️ The `treatments` API is preliminary and likely to change.

In the treatment effects estimation task, the goal is to predict a counterfactual outcome given an alternative treatment.

To see all the relevant plugins:

In [None]:
from tempor import plugin_loader
from rich.pretty import pprint

all_treatments_plugins = plugin_loader.list()["treatments"]

pprint(all_treatments_plugins, indent_guides=False)

Now also load data source(s) we will use:

In [None]:
DummyTemporalTreatmentEffectsDataSource = plugin_loader.get_class(
    "treatments.temporal.dummy_treatments", plugin_type="datasource"
)

## Using a temporal treatment effects plugin.

In this setting, the treatments are time series, and the outcomes are also time series.

In [None]:
from tempor import plugin_loader

dataset = DummyTemporalTreatmentEffectsDataSource(
    random_state=42,
    temporal_covariates_missing_prob=0.0,
    temporal_treatments_n_features=1,
    temporal_treatments_n_categories=2,
).load()
print(dataset)

model = plugin_loader.get("treatments.temporal.regression.crn_regressor", epochs=20)
print(model)

TemporalTreatmentEffectsDataset(
    time_series=TimeSeriesSamples([100, *, 5]),
    static=StaticSamples([100, 3]),
    predictive=TemporalTreatmentEffectsTaskData(
        targets=TimeSeriesSamples([100, *, 2]),
        treatments=TimeSeriesSamples([100, *, 1])
    )
)
CRNTreatmentsRegressor(
    name='crn_regressor',
    category='treatments.temporal.regression',
    plugin_type='method',
    params={
        'encoder_rnn_type': 'LSTM',
        'encoder_hidden_size': 100,
        'encoder_num_layers': 1,
        'encoder_bias': True,
        'encoder_dropout': 0.0,
        'encoder_bidirectional': False,
        'encoder_nonlinearity': None,
        'encoder_proj_size': None,
        'decoder_rnn_type': 'LSTM',
        'decoder_hidden_size': 100,
        'decoder_num_layers': 1,
        'decoder_bias': True,
        'decoder_dropout': 0.0,
        'decoder_bidirectional': False,
        'decoder_nonlinearity': None,
        'decoder_proj_size': None,
        'adapter_hidden_dims': [

In [None]:
# Targets:
dataset.predictive.targets

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1
sample_idx,time_idx,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0,-3.110475,-3.566948
0,1,1.528495,-0.653673
0,2,2.275307,-0.695371
0,3,4.844060,3.469371
0,4,4.420301,5.147500
...,...,...,...
99,7,5.994185,6.225290
99,8,10.913662,5.346697
99,9,9.558824,7.585175
99,10,10.194430,5.795619


In [None]:
# Treatments:
dataset.predictive.treatments

Unnamed: 0_level_0,Unnamed: 1_level_0,0
sample_idx,time_idx,Unnamed: 2_level_1
0,0,0
0,1,1
0,2,1
0,3,0
0,4,0
...,...,...
99,7,1
99,8,1
99,9,0
99,10,0


In [None]:
# Train.
model.fit(dataset);

Preparing data for decoder training...


Preparing data for decoder training DONE.


=== Training stage: 1. Train encoder ===
Epoch: 0, Prediction Loss: 68.610, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 68.610
Epoch: 1, Prediction Loss: 31.291, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 31.291


Epoch: 2, Prediction Loss: 19.947, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 19.947
Epoch: 3, Prediction Loss: 18.710, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 18.710
Epoch: 4, Prediction Loss: 16.495, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 16.495


Epoch: 5, Prediction Loss: 11.959, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 11.959
Epoch: 6, Prediction Loss: 8.338, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 8.338
Epoch: 7, Prediction Loss: 6.868, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 6.868


Epoch: 8, Prediction Loss: 5.521, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 5.521
Epoch: 9, Prediction Loss: 4.948, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.948
Epoch: 10, Prediction Loss: 4.676, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.676


Epoch: 11, Prediction Loss: 4.263, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.263
Epoch: 12, Prediction Loss: 4.067, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.067
Epoch: 13, Prediction Loss: 4.225, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.225


Epoch: 14, Prediction Loss: 3.905, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.905
Epoch: 15, Prediction Loss: 3.924, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.924
Epoch: 16, Prediction Loss: 3.815, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.815


Epoch: 17, Prediction Loss: 3.761, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.761
Epoch: 18, Prediction Loss: 3.768, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.768
Epoch: 19, Prediction Loss: 3.823, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.823
=== Training stage: 2. Train decoder ===


Epoch: 0, Prediction Loss: 33.920, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 33.920


Epoch: 1, Prediction Loss: 4.568, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.568


Epoch: 2, Prediction Loss: 3.814, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.814


Epoch: 3, Prediction Loss: 3.769, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.769


Epoch: 4, Prediction Loss: 3.752, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.752


Epoch: 5, Prediction Loss: 3.735, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.735


Epoch: 6, Prediction Loss: 3.714, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.714


Epoch: 7, Prediction Loss: 3.733, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.733


Epoch: 8, Prediction Loss: 3.697, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.697


Epoch: 9, Prediction Loss: 3.762, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.762


Epoch: 10, Prediction Loss: 3.675, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.675


Epoch: 11, Prediction Loss: 3.726, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.726


Epoch: 12, Prediction Loss: 3.763, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.763


Epoch: 13, Prediction Loss: 3.689, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.689


Epoch: 14, Prediction Loss: 3.686, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.686


Epoch: 15, Prediction Loss: 3.659, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.659


Epoch: 16, Prediction Loss: 3.645, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.645


Epoch: 17, Prediction Loss: 3.742, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.742


Epoch: 18, Prediction Loss: 3.685, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.685


Epoch: 19, Prediction Loss: 3.652, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.652


In [None]:
# Predict counterfactuals:

import numpy as np

dataset = dataset[:5]

# Define horizons for each sample.
horizons = [tc.time_indexes()[0][len(tc.time_indexes()[0]) // 2 :] for tc in dataset.time_series]
print("Horizons for sample 0:\n", horizons[0], end="\n\n")

# Define treatment scenarios for each sample.
treatment_scenarios = [[np.asarray([1] * len(h)), np.asarray([0] * len(h))] for h in horizons]
print("Alternative treatment scenarios for sample 0:\n", treatment_scenarios[0], end="\n\n")

# Call predict_counterfactuals.
counterfactuals = model.predict_counterfactuals(dataset, horizons=horizons, treatment_scenarios=treatment_scenarios)
print("Counterfactual outcomes for sample 0, given the alternative treatment scenarios:\n")
for idx, c in enumerate(counterfactuals[0]):
    print(f"Treatment scenario {idx}, {treatment_scenarios[0][idx]}")
    print(c, end="\n\n")

Horizons for sample 0:
 [5, 6, 7, 8, 9, 10]

Alternative treatment scenarios for sample 0:
 [array([1, 1, 1, 1, 1, 1]), array([0, 0, 0, 0, 0, 0])]

Counterfactual outcomes for sample 0, given the alternative treatment scenarios:

Treatment scenario 0, [1 1 1 1 1 1]
TimeSeries() with data:
                 0         1
time_idx                    
5         6.476208  4.946198
6         6.451749  4.914784
7         6.451779  4.914852
8         6.451779  4.914852
9         6.451779  4.914852
10        6.451779  4.914852

Treatment scenario 1, [0 0 0 0 0 0]
TimeSeries() with data:
                 0         1
time_idx                    
5         6.835220  4.965306
6         6.832456  4.960526
7         6.832483  4.960564
8         6.832483  4.960563
9         6.832483  4.960563
10        6.832483  4.960563



## 🎉 Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards *Machine learning and AI for Medicine*, you can do so in the following ways!



### ⭐ Star [TemporAI](https://github.com/vanderschaarlab/temporai) on GitHub

- The easiest way to help our community is by just starring the repos! This helps raise awareness of the tools we're building.



### Check out other projects from [vanderschaarlab](https://github.com/vanderschaarlab)
- 📝 [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- 📊 [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
- 🤖 [SynthCity](https://github.com/vanderschaarlab/synthcity)
 