In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import random
from sklift.metrics import qini_auc_score, uplift_auc_score
from causalml.dataset import make_uplift_classification_logistic

import matplotlib.pyplot as plt
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [2]:

class SampleComparisonNet(nn.Module):
    def __init__(self, input_dim, shared_hidden=128, embedding_dim=64):
        super(SampleComparisonNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=shared_hidden),
            nn.ReLU(),
            nn.Linear(in_features=shared_hidden, out_features=shared_hidden),
            nn.ReLU(),
            nn.Linear(in_features=shared_hidden, out_features=embedding_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=shared_hidden),
            nn.ReLU(),
            nn.Linear(in_features=shared_hidden, out_features=shared_hidden),
            nn.ReLU(),
            nn.Linear(in_features=shared_hidden, out_features=1)
        )

    def forward(self, inputs):
        encoded = self.encoder(inputs)
        decoded = self.decoder(encoded)
        return decoded
    
    def get_embeddings(self, inputs):
        # Return embeddings. No gradient needed.
        with torch.no_grad():
            embeddings = self.encoder(inputs)
        return embeddings

    def compute_similarity(self, embeddings, tags, top_k=5):
        # For each sample, get topK indices of other-group samples by cosine similarity.
        normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
        cosine_similarity = torch.matmul(normalized_embeddings, normalized_embeddings.T)
        tag = tags.float().view(-1, 1)
        mask = (tag != tag.t())
        masked_similarity = cosine_similarity.masked_fill(~mask, float('-inf'))
        _, topk_indices = torch.topk(masked_similarity, k=top_k, dim=1)
        return topk_indices

    def generate_comparison_labels(self, labels, topk_indices, tags, top_k=5):
        # For each sample, generate uplift label by group.
        label_vector = torch.squeeze(labels, dim=1)
        tags_vector = torch.squeeze(tags, dim=1)
        topk_labels = label_vector[topk_indices]              # [N, top_k]
        matched_y = topk_labels.float().mean(dim=1)           # [N]
        uplift_label = torch.where(
            tags_vector > 0,
            label_vector - matched_y,                         # Treatment: Y(1) - matched_Y(0)
            matched_y - label_vector                          # Control: matched_Y(1) - Y(0)
        )
        return uplift_label.unsqueeze(1)

    def compute_loss(self, predictions, comparison_labels, tags):
        assert predictions.shape == comparison_labels.shape
        loss = F.mse_loss(predictions, comparison_labels, reduction='mean')
        return loss


In [3]:

class EarlyStopper:
    def __init__(self, patience=15, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf
    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


In [4]:

def set_seed(seed):
    # Set all random seeds for reproducibility.
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


In [5]:

class SampleComparisonNetModel:
    def __init__(
        self,
        input_dim,
        shared_hidden=64,
        embedding_dim=16,
        epochs=5,
        patience=1,
        batch_size=2048,
        learning_rate=1e-4,
        data_loader_num_workers=10,
        random_seed=30,
        scheduler_patience=2,
        scheduler_factor=0.5,
        scheduler_min_lr=1e-7,
        top_k=5,
    ):
        set_seed(random_seed)
        self.model = SampleComparisonNet(input_dim, shared_hidden, embedding_dim).to(device)
        self.epochs = epochs
        self.patience = patience
        self.batch_size = batch_size
        self.num_workers = data_loader_num_workers
        self.top_k = top_k
        self.optim = torch.optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-4)
        self.train_dataloader = None
        self.valid_dataloader = None
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optim,
            mode='min',
            patience=scheduler_patience,
            factor=scheduler_factor,
            min_lr=scheduler_min_lr,
            verbose=True
        )

    def create_dataloaders(self, x, y, tags, valid_perc=None):
        # Split/train-validation loader.
        if valid_perc:
            x_train, x_test, y_train, y_test, tags_train, tags_test = train_test_split(
                x, y, tags, test_size=valid_perc, random_state=42
            )
            x_train = torch.Tensor(x_train)
            x_test = torch.Tensor(x_test)
            y_train = torch.Tensor(y_train).reshape(-1, 1)
            y_test = torch.Tensor(y_test).reshape(-1, 1)
            tags_train = torch.Tensor(tags_train).reshape(-1, 1)
            tags_test = torch.Tensor(tags_test).reshape(-1, 1)
            train_dataset = TensorDataset(x_train, y_train, tags_train)
            valid_dataset = TensorDataset(x_test, y_test, tags_test)
            self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
            self.valid_dataloader = DataLoader(valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
        else:
            x = torch.Tensor(x)
            y = torch.Tensor(y).reshape(-1, 1)
            tags = torch.Tensor(tags).reshape(-1, 1)
            train_dataset = TensorDataset(x, y, tags)
            self.train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, num_workers=self.num_workers
            )

    def fit(self, x, y, tags, valid_perc=None):
        self.create_dataloaders(x, y, tags, valid_perc)
        early_stopper = EarlyStopper(patience=self.patience, min_delta=0)
        for epoch in range(self.epochs):
            train_loss_sum = 0.0
            num_batches = 0
            for batch, (X, y_batch, tags_batch) in enumerate(self.train_dataloader):
                self.model.train()
                X, y_batch, tags_batch = X.to(device), y_batch.to(device), tags_batch.to(device)
                embeddings = self.model.get_embeddings(X)
                topk_indices = self.model.compute_similarity(embeddings, tags_batch, top_k=self.top_k)
                comparison_labels = self.model.generate_comparison_labels(y_batch, topk_indices, tags_batch, top_k=self.top_k)
                predictions = self.model(X)
                loss = self.model.compute_loss(predictions, comparison_labels, tags_batch)
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                train_loss_sum += loss.item()
                num_batches += 1
            train_loss_mean = train_loss_sum / num_batches

            if self.valid_dataloader:
                self.model.eval()
                valid_loss = self.validate_step()
                print(f"epoch: {epoch} | train_loss: {train_loss_mean:.8f} | valid_loss: {valid_loss:.8f} | lr: {self.optim.param_groups[0]['lr']:.2e}")
                self.scheduler.step(valid_loss)
                if early_stopper.early_stop(valid_loss):
                    break
            else:
                print(f"epoch: {epoch} | train_loss: {train_loss_mean:.8f} | lr: {self.optim.param_groups[0]['lr']:.2e}")

    def validate_step(self):
        valid_loss_sum = 0.0
        num_batches = 0
        with torch.no_grad():
            for batch, (X, y_batch, tags_batch) in enumerate(self.valid_dataloader):
                embeddings = self.model.get_embeddings(X.to(device))
                topk_indices = self.model.compute_similarity(embeddings, tags_batch.to(device), top_k=self.top_k)
                comparison_labels = self.model.generate_comparison_labels(y_batch.to(device), topk_indices, tags_batch.to(device), top_k=self.top_k)
                predictions = self.model(X.to(device))
                loss = self.model.compute_loss(predictions, comparison_labels, tags_batch.to(device))
                valid_loss_sum += loss.item()
                num_batches += 1
        return valid_loss_sum / num_batches

    def predict(self, x, batch_size=1024):
        # Batch prediction on test set.
        preds = []
        self.model.eval()
        with torch.no_grad():
            for i in range(0, len(x), batch_size):
                x_batch = torch.Tensor(x[i:i+batch_size]).to(device)
                pred = self.model(x_batch)
                preds.append(pred.cpu())
        return torch.cat(preds, dim=0)


In [6]:
df = pd.read_csv("synthetic_uplift_data.csv")
x_names = ['x1_informative', 'x2_informative', 'x3_informative', 'x4_informative', 'x5_informative', 'x6_informative', 'x7_informative', 'x8_informative', 'x9_informative', 'x10_informative', 'x11_redundant_linear_x9_x6_x7_x3_x5_x1_x4_x8', 'x12_redundant_linear_x7_x1_x3_x9_x5_x4_x8_x10', 'x13_redundant_linear_x8_x7_x2_x9_x3_x1_x5_x10', 'x14_redundant_linear_x4_x2_x9_x5', 'x15_redundant_linear_x1_x3_x10_x2_x4_x7_x5_x8_x9', 'x16_redundant_linear_x7_x6', 'x17_redundant_linear_x1', 'x18_redundant_linear_x2_x5_x7_x3_x1_x10_x9', 'x19_redundant_linear_x8_x6_x10_x3', 'x20_redundant_linear_x4_x7_x5_x10_x6_x9_x2', 'x21_redundant_linear_x7', 'x22_redundant_linear_x3_x10_x7', 'x23_redundant_linear_x2_x1_x10_x4_x5_x3_x8_x9_x7_x6', 'x24_redundant_linear_x8_x1', 'x25_redundant_linear_x7_x2_x1_x5_x6', 'x26_redundant_linear_x5_x3_x9_x6_x4_x8_x7_x10_x1_x2', 'x27_redundant_linear_x3_x1_x5', 'x28_redundant_linear_x6_x2_x1_x5_x3_x8_x4_x9', 'x29_redundant_linear_x10_x3', 'x30_redundant_linear_x7_x1_x10_x5_x6', 'x31_redundant_linear_x3_x6_x7', 'x32_redundant_linear_x1_x3_x10_x4_x7', 'x33_redundant_linear_x5_x4_x10_x9_x7_x3_x1_x2', 'x34_redundant_linear_x6_x7_x3_x4_x8_x2_x10_x9_x5_x1', 'x35_redundant_linear_x2_x1_x5_x6_x10_x9_x8_x7_x4_x3', 'x36_redundant_linear_x1_x2_x8_x7_x3_x9_x6_x4_x5_x10', 'x37_redundant_linear_x4_x2_x6_x3_x5_x9_x7_x10', 'x38_redundant_linear_x2_x10_x3_x1_x9', 'x39_redundant_linear_x5', 'x40_redundant_linear_x3_x10_x2_x6_x1_x5', 'x41_redundant_linear_x7_x4_x10_x2_x3_x9', 'x42_redundant_linear_x10_x2_x9', 'x43_redundant_linear_x6_x5_x1_x10_x9_x8_x2_x3', 'x44_redundant_linear_x7_x3_x5_x10_x1_x2_x6_x4_x9_x8', 'x45_redundant_linear_x3_x7_x2_x9_x6', 'x46_redundant_linear_x5_x3_x1_x9_x4', 'x47_redundant_linear_x6_x10_x9_x8_x1_x4_x3_x2_x7_x5', 'x48_redundant_linear_x7', 'x49_redundant_linear_x2_x9_x10', 'x50_redundant_linear_x5_x6_x9_x8_x7_x10_x1_x2_x3_x4', 'x51_redundant_linear_x4_x2', 'x52_redundant_linear_x8_x4_x9_x7_x2_x10', 'x53_redundant_linear_x8_x9_x3_x5_x1_x7_x6_x4_x10', 'x54_redundant_linear_x4_x5_x8_x7_x10_x1_x2_x3_x9_x6', 'x55_redundant_linear_x7_x8_x5_x2_x9_x1_x6', 'x56_redundant_linear_x5_x10_x8_x9_x1', 'x57_redundant_linear_x8', 'x58_redundant_linear_x10_x9_x2_x8_x4_x3_x6_x7_x5_x1', 'x59_redundant_linear_x2_x7_x1_x9_x3_x6_x8_x5_x4', 'x60_redundant_linear_x6_x9_x5_x3_x2_x8_x1_x10', 'x61_redundant_linear_x10_x7_x3_x8_x9_x2', 'x62_redundant_linear_x7_x1_x2_x10_x9_x8_x4_x3_x6', 'x63_redundant_linear_x3_x5_x6_x4_x2', 'x64_redundant_linear_x4_x6_x5', 'x65_redundant_linear_x5', 'x66_redundant_linear_x1_x5', 'x67_redundant_linear_x8_x10_x5_x2_x1_x3_x7_x9_x4', 'x68_redundant_linear_x4_x3_x8', 'x69_redundant_linear_x4_x3', 'x70_redundant_linear_x7_x6_x9_x8_x2_x1_x3_x10_x5_x4', 'x71_redundant_linear_x8_x7', 'x72_redundant_linear_x5_x1_x9', 'x73_redundant_linear_x2_x8_x1_x9', 'x74_redundant_linear_x7_x9_x10', 'x75_redundant_linear_x3_x10', 'x76_redundant_linear_x9_x7_x5', 'x77_redundant_linear_x1_x5_x2_x7_x6', 'x78_redundant_linear_x5_x10_x4_x6', 'x79_redundant_linear_x3_x2_x5_x8', 'x80_redundant_linear_x4_x3_x9_x10_x8_x5_x6_x2_x7_x1', 'x81_repeated_x10', 'x82_repeated_x8', 'x83_repeated_x6', 'x84_repeated_x1', 'x85_repeated_x7', 'x86_repeated_x5', 'x87_repeated_x3', 'x88_repeated_x6', 'x89_repeated_x1', 'x90_repeated_x1', 'x91_repeated_x5', 'x92_repeated_x7', 'x93_repeated_x9', 'x94_repeated_x5', 'x95_repeated_x3', 'x96_repeated_x9', 'x97_repeated_x4', 'x98_repeated_x5', 'x99_repeated_x8', 'x100_repeated_x5', 'x101_repeated_x6', 'x102_repeated_x9', 'x103_repeated_x2', 'x104_repeated_x7', 'x105_repeated_x2', 'x106_repeated_x3', 'x107_repeated_x6', 'x108_repeated_x2', 'x109_repeated_x6', 'x110_repeated_x6', 'x111_repeated_x4', 'x112_repeated_x6', 'x113_repeated_x3', 'x114_repeated_x6', 'x115_repeated_x7', 'x116_repeated_x7', 'x117_repeated_x2', 'x118_repeated_x4', 'x119_repeated_x9', 'x120_repeated_x7', 'x121_repeated_x7', 'x122_repeated_x8', 'x123_repeated_x2', 'x124_repeated_x8', 'x125_repeated_x4', 'x126_repeated_x2', 'x127_repeated_x5', 'x128_repeated_x1', 'x129_repeated_x1', 'x130_repeated_x10', 'x131_irrelevant', 'x132_irrelevant', 'x133_irrelevant', 'x134_irrelevant', 'x135_irrelevant', 'x136_irrelevant', 'x137_irrelevant', 'x138_irrelevant', 'x139_irrelevant', 'x140_irrelevant', 'x141_irrelevant', 'x142_irrelevant', 'x143_irrelevant', 'x144_irrelevant', 'x145_irrelevant', 'x146_irrelevant', 'x147_irrelevant', 'x148_irrelevant', 'x149_irrelevant', 'x150_irrelevant', 'x151_uplift', 'x152_uplift', 'x153_uplift', 'x154_uplift', 'x155_uplift', 'x156_uplift', 'x157_uplift', 'x158_uplift', 'x159_uplift', 'x160_uplift', 'x161_uplift', 'x162_uplift', 'x163_uplift', 'x164_uplift', 'x165_uplift', 'x166_uplift', 'x167_uplift', 'x168_uplift', 'x169_uplift', 'x170_uplift', 'x171_uplift', 'x172_uplift', 'x173_uplift', 'x174_uplift', 'x175_uplift', 'x176_uplift', 'x177_uplift', 'x178_uplift', 'x179_uplift', 'x180_uplift', 'x181_uplift', 'x182_uplift', 'x183_uplift', 'x184_uplift', 'x185_uplift', 'x186_uplift', 'x187_uplift', 'x188_uplift', 'x189_uplift', 'x190_uplift', 'x191_uplift', 'x192_uplift', 'x193_uplift', 'x194_uplift', 'x195_uplift', 'x196_uplift', 'x197_uplift', 'x198_uplift', 'x199_uplift', 'x200_uplift', 'x201_uplift', 'x202_uplift', 'x203_uplift', 'x204_uplift', 'x205_uplift', 'x206_uplift', 'x207_uplift', 'x208_uplift', 'x209_uplift', 'x210_uplift', 'x211_uplift', 'x212_uplift', 'x213_uplift', 'x214_uplift', 'x215_uplift', 'x216_uplift', 'x217_uplift', 'x218_uplift', 'x219_uplift', 'x220_uplift', 'x221_uplift', 'x222_uplift', 'x223_uplift', 'x224_uplift', 'x225_uplift', 'x226_uplift', 'x227_uplift', 'x228_uplift', 'x229_uplift', 'x230_uplift', 'x231_mix', 'x232_mix', 'x233_mix', 'x234_mix', 'x235_mix', 'x236_mix', 'x237_mix', 'x238_mix', 'x239_mix', 'x240_mix', 'x241_mix', 'x242_mix', 'x243_mix', 'x244_mix', 'x245_mix', 'x246_mix', 'x247_mix', 'x248_mix', 'x249_mix', 'x250_mix', 'x251_mix', 'x252_mix', 'x253_mix', 'x254_mix', 'x255_mix', 'x256_mix', 'x257_mix', 'x258_mix', 'x259_mix', 'x260_mix', 'x261_mix', 'x262_mix', 'x263_mix', 'x264_mix', 'x265_mix', 'x266_mix', 'x267_mix', 'x268_mix', 'x269_mix', 'x270_mix', 'x271_mix', 'x272_mix', 'x273_mix', 'x274_mix', 'x275_mix', 'x276_mix', 'x277_mix', 'x278_mix', 'x279_mix', 'x280_mix', 'x281_mix', 'x282_mix', 'x283_mix', 'x284_mix', 'x285_mix', 'x286_mix', 'x287_mix', 'x288_mix', 'x289_mix', 'x290_mix', 'x291_mix', 'x292_mix', 'x293_mix', 'x294_mix', 'x295_mix', 'x296_mix', 'x297_mix', 'x298_mix', 'x299_mix', 'x300_mix']

df['treatment'] = (df['treatment_group_key'] == 'treatment').astype(int)
df['true_uplift'] = df['treatment_true_effect']

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
in_features = x_names
label_feature = 'outcome'
treatment_feature = 'treatment'

scaler = StandardScaler()
X_train = scaler.fit_transform(train_df[in_features])
X_test = scaler.transform(test_df[in_features])
y_train = train_df[label_feature].values
y_test = test_df[label_feature].values
tags_train = train_df[treatment_feature].values
tags_test = test_df[treatment_feature].values


In [7]:

# ==== Model Training & Evaluation ====
seed = 42
set_seed(seed)
model = SampleComparisonNetModel(
    input_dim=X_train.shape[1],
    shared_hidden=512,
    embedding_dim=32,
    random_seed=seed,
    batch_size=20000,
    learning_rate=5e-4,
    epochs=40,
    patience=10,
    scheduler_patience=5,
    scheduler_factor=0.5,
    scheduler_min_lr=1e-10,
    top_k=8
)
model.fit(X_train, y_train, tags_train, valid_perc=0.1)
pred = model.predict(X_test).cpu().detach().numpy().reshape(-1)

analysis_df = pd.DataFrame({
    'treatment': test_df['treatment'].values,
    'label': test_df['outcome'].values,
    'uplift_pred': pred,
    'true_uplift': test_df['true_uplift'].values
})

y_true = analysis_df['label'].values
treatment = analysis_df['treatment'].values
uplift_score = analysis_df['uplift_pred'].values
true_uplift = analysis_df['true_uplift'].values

auuc = uplift_auc_score(y_true, uplift_score, treatment)
qini = qini_auc_score(y_true, uplift_score, treatment)
sqrt_pehe = np.sqrt(np.mean((uplift_score - true_uplift) ** 2))
pred_ate = np.mean(uplift_score)
true_ate = np.mean(true_uplift)
e_ate = np.abs(pred_ate - true_ate)
print(f"\n==== Results ====")
print(f"Seed {seed} | AUUC={auuc:.6f} | Qini={qini:.6f} | sqrt_PEHE={sqrt_pehe:.6f} | Îµ_ATE={e_ate:.6f}")

# ==== Optional: Distribution plot ====
plt.figure(figsize=(8,4))
sns.kdeplot(true_uplift, fill=True, label='True Uplift', color='g')
sns.kdeplot(uplift_score, fill=True, label='Model Pseudo Uplift', color='orange')
plt.xlabel("Uplift")
plt.title("True vs Model Pseudo Uplift Distribution")
plt.legend()
plt.show()


epoch: 0 | train_loss: 0.10495055 | valid_loss: 0.09927499 | lr: 5.00e-04
epoch: 1 | train_loss: 0.09920587 | valid_loss: 0.10008876 | lr: 5.00e-04
epoch: 2 | train_loss: 0.10041006 | valid_loss: 0.09995381 | lr: 5.00e-04
epoch: 3 | train_loss: 0.10084260 | valid_loss: 0.09786592 | lr: 5.00e-04
epoch: 4 | train_loss: 0.09915914 | valid_loss: 0.09592140 | lr: 5.00e-04
