## Counterfactual prediction using `CRN`

`TODO: More detail`

### 1. Get and preprocess some data

In [1]:
from clairvoyance2.datasets import dummy_dataset

data = dummy_dataset(
    n_samples = 10,
    temporal_covariates_n_features = 5,
    temporal_covariates_max_len = 30,
    temporal_covariates_missing_prob = 0.0,
    static_covariates_n_features = 4,
    static_covariates_missing_prob = 0.0,
    temporal_targets_n_features=1,
    temporal_targets_n_categories=3,
    temporal_treatments_n_features=1,
    temporal_treatments_n_categories=2,
)

data

Dataset(
    temporal_covariates=TimeSeriesSamples([10,*,5]),
    static_covariates=StaticSamples([10,4]),
    temporal_targets=TimeSeriesSamples([10,*,1]),
    temporal_treatments=TimeSeriesSamples([10,*,1]),
)

In [2]:
from clairvoyance2.preprocessing import TemporalDataOneHotEncoder

data = TemporalDataOneHotEncoder(params=dict(apply_to="temporal_targets", feature_name=0)).fit_transform(data)
data = TemporalDataOneHotEncoder(params=dict(apply_to="temporal_treatments", feature_name=0)).fit_transform(data)

In [3]:
from clairvoyance2.preprocessing import TemporalDataMinMaxScaler
data = TemporalDataMinMaxScaler().fit_transform(data)

data

Dataset(
    temporal_covariates=TimeSeriesSamples([10,*,5]),
    static_covariates=StaticSamples([10,4]),
    temporal_targets=TimeSeriesSamples([10,*,3]),
    temporal_treatments=TimeSeriesSamples([10,*,2]),
)

In [4]:
data.temporal_covariates

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4
s_idx,t_idx,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,0,0.141829,0.169919,0.006331,0.169088,0.053567
0,1,0.278062,0.137602,0.05735,0.072325,0.0592
0,2,0.231557,0.10777,0.08297,0.296445,0.14139
0,3,0.282318,0.256971,0.109456,0.157041,0.103329
0,4,0.287415,0.28098,0.145454,0.186354,0.165971

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4
s_idx,t_idx,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
9,0,0.116612,0.126294,0.020038,0.150076,0.007844
9,1,0.276629,0.155165,0.052604,0.165504,0.045183
9,2,0.214851,0.159745,0.076639,0.202486,0.107997
9,3,0.266512,0.220159,0.1121,0.206907,0.125498
9,4,0.127055,0.282298,0.1513,0.182784,0.187168


In [5]:
data.temporal_targets

Unnamed: 0_level_0,Unnamed: 1_level_0,OneHot_0_0,OneHot_0_1,OneHot_0_2
s_idx,t_idx,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0,0.0,0.0,1.0
0,1,1.0,0.0,0.0
0,2,0.0,0.0,1.0
0,3,1.0,0.0,0.0
0,4,1.0,0.0,0.0

Unnamed: 0_level_0,Unnamed: 1_level_0,OneHot_0_0,OneHot_0_1,OneHot_0_2
s_idx,t_idx,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
9,0,1.0,0.0,0.0
9,1,0.0,1.0,0.0
9,2,1.0,0.0,0.0
9,3,0.0,1.0,0.0
9,4,0.0,1.0,0.0


In [6]:
data.temporal_treatments

Unnamed: 0_level_0,Unnamed: 1_level_0,OneHot_0_0,OneHot_0_1
s_idx,t_idx,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0,0.0,1.0
0,1,1.0,0.0
0,2,0.0,1.0
0,3,1.0,0.0
0,4,1.0,0.0

Unnamed: 0_level_0,Unnamed: 1_level_0,OneHot_0_0,OneHot_0_1
s_idx,t_idx,Unnamed: 2_level_1,Unnamed: 3_level_1
9,0,1.0,0.0
9,1,1.0,0.0
9,2,1.0,0.0
9,3,1.0,0.0
9,4,1.0,0.0


### 2. Initialize and train `CRN` model

In [7]:
from clairvoyance2.treatment_effects import CRNClassifier

crn = CRNClassifier(params=dict(
    encoder_rnn_type="LSTM",
    decoder_rnn_type="GRU",
    epochs=10,
    batch_size=16,
    encoder_num_layers=2,
    decoder_num_layers=2,
))

crn

CRNClassifier(
    params:
    {
        "encoder_rnn_type": "LSTM",
        "encoder_hidden_size": 100,
        "encoder_num_layers": 2,
        "encoder_bias": True,
        "encoder_dropout": 0.0,
        "encoder_bidirectional": False,
        "encoder_nonlinearity": null,
        "encoder_proj_size": null,
        "decoder_rnn_type": "GRU",
        "decoder_hidden_size": 100,
        "decoder_num_layers": 2,
        "decoder_bias": True,
        "decoder_dropout": 0.0,
        "decoder_bidirectional": False,
        "decoder_nonlinearity": null,
        "decoder_proj_size": null,
        "adapter_hidden_dims": [
            50
        ],
        "adapter_out_activation": "Tanh",
        "predictor_hidden_dims": [],
        "predictor_out_activation": null,
        "treat_net_hidden_dims": [],
        "treat_net_out_activation": null,
        "max_len": null,
        "optimizer_str": "Adam",
        "optimizer_kwargs": {
            "lr": 0.01,
            "weight_decay": 1e-05
   

In [8]:
# NOTE: The example data is random, training here is for illustration only. 

crn.fit(data);

Preparing data for decoder training...
Preparing data for decoder training DONE.
=== Training stage: 1. Train encoder ===
Epoch: 0, Prediction Loss: 1.097, Lambda: 1.000, Treatment BR Loss: 0.693, Loss: 1.790
Epoch: 1, Prediction Loss: 1.095, Lambda: 1.000, Treatment BR Loss: 0.703, Loss: 1.798
Epoch: 2, Prediction Loss: 1.095, Lambda: 1.000, Treatment BR Loss: 0.764, Loss: 1.858
Epoch: 3, Prediction Loss: 1.097, Lambda: 1.000, Treatment BR Loss: 0.803, Loss: 1.900
Epoch: 4, Prediction Loss: 1.094, Lambda: 1.000, Treatment BR Loss: 0.807, Loss: 1.901
Epoch: 5, Prediction Loss: 1.094, Lambda: 1.000, Treatment BR Loss: 0.790, Loss: 1.884
Epoch: 6, Prediction Loss: 1.093, Lambda: 1.000, Treatment BR Loss: 0.747, Loss: 1.840
Epoch: 7, Prediction Loss: 1.093, Lambda: 1.000, Treatment BR Loss: 0.700, Loss: 1.793
Epoch: 8, Prediction Loss: 1.092, Lambda: 1.000, Treatment BR Loss: 0.695, Loss: 1.787
Epoch: 9, Prediction Loss: 1.092, Lambda: 1.000, Treatment BR Loss: 0.724, Loss: 1.815
=== Trai

### 3. Use `CRN` to make predictions (it is both a predictor and ITE model) 

In [9]:
from clairvoyance2.data.utils import time_index_utils
from clairvoyance2.interface import TimeIndexHorizon
import pandas as pd

# Predict at time indexes [5, 6, 7, 8, 9].
horizon = TimeIndexHorizon(time_index_sequence=[pd.Index([5, 6, 7, 8, 9])] * len(data))

predicted = crn.predict(data, horizon)

print("Predictions:")
predicted

Predictions:


Unnamed: 0_level_0,Unnamed: 1_level_0,OneHot_0_0,OneHot_0_1,OneHot_0_2
s_idx,t_idx,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,5,0.256463,0.135223,0.608314
0,6,0.253809,0.134143,0.612048
0,7,0.254158,0.13426,0.611582
0,8,0.25401,0.134042,0.611949
0,9,0.547625,0.280022,0.172354

Unnamed: 0_level_0,Unnamed: 1_level_0,OneHot_0_0,OneHot_0_1,OneHot_0_2
s_idx,t_idx,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
9,5,0.749591,0.125845,0.124564
9,6,0.409305,0.071024,0.519671
9,7,0.408936,0.071082,0.519982
9,8,0.408464,0.071113,0.520423
9,9,0.749566,0.126005,0.124429


### 3. Use `CRN` to predict counterfactuals

In [10]:
# Define prediction horizon for the outcomes
outcome_horizon = TimeIndexHorizon(time_index_sequence=[pd.Index([5, 6, 7, 8, 9])])

# Define treatment scenarios
treatment_scenarios = [
    pd.DataFrame({"OneHot_0_0": [1, 0, 1, 1, 1], "OneHot_0_1": [0, 1, 0, 0, 0]}, index=pd.Index([5, 6, 7, 8, 9])),
    pd.DataFrame({"OneHot_0_0": [1, 1, 1, 0, 0], "OneHot_0_1": [0, 0, 0, 1, 1]}, index=pd.Index([5, 6, 7, 8, 9])),
]

In [11]:
sample_horizon = TimeIndexHorizon(time_index_sequence=[pd.Index([5, 6, 7, 8, 9])])

crn.predict_counterfactuals(data, sample_index=0, treatment_scenarios=treatment_scenarios, horizon=outcome_horizon)

[TimeSeries() with data:
        OneHot_0_0  OneHot_0_1  OneHot_0_2
 t_idx                                    
 5        0.220458    0.612089    0.167453
 6        0.106353    0.299463    0.594184
 7        0.224979    0.606961    0.168060
 8        0.224591    0.611645    0.163764
 9        0.224594    0.611661    0.163745,
 TimeSeries() with data:
        OneHot_0_0  OneHot_0_1  OneHot_0_2
 t_idx                                    
 5        0.220458    0.612089    0.167453
 6        0.224649    0.611579    0.163772
 7        0.224593    0.611661    0.163746
 8        0.106336    0.299527    0.594137
 9        0.105096    0.293257    0.601647]