In [144]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as  plt
import seaborn as sns
import torch

In [146]:
X_train= pd.read_csv("data/preprocessed/X_train.csv")
X_test = pd.read_csv("data/preprocessed/X_test.csv")
y_train = pd.read_csv("data/preprocessed/y_train.csv")
y_test = pd.read_csv("data/preprocessed/y_test.csv")

In [147]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [148]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape

((80000, 9), (80000, 2), (20000, 9), (20000, 2))

In [149]:
class BalancedEmailClickModel(nn.Module):
    def __init__(self, input_dims=9, hidden_dims=(32, 16), dropout_rate=0.1):
        super().__init__()

        # Simpler shared network with balanced capacity
        self.shared_network = nn.Sequential(
            nn.Linear(input_dims, hidden_dims[0]),
            nn.BatchNorm1d(hidden_dims[0]),
            nn.LeakyReLU(0.1),  # LeakyReLU to prevent dying neurons
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.BatchNorm1d(hidden_dims[1]),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout_rate)
        )

        # Simple head for open prediction
        self.opened_head = nn.Linear(hidden_dims[1], 1)

        # Simple head for click prediction that takes opening probability into account
        self.clicked_head = nn.Linear(hidden_dims[1] + 1, 1)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)  # Xavier for better initialization
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x, y_opened=None):
        h = self.shared_network(x)
        opened_logits = self.opened_head(h)

        # Conditional input for clicked head
        if y_opened is not None:
            # During training when we have ground truth opened labels
            cond = y_opened
        else:
            # During inference or validation
            cond = torch.sigmoid(opened_logits).detach()

        clicked_logits = self.clicked_head(torch.cat([h, cond], dim=1))

        return opened_logits, clicked_logits

    def loss(self, opened_logits, clicked_logits, y_opened, y_clicked, pos_weight=None):
        """Calculate balanced loss"""
        # Basic losses with class balancing
        loss_open = F.binary_cross_entropy_with_logits(opened_logits, y_opened)

        # Handle click prediction with pos_weight (class imbalance)
        if pos_weight is not None:
            loss_click = F.binary_cross_entropy_with_logits(
                clicked_logits, y_clicked, pos_weight=pos_weight)
        else:
            loss_click = F.binary_cross_entropy_with_logits(clicked_logits, y_clicked)

        # Combined loss with fixed weighting
        total_loss = 0.4 * loss_open + 0.6 * loss_click

        return loss_open, loss_click, total_loss

In [150]:
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_score, recall_score, roc_auc_score, accuracy_score, f1_score
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import random_split
# from tqdm.notebook import tqdm

In [151]:
class EmailDataset(Dataset):
    def __init__(self, X, opened, clicked):
        self.X = torch.tensor(X.values, dtype=torch.float32)
        self.y_opened = torch.tensor(opened.values, dtype=torch.float32).unsqueeze(1)
        self.y_clicked = torch.tensor(clicked.values, dtype=torch.float32).unsqueeze(1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], {"opened": self.y_opened[idx], "clicked": self.y_clicked[idx]}

In [152]:
def compute_metrics(y_true, y_score, threshold=0.5):
    """Compute classification metrics given scores and labels"""
    y_pred = [1.0 if p > threshold else 0.0 for p in y_score]

    metrics = {
        'precision': precision_score(y_true, y_pred, zero_division=0),
        'recall': recall_score(y_true, y_pred, zero_division=0),
        'f1': f1_score(y_true, y_pred, zero_division=0),
    }

    # Only compute AUC if we have both positive and negative samples
    if len(set(y_true)) > 1:
        metrics['auc'] = roc_auc_score(y_true, y_score)
    else:
        metrics['auc'] = float('nan')

    return metrics

In [153]:
def create_train_val_datasets(dataset, val_ratio=0.2, seed=42):
    total_size = len(dataset)
    val_size = int(total_size * val_ratio)
    train_size = total_size - val_size

    torch.manual_seed(seed)
    train_ds, val_ds = random_split(dataset, [train_size, val_size])

    return train_ds, val_ds

In [154]:
ds = EmailDataset(X_train, y_train["opened"], y_train["clicked"])
train_ds, val_ds = create_train_val_datasets(ds, val_ratio=0.2)

In [155]:
ds[0]

(tensor([ 0.9659,  0.0000, -1.8102,  0.2588,  0.0000,  1.0000,  0.0000,  0.0000,
          0.0000]),
 {'opened': tensor([0.]), 'clicked': tensor([0.])})

In [156]:
def train_model(model, train_ds, val_ds,
                epochs=50, batch_size=64, lr=1e-3,
                warmup_epochs=5, patience=10,
                class_weight_factor=2.0):
    """
    Fixed training function that addresses overfitting issues:
    - Simplified learning rate schedule
    - Better handling of class imbalance
    - Balanced batch sampling
    - Better monitoring of metrics
    - Lower learning rate
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Data loaders
    # Compute class weights for balanced batches
    loader_tr = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    loader_val = DataLoader(val_ds, batch_size=batch_size)

    # Compute positive class weight for handling imbalance - with a safeguard against extreme values
    pos, neg = 0, 0
    for _, y in loader_tr:
        pos += y['clicked'].sum().item()
        neg += (y['clicked'].size(0) - y['clicked'].sum().item())

    # Limit the pos_weight to avoid extreme values that can destabilize training
    raw_weight = neg / max(pos, 1)
    pos_weight = torch.tensor(min(raw_weight, class_weight_factor * 10), device=device)
    print(f"Class imbalance ratio (neg/pos): {raw_weight:.2f}, using pos_weight: {pos_weight.item():.2f}")

    # Simple optimizer with reduced learning rate to prevent overfitting
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    # Simple step scheduler with warm restarts to avoid getting stuck
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # History tracking for both training and validation
    history = {
        'train': {k: [] for k in ['loss', 'opened_loss', 'clicked_loss',
                                  'precision', 'recall', 'f1', 'opened_accuracy',
                                  'clicked_accuracy', 'auc']},
        'val': {k: [] for k in ['loss', 'opened_loss', 'clicked_loss',
                               'precision', 'recall', 'f1', 'opened_accuracy',
                               'clicked_accuracy', 'auc']}
    }

    # Early stopping variables
    best_val_f1 = -float('inf')  # Track F1 score instead of just loss
    best_model_state = None
    patience_counter = 0

    for epoch in range(1, epochs+1):
        # Training phase
        model.train()
        train_losses = []
        train_opened_losses = []
        train_clicked_losses = []
        train_true_opened, train_pred_opened = [], []
        train_true_clicked, train_score_clicked = [], []

        for X, y in loader_tr:
            X = X.to(device)
            y_open = y['opened'].to(device)
            y_click = y['clicked'].to(device)

            # Teacher forcing during warmup
            is_warm = epoch <= warmup_epochs
            opened_logits, clicked_logits = model(X, y_open if is_warm else None)

            # Compute losses
            l_open, l_click, loss = model.loss(
                opened_logits, clicked_logits, y_open, y_click, pos_weight
            )

            # Backpropagation with gradient clipping
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Track metrics
            train_losses.append(loss.item())
            train_opened_losses.append(l_open.item())
            train_clicked_losses.append(l_click.item())

            # Collect predictions for opened emails
            train_true_opened.extend(y_open.cpu().numpy().flatten().tolist())
            train_pred_opened.extend((torch.sigmoid(opened_logits) > 0.5).float().cpu().numpy().flatten().tolist())

            # Collect predictions for clicked emails
            train_true_clicked.extend(y_click.cpu().numpy().flatten().tolist())
            train_score_clicked.extend(torch.sigmoid(clicked_logits).detach().cpu().numpy().flatten().tolist())

        # Calculate training metrics
        train_opened_acc = sum(1 for a, b in zip(train_true_opened, train_pred_opened) if a == b) / len(train_true_opened)
        train_clicked_metrics = compute_metrics(train_true_clicked, train_score_clicked)
        train_clicked_acc = sum(1 for a, b in zip(train_true_clicked, [1 if s > 0.5 else 0 for s in train_score_clicked])) / len(train_true_clicked)

        # Update training history
        history['train']['loss'].append(sum(train_losses) / len(train_losses))
        history['train']['opened_loss'].append(sum(train_opened_losses) / len(train_opened_losses))
        history['train']['clicked_loss'].append(sum(train_clicked_losses) / len(train_clicked_losses))
        history['train']['precision'].append(train_clicked_metrics['precision'])
        history['train']['recall'].append(train_clicked_metrics['recall'])
        history['train']['f1'].append(train_clicked_metrics['f1'])
        history['train']['auc'].append(train_clicked_metrics['auc'])
        history['train']['opened_accuracy'].append(train_opened_acc)
        history['train']['clicked_accuracy'].append(train_clicked_acc)

        # Validation phase
        model.eval()
        val_losses = []
        val_opened_losses = []
        val_clicked_losses = []
        val_true_opened, val_pred_opened = [], []
        val_true_clicked, val_score_clicked = [], []

        with torch.no_grad():
            for X, y in loader_val:
                X = X.to(device)
                y_open = y['opened'].to(device)
                y_click = y['clicked'].to(device)

                # Forward pass - no teacher forcing during validation
                opened_logits, clicked_logits = model(X)

                # Calculate losses
                l_open, l_click, loss = model.loss(
                    opened_logits, clicked_logits, y_open, y_click, pos_weight
                )

                # Track metrics
                val_losses.append(loss.item())
                val_opened_losses.append(l_open.item())
                val_clicked_losses.append(l_click.item())

                # Collect predictions for opened emails
                val_true_opened.extend(y_open.cpu().numpy().flatten().tolist())
                val_pred_opened.extend((torch.sigmoid(opened_logits) > 0.5).float().cpu().numpy().flatten().tolist())

                # Collect predictions for clicked emails
                val_true_clicked.extend(y_click.cpu().numpy().flatten().tolist())
                val_score_clicked.extend(torch.sigmoid(clicked_logits).cpu().numpy().flatten().tolist())

        # Calculate validation metrics
        val_opened_acc = sum(1 for a, b in zip(val_true_opened, val_pred_opened) if a == b) / len(val_true_opened)
        val_clicked_metrics = compute_metrics(val_true_clicked, val_score_clicked)
        val_clicked_acc = sum(1 for a, b in zip(val_true_clicked, [1 if s > 0.5 else 0 for s in val_score_clicked])) / len(val_true_clicked)

        # Update validation history
        avg_val_loss = sum(val_losses) / len(val_losses)
        history['val']['loss'].append(avg_val_loss)
        history['val']['opened_loss'].append(sum(val_opened_losses) / len(val_opened_losses))
        history['val']['clicked_loss'].append(sum(val_clicked_losses) / len(val_clicked_losses))
        history['val']['precision'].append(val_clicked_metrics['precision'])
        history['val']['recall'].append(val_clicked_metrics['recall'])
        history['val']['f1'].append(val_clicked_metrics['f1'])
        history['val']['auc'].append(val_clicked_metrics['auc'])
        history['val']['opened_accuracy'].append(val_opened_acc)
        history['val']['clicked_accuracy'].append(val_clicked_acc)

        # Update learning rate based on validation performance
        scheduler.step(avg_val_loss)

        # Print metrics
        print(f"Epoch {epoch}/{epochs}")
        print(f"  Train | loss: {history['train']['loss'][-1]:.4f} | opened_acc: {train_opened_acc:.4f} | clicked_acc: {train_clicked_acc:.4f}")
        print(f"        | precision: {train_clicked_metrics['precision']:.4f} | recall: {train_clicked_metrics['recall']:.4f} | f1: {train_clicked_metrics['f1']:.4f} | auc: {train_clicked_metrics['auc']:.4f}")
        print(f"  Val   | loss: {avg_val_loss:.4f} | opened_acc: {val_opened_acc:.4f} | clicked_acc: {val_clicked_acc:.4f}")
        print(f"        | precision: {val_clicked_metrics['precision']:.4f} | recall: {val_clicked_metrics['recall']:.4f} | f1: {val_clicked_metrics['f1']:.4f} | auc: {val_clicked_metrics['auc']:.4f}")

        # Check for NaN or error values in validation metrics
        if (val_clicked_metrics['precision'] == 0 and val_clicked_metrics['recall'] == 0) or \
           torch.isnan(torch.tensor(avg_val_loss)):
            print("Warning: Detected zero precision/recall or NaN loss, adjusting model...")
            # Reduce learning rate and continue
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5

            # If we've had multiple issues, break early
            if patience_counter > patience // 2:
                print("Training unstable, breaking early")
                break

        # Early stopping logic - based on F1 score, not just loss
        val_f1 = val_clicked_metrics['f1']
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch} epochs")
                break

    # Load best model if early stopping was triggered
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Restored best model with F1 score: {best_val_f1:.4f}")

    return model, history

In [157]:
def visualize_metrics(history):
    """Visualize training and validation metrics"""
    import matplotlib.pyplot as plt

    metrics = ['loss', 'precision', 'recall', 'f1', 'auc', 'opened_accuracy', 'clicked_accuracy']
    fig, axes = plt.subplots(len(metrics), 1, figsize=(12, 4*len(metrics)))

    for i, metric in enumerate(metrics):
        ax = axes[i]
        ax.plot(history['train'][metric], label=f'Train {metric}')
        ax.plot(history['val'][metric], label=f'Val {metric}')
        ax.set_title(f'{metric.capitalize()} vs. Epoch')
        ax.set_xlabel('Epoch')
        ax.set_ylabel(metric.capitalize())
        ax.legend()
        ax.grid(True)

    plt.tight_layout()
    return fig

In [158]:
model = BalancedEmailClickModel()

In [159]:
# Implementation to find optimal threshold
def find_optimal_threshold(model, val_loader, device):
    """Find optimal threshold for classification by maximizing F1 score"""
    model.eval()
    all_true = []
    all_scores = []

    with torch.no_grad():
        for X, y in val_loader:
            X = X.to(device)
            y_click = y['clicked'].to(device)

            # Forward pass
            _, clicked_logits = model(X)

            # Collect predictions
            all_true.extend(y_click.cpu().numpy().flatten().tolist())
            all_scores.extend(torch.sigmoid(clicked_logits).cpu().numpy().flatten().tolist())

    # Try different thresholds
    best_f1 = 0
    best_threshold = 0.5

    for threshold in [i/100 for i in range(1, 100)]:
        y_pred = [1 if score > threshold else 0 for score in all_scores]
        f1 = f1_score(all_true, y_pred, zero_division=0)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    print(f"Optimal threshold: {best_threshold:.2f} with F1 score: {best_f1:.4f}")
    return best_threshold

In [162]:
model, history = train_model(model, train_ds, val_ds, epochs=150, batch_size=128, lr=3e-4, warmup_epochs=50, patience=20)

Class imbalance ratio (neg/pos): 46.83, using pos_weight: 20.00




Epoch 1/150
  Train | loss: 0.5058 | opened_acc: 0.8946 | clicked_acc: 1.0000
        | precision: 0.1488 | recall: 0.5247 | f1: 0.2318 | auc: 0.8564
  Val   | loss: 0.5574 | opened_acc: 0.8986 | clicked_acc: 1.0000
        | precision: 0.0687 | recall: 0.2012 | f1: 0.1024 | auc: 0.7506
Epoch 2/150
  Train | loss: 0.4920 | opened_acc: 0.8956 | clicked_acc: 1.0000
        | precision: 0.1823 | recall: 0.5531 | f1: 0.2742 | auc: 0.8746
  Val   | loss: 0.5560 | opened_acc: 0.8986 | clicked_acc: 1.0000
        | precision: 0.0720 | recall: 0.1921 | f1: 0.1047 | auc: 0.7529
Epoch 3/150
  Train | loss: 0.4796 | opened_acc: 0.8955 | clicked_acc: 1.0000
        | precision: 0.1940 | recall: 0.5987 | f1: 0.2930 | auc: 0.8886
  Val   | loss: 0.5557 | opened_acc: 0.8986 | clicked_acc: 1.0000
        | precision: 0.0758 | recall: 0.1738 | f1: 0.1056 | auc: 0.7539
Epoch 4/150
  Train | loss: 0.4669 | opened_acc: 0.8959 | clicked_acc: 1.0000
        | precision: 0.2237 | recall: 0.6248 | f1: 0.3295 