In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import random
from sklift.metrics import qini_auc_score, uplift_auc_score
from causalml.dataset import make_uplift_classification_logistic

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:

class SampleComparisonNet(nn.Module):
    def __init__(self, input_dim, shared_hidden=128, embedding_dim=64):
        super(SampleComparisonNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=shared_hidden),
            nn.ReLU(),
            nn.Linear(in_features=shared_hidden, out_features=shared_hidden),
            nn.ReLU(),
            nn.Linear(in_features=shared_hidden, out_features=embedding_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=shared_hidden),
            nn.ReLU(),
            nn.Linear(in_features=shared_hidden, out_features=shared_hidden),
            nn.ReLU(),
            nn.Linear(in_features=shared_hidden, out_features=1)
        )

    def forward(self, inputs):
        encoded = self.encoder(inputs)
        decoded = self.decoder(encoded)
        return decoded
    
    def get_embeddings(self, inputs):
        # Return embeddings. No gradient needed.
        with torch.no_grad():
            embeddings = self.encoder(inputs)
        return embeddings

    def compute_similarity(self, embeddings, tags, top_k=5):
        # For each sample, get topK indices of other-group samples by cosine similarity.
        normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
        cosine_similarity = torch.matmul(normalized_embeddings, normalized_embeddings.T)
        tag = tags.float().view(-1, 1)
        mask = (tag != tag.t())
        masked_similarity = cosine_similarity.masked_fill(~mask, float('-inf'))
        _, topk_indices = torch.topk(masked_similarity, k=top_k, dim=1)
        return topk_indices

    def generate_comparison_labels(self, labels, topk_indices, tags, top_k=5):
        # For each sample, generate uplift label by group.
        label_vector = torch.squeeze(labels, dim=1)
        tags_vector = torch.squeeze(tags, dim=1)
        topk_labels = label_vector[topk_indices]              # [N, top_k]
        matched_y = topk_labels.float().mean(dim=1)           # [N]
        uplift_label = torch.where(
            tags_vector > 0,
            label_vector - matched_y,                         # Treatment: Y(1) - matched_Y(0)
            matched_y - label_vector                          # Control: matched_Y(1) - Y(0)
        )
        return uplift_label.unsqueeze(1)

    def compute_loss(self, predictions, comparison_labels, tags):
        assert predictions.shape == comparison_labels.shape
        loss = F.mse_loss(predictions, comparison_labels, reduction='mean')
        return loss


In [None]:

class EarlyStopper:
    def __init__(self, patience=15, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf
    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


In [None]:

def set_seed(seed):
    # Set all random seeds for reproducibility.
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


In [None]:

class SampleComparisonNetModel:
    def __init__(
        self,
        input_dim,
        shared_hidden=64,
        embedding_dim=16,
        epochs=5,
        patience=1,
        batch_size=2048,
        learning_rate=1e-4,
        data_loader_num_workers=10,
        random_seed=30,
        scheduler_patience=2,
        scheduler_factor=0.5,
        scheduler_min_lr=1e-7,
        top_k=5,
    ):
        set_seed(random_seed)
        self.model = SampleComparisonNet(input_dim, shared_hidden, embedding_dim).to(device)
        self.epochs = epochs
        self.patience = patience
        self.batch_size = batch_size
        self.num_workers = data_loader_num_workers
        self.top_k = top_k
        self.optim = torch.optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-4)
        self.train_dataloader = None
        self.valid_dataloader = None
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optim,
            mode='min',
            patience=scheduler_patience,
            factor=scheduler_factor,
            min_lr=scheduler_min_lr,
            verbose=True
        )

    def create_dataloaders(self, x, y, tags, valid_perc=None):
        # Split/train-validation loader.
        if valid_perc:
            x_train, x_test, y_train, y_test, tags_train, tags_test = train_test_split(
                x, y, tags, test_size=valid_perc, random_state=42
            )
            x_train = torch.Tensor(x_train)
            x_test = torch.Tensor(x_test)
            y_train = torch.Tensor(y_train).reshape(-1, 1)
            y_test = torch.Tensor(y_test).reshape(-1, 1)
            tags_train = torch.Tensor(tags_train).reshape(-1, 1)
            tags_test = torch.Tensor(tags_test).reshape(-1, 1)
            train_dataset = TensorDataset(x_train, y_train, tags_train)
            valid_dataset = TensorDataset(x_test, y_test, tags_test)
            self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
            self.valid_dataloader = DataLoader(valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
        else:
            x = torch.Tensor(x)
            y = torch.Tensor(y).reshape(-1, 1)
            tags = torch.Tensor(tags).reshape(-1, 1)
            train_dataset = TensorDataset(x, y, tags)
            self.train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, num_workers=self.num_workers
            )

    def fit(self, x, y, tags, valid_perc=None):
        self.create_dataloaders(x, y, tags, valid_perc)
        early_stopper = EarlyStopper(patience=self.patience, min_delta=0)
        for epoch in range(self.epochs):
            train_loss_sum = 0.0
            num_batches = 0
            for batch, (X, y_batch, tags_batch) in enumerate(self.train_dataloader):
                self.model.train()
                X, y_batch, tags_batch = X.to(device), y_batch.to(device), tags_batch.to(device)
                embeddings = self.model.get_embeddings(X)
                topk_indices = self.model.compute_similarity(embeddings, tags_batch, top_k=self.top_k)
                comparison_labels = self.model.generate_comparison_labels(y_batch, topk_indices, tags_batch, top_k=self.top_k)
                predictions = self.model(X)
                loss = self.model.compute_loss(predictions, comparison_labels, tags_batch)
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                train_loss_sum += loss.item()
                num_batches += 1
            train_loss_mean = train_loss_sum / num_batches

            if self.valid_dataloader:
                self.model.eval()
                valid_loss = self.validate_step()
                print(f"epoch: {epoch} | train_loss: {train_loss_mean:.8f} | valid_loss: {valid_loss:.8f} | lr: {self.optim.param_groups[0]['lr']:.2e}")
                self.scheduler.step(valid_loss)
                if early_stopper.early_stop(valid_loss):
                    break
            else:
                print(f"epoch: {epoch} | train_loss: {train_loss_mean:.8f} | lr: {self.optim.param_groups[0]['lr']:.2e}")

    def validate_step(self):
        valid_loss_sum = 0.0
        num_batches = 0
        with torch.no_grad():
            for batch, (X, y_batch, tags_batch) in enumerate(self.valid_dataloader):
                embeddings = self.model.get_embeddings(X.to(device))
                topk_indices = self.model.compute_similarity(embeddings, tags_batch.to(device), top_k=self.top_k)
                comparison_labels = self.model.generate_comparison_labels(y_batch.to(device), topk_indices, tags_batch.to(device), top_k=self.top_k)
                predictions = self.model(X.to(device))
                loss = self.model.compute_loss(predictions, comparison_labels, tags_batch.to(device))
                valid_loss_sum += loss.item()
                num_batches += 1
        return valid_loss_sum / num_batches

    def predict(self, x, batch_size=1024):
        # Batch prediction on test set.
        preds = []
        self.model.eval()
        with torch.no_grad():
            for i in range(0, len(x), batch_size):
                x_batch = torch.Tensor(x[i:i+batch_size]).to(device)
                pred = self.model(x_batch)
                preds.append(pred.cpu())
        return torch.cat(preds, dim=0)


In [None]:

# ==== Synthetic Data Generation ====
df, x_names = make_uplift_classification_logistic(
    n_samples=50000,
    treatment_name=["control", "treatment"],
    y_name="outcome",
    n_classification_features=150,
    n_classification_informative=10,
    n_classification_redundant=70,
    n_classification_repeated=50,
    n_uplift_dict={"treatment": 80},
    n_mix_informative_uplift_dict={"treatment": 70},
    delta_uplift_dict={"treatment": 0.1},
    positive_class_proportion=0.05,
    random_seed=47,
    feature_association_list=['sin', 'cos', 'cubic', 'relu']*38
)


In [None]:

df['treatment'] = (df['treatment_group_key'] == 'treatment').astype(int)
df['true_uplift'] = df['treatment_true_effect']

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
in_features = x_names
label_feature = 'outcome'
treatment_feature = 'treatment'

scaler = StandardScaler()
X_train = scaler.fit_transform(train_df[in_features])
X_test = scaler.transform(test_df[in_features])
y_train = train_df[label_feature].values
y_test = test_df[label_feature].values
tags_train = train_df[treatment_feature].values
tags_test = test_df[treatment_feature].values


In [None]:

# ==== Model Training & Evaluation ====
seed = 42
set_seed(seed)
model = SampleComparisonNetModel(
    input_dim=X_train.shape[1],
    shared_hidden=512,
    embedding_dim=32,
    random_seed=seed,
    batch_size=20000,
    learning_rate=5e-4,
    epochs=40,
    patience=10,
    scheduler_patience=5,
    scheduler_factor=0.5,
    scheduler_min_lr=1e-10,
    top_k=8
)
model.fit(X_train, y_train, tags_train, valid_perc=0.1)
pred = model.predict(X_test).cpu().detach().numpy().reshape(-1)

analysis_df = pd.DataFrame({
    'treatment': test_df['treatment'].values,
    'label': test_df['outcome'].values,
    'uplift_pred': pred,
    'true_uplift': test_df['true_uplift'].values
})

y_true = analysis_df['label'].values
treatment = analysis_df['treatment'].values
uplift_score = analysis_df['uplift_pred'].values
true_uplift = analysis_df['true_uplift'].values

auuc = uplift_auc_score(y_true, uplift_score, treatment)
qini = qini_auc_score(y_true, uplift_score, treatment)
sqrt_pehe = np.sqrt(np.mean((uplift_score - true_uplift) ** 2))
pred_ate = np.mean(uplift_score)
true_ate = np.mean(true_uplift)
e_ate = np.abs(pred_ate - true_ate)
print(f"\n==== Results ====")
print(f"Seed {seed} | AUUC={auuc:.6f} | Qini={qini:.6f} | sqrt_PEHE={sqrt_pehe:.6f} | Îµ_ATE={e_ate:.6f}")