In [1]:
import torch
import math

# For reproducibility
torch.manual_seed(42)

###############################################################################
# 1. Set up parameters for ads and their age distributions
###############################################################################

# We'll generate N_PER_AD samples for each of 5 ads.
N_PER_AD = 200
ad_ids = [0, 1, 2, 3, 4]  

# For each ad, we'll define a different normal distribution for the user's age.
# (You can customize the means/stddevs as desired.)
age_distributions = {
    0: {"mean": 20.0, "std": 4.0},
    1: {"mean": 25.0, "std": 5.0},
    2: {"mean": 30.0, "std": 6.0},
    3: {"mean": 45.0, "std": 5.0},
    4: {"mean": 50.0, "std": 6.0},
}

# Define a "true" logistic function that depends on ad_id and age.
# For example:  logit = -2.0 + 0.5 * ad_id + 0.05 * age
TRUE_BIAS = -2.0
TRUE_WEIGHT_AD_ID = 0.5
TRUE_WEIGHT_AGE = 0.05


###############################################################################
# 2. Generate the synthetic samples for each ad
###############################################################################

all_ad_ids = []
all_ages = []
all_labels = []

for ad_id in ad_ids:
    dist_params = age_distributions[ad_id]
    
    # 2a. Generate random ages for this ad
    # We'll clamp ages to [0, 100] just to avoid negative or unrealistic extremes.
    ages = torch.normal(mean=dist_params["mean"], 
                        std=dist_params["std"], 
                        size=(N_PER_AD,))
    ages = torch.clamp(ages, min=0.0, max=100.0)
    
    # 2b. Create a tensor of ad IDs matching the shape of ages
    ad_id_tensor = torch.full((N_PER_AD,), float(ad_id))
    
    # 2c. Calculate the "true" logit for each sample
    logits = (TRUE_BIAS
              + TRUE_WEIGHT_AD_ID * ad_id_tensor
              + TRUE_WEIGHT_AGE * ages)
    
    # 2d. Convert logits to probabilities via sigmoid
    prob = torch.sigmoid(logits)
    
    # 2e. Sample a label (0 or 1) from a Bernoulli distribution with parameter prob
    labels = (torch.rand(N_PER_AD) < prob).float()
    
    # 2f. Append to lists
    all_ad_ids.append(ad_id_tensor)
    all_ages.append(ages)
    all_labels.append(labels)

# Concatenate all samples across the 5 ads
ad_ids_combined = torch.cat(all_ad_ids)  # shape [5*N_PER_AD]
ages_combined   = torch.cat(all_ages)    # shape [5*N_PER_AD]
y_combined      = torch.cat(all_labels)  # shape [5*N_PER_AD]


###############################################################################
# 3. Build final feature matrix X = [ad_id, age], shuffle, and split
###############################################################################

# Stack into a single tensor of shape [N, 2]
X_combined = torch.stack([ad_ids_combined, ages_combined], dim=1)

# We want to unify labels into shape [N, 1] for convenience
y_combined = y_combined.view(-1, 1)

# Shuffle the dataset indices
N = X_combined.size(0)  # total number of samples (should be 5*N_PER_AD)
indices = torch.randperm(N)

# Apply the shuffle
X_shuffled = X_combined[indices]
y_shuffled = y_combined[indices]

# Choose a split ratio, e.g. 80% train, 20% eval
train_ratio = 0.8
train_size = int(train_ratio * N)

# Split into train and eval
X_train = X_shuffled[:train_size]
y_train = y_shuffled[:train_size]
X_eval  = X_shuffled[train_size:]
y_eval  = y_shuffled[train_size:]

###############################################################################
# 4. Print summary
###############################################################################

print("Total samples:", N)
print("Training samples:", X_train.size(0))
print("Evaluation samples:", X_eval.size(0))

# Just to get a sense of what we've got:
print("\nExample training samples:")
for i in range(5):
    print(f"  X={X_train[i].tolist()}  y={y_train[i].item():.0f}")

Total samples: 1000
Training samples: 800
Evaluation samples: 200

Example training samples:
  X=[2.0, 34.427040100097656]  y=1
  X=[2.0, 27.55743980407715]  y=0
  X=[0.0, 16.924644470214844]  y=1
  X=[1.0, 22.798748016357422]  y=1
  X=[3.0, 42.41710662841797]  y=1
