# Causal-Embeddings Training Notebook

This notebook implements a pipeline for training causal embeddings on a dataset of (cause, effect) pairs using PyTorch and e-CARE.

In [None]:
# Import necessary libraries
import os
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2Model, GPT2Config
import numpy as np
import pandas as pd
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import random
import sklearn.metrics as M
import wandb
import torch.nn.functional as F # Often needed
import matplotlib.pyplot as plt

In [None]:
# Training hyperparameters
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01 # L2 regularization
DROPOUT_PROB = 0.2 # dropout probability
# MARGIN = 0.5    # margin for contrastive loss (No longer needed for InfoNCE)
TEMPERATURE = 0.05 # scale for InfoNCE loss
NUM_EPOCHS = 20   # Increased epochs, early stopping will handle the actual number
# STEP_SIZE = 3    # step for scheduler (No longer needed for CosineAnnealingLR)
# GAMMA = 0.5     # multiplier for scheduler (No longer needed for CosineAnnealingLR)
BATCH_SIZE = 128   # batch size
EARLY_STOPPING_PATIENCE = 2 # Epochs to wait for improvement before stopping

In [None]:
# Flag: use pre-trained GPT2 embeddings or initialize them from scratch
use_pretrained_embeds = True # False — for random init

In [None]:
# Loading e-CARE dataset via HuggingFace datasets
print("Loading dataset 12ml/e-CARE...")
ec_dataset = load_dataset("12ml/e-CARE", split="train")
print(f"Example record: {ec_dataset[0]}")

In [None]:
# Initializing tokenizer and model
model_name = 'gpt2'
print(f"Model being used: {model_name}")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
if use_pretrained_embeds:
  base_model = GPT2Model.from_pretrained(model_name)
else:
  base_model = GPT2Model(GPT2Config())

import torch.nn.functional as F
class CausalEmbeddingModel(torch.nn.Module):
  def __init__(self, base_model, dropout_prob=DROPOUT_PROB):
    super().__init__()
    self.encoder = base_model
    self.dropout = torch.nn.Dropout(dropout_prob)

  def forward(self, input_ids, attention_mask):
    outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
    last_hidden = outputs.last_hidden_state
    mask = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
    summed = torch.sum(last_hidden * mask, 1)
    counts = torch.clamp(mask.sum(1), min=1e-9)
    embeddings = summed / counts
    embeddings = self.dropout(embeddings)
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings

device = torch.device("mps" if torch.mps.is_available() else "cpu")
model = CausalEmbeddingModel(base_model, dropout_prob=DROPOUT_PROB).to(device)

In [None]:
# InfoNCE Loss Function
class InfoNCELoss(torch.nn.Module):
  def __init__(self, temperature=TEMPERATURE):
    super().__init__()
    self.temperature = temperature
    self.criterion = torch.nn.CrossEntropyLoss()

  def forward(self, emb1, emb2, labels):
    # emb1: [batch_size, embed_dim] (e.g., cause embeddings)
    # emb2: [batch_size, embed_dim] (e.g., effect embeddings)
    # labels: [batch_size] (1 for positive pair, 0 for negative)

    # Normalize embeddings (already done in model, but good practice)
    emb1 = F.normalize(emb1, p=2, dim=1)
    emb2 = F.normalize(emb2, p=2, dim=1)

    # Calculate cosine similarity matrix (batch_size x batch_size)
    # sim_matrix[i, j] = similarity between emb1_i and emb2_j
    sim_matrix = torch.matmul(emb1, emb2.T) / self.temperature

    # Create labels for cross-entropy: diagonal elements are positives
    # We only calculate loss for the actual positive pairs in the input `labels`
    positive_mask = (labels == 1)
    if not positive_mask.any(): # Handle batches with no positive pairs if they occur
       return torch.tensor(0.0, device=emb1.device, requires_grad=True) # Or handle as needed

    # Select rows corresponding to positive causes
    pos_sim_matrix = sim_matrix[positive_mask]

    # Create target labels for these rows: the diagonal element (true effect) should be the target
    # The index of the positive sample for emb1_i is i
    # Need to find the original indices of the positive samples
    positive_indices = torch.where(positive_mask)[0]
    # The target label for the i-th positive sample (original index k=positive_indices[i])
    # should correspond to the column k in the full similarity matrix.
    # However, CrossEntropyLoss expects class indices relative to the input tensor (pos_sim_matrix).
    # We need to map the global positive index `k` to the row index `i` within the `pos_sim_matrix`.
    # This mapping is simply the range(number_of_positives).
    num_positives = pos_sim_matrix.size(0)
    targets = torch.arange(num_positives, device=emb1.device) # Target is the index of the positive pair within the filtered matrix

    # Adjust targets based on the actual structure if emb1[i] positive is emb2[i]
    # In our case, the positive pair for emb1[k] is emb2[k].
    # So, the target column index in sim_matrix[k, :] is k.
    # When we filter sim_matrix with positive_mask, the columns remain the same.
    # We need the column indices corresponding to the positive pairs.
    targets = positive_indices # The target column index is the original index of the positive pair

    # Filter the similarity matrix columns to only include the relevant comparisons for the positive samples
    # This is complex. Let's simplify: Use the standard CE loss where the target is the index of the positive sample.
    # For emb1_i (where label_i=1), the positive is emb2_i.
    # The logits are sim(emb1_i, emb2_j) for all j. Target is i.

    # Recalculate using the standard approach for simplicity:
    # Treat each emb1 as an anchor, its corresponding emb2 as positive (if label=1)
    # and all other emb2 in the batch as negatives.

    # Cosine similarity between corresponding pairs (potential positives)
    sim_pos = torch.diag(sim_matrix) # Similarity between emb1_i and emb2_i

    # Create logits for CrossEntropyLoss
    # For each emb1_i, the logits are [sim(emb1_i, emb2_0), sim(emb1_i, emb2_1), ..., sim(emb1_i, emb2_{N-1})]
    logits = sim_matrix

    # Create targets: for row i, the target class is i (representing emb2_i)
    targets = torch.arange(emb1.size(0), device=emb1.device)

    # Calculate loss only for the positive pairs
    loss = self.criterion(logits[positive_mask], targets[positive_mask])

    return loss

criterion = InfoNCELoss(temperature=TEMPERATURE) # Use the new loss

In [None]:
# Creating a list of triples (cause, effect, label) from ec_dataset with hard negatives
pairs = []
# Collecting all possible effects for hard negatives
all_effects = [item['choice1'] for item in ec_dataset] + [item['choice2'] for item in ec_dataset]
for item in ec_dataset:
  cause = item['premise']
  # effect_pos — correct effect from field choice1/choice2 according to label
  effect_pos = item['choice1'] if item['label'] == 0 else item['choice2']
  # effect_neg — incorrect choice
  effect_neg = item['choice2'] if item['label'] == 0 else item['choice1']
  # Positive pair
  pairs.append((cause, effect_pos, 1))
  # Negative pair (contrastive)
  pairs.append((cause, effect_neg, 0))
  # Hard negative: random effect, not equal to the correct one
  hn = random.choice([eff for eff in all_effects if eff != effect_pos])
  pairs.append((cause, hn, 0))

print(f"Total pairs: {len(pairs)}")
print('Example of positive pair:', pairs[0])
print('Example of negative pair:', pairs[1])
print('Number of positive pairs:', sum(1 for x in pairs if x[2] == 1))
print('Number of negative pairs:', sum(1 for x in pairs if x[2] == 0))

In [None]:
# Initialize Weights & Biases
wandb.init(
  project="causal-embeddings-ecare",
  config={
    "learning_rate": LEARNING_RATE,
    "weight_decay": WEIGHT_DECAY,
    "dropout_prob": DROPOUT_PROB,
    "temperature": TEMPERATURE,
    "epochs": NUM_EPOCHS,
    "batch_size": BATCH_SIZE,
    "model_name": model_name,
    "use_pretrained": use_pretrained_embeds,
    "scheduler": "CosineAnnealingLR",
    "loss": "InfoNCE",
    "early_stopping_patience": EARLY_STOPPING_PATIENCE
  }
)
# Log code for reproducibility
wandb.save('causal_embeddings.ipynb')

In [None]:
# Split into train/eval
train_data, eval_data = train_test_split(pairs, test_size=0.2, random_state=42)

class LabeledCausalDataset(Dataset):
  def __init__(self, data, tokenizer, max_length=64):
    self.data = data
    self.tokenizer = tokenizer
    self.max_length = max_length

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    cause, effect, label = self.data[idx]
    if self.tokenizer.pad_token is None:
      self.tokenizer.pad_token = self.tokenizer.eos_token
    enc1 = self.tokenizer(cause, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
    enc2 = self.tokenizer(effect, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
    return {
      'input_ids1': enc1['input_ids'].squeeze(),
      'attention_mask1': enc1['attention_mask'].squeeze(),
      'input_ids2': enc2['input_ids'].squeeze(),
      'attention_mask2': enc2['attention_mask'].squeeze(),
      'label': torch.tensor(label, dtype=torch.float),
      'cause_text': cause,
      'effect_text': effect
    }

train_dataset = LabeledCausalDataset(train_data, tokenizer)
eval_dataset = LabeledCausalDataset(eval_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-7)

In [None]:
# Training loop with contrastive loss function and metrics logging
train_losses = [] # Store epoch average train losses
eval_aucs = []  # Store epoch evaluation AUCs
best_eval_auc = 0.0
epochs_no_improve = 0

print("Starting training...")
for epoch in range(NUM_EPOCHS):
  model.train()
  total_loss = 0
  batch_losses = [] # Track losses within an epoch
  
  for batch_idx, batch in enumerate(train_loader):
    optimizer.zero_grad()
    emb1 = model(batch['input_ids1'].to(device), batch['attention_mask1'].to(device))
    emb2 = model(batch['input_ids2'].to(device), batch['attention_mask2'].to(device))
    labels = batch['label'].to(device)
    
    # Ensure there are positive samples for InfoNCE calculation in the batch
    if (labels == 1).sum() == 0:
      print(f"Skipping batch {batch_idx+1} due to no positive samples.")
      continue # Skip batch if no positive pairs exist

    loss = criterion(emb1, emb2, labels)

    # Check for NaN loss
    if torch.isnan(loss):
      print(f"NaN loss detected at Epoch {epoch+1}, Batch {batch_idx+1}. Skipping batch.")
      # Potentially log problematic batch data here if needed
      wandb.log({"problem_batch_data": wandb.Table(data=[batch['cause_text'], batch['effect_text'], batch['label'].tolist()])})
      continue # Skip optimizer step and loss accumulation
    
    loss.backward()
    # Gradient clipping (optional but often helpful)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    
    current_loss = loss.item()
    total_loss += current_loss
    batch_losses.append(current_loss)
    
    # Log batch loss and learning rate to W&B
    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({
      "batch_loss": current_loss,
      "learning_rate": current_lr,
      "epoch": epoch + (batch_idx + 1) / len(train_loader) # Log fractional epoch
    })
    
    if (batch_idx + 1) % 50 == 0: # Print progress every 50 batches
      print(f"Epoch {epoch+1} Batch {batch_idx+1}/{len(train_loader)} - Loss {current_loss:.4f}, LR: {current_lr:.2e}")

  avg_loss = total_loss / len(batch_losses) if batch_losses else 0 # Avoid division by zero
  train_losses.append(avg_loss) # Append average loss for the epoch

  # Evaluation step
  model.eval()
  y_true, y_scores = [], []
  with torch.no_grad():
    for batch in eval_loader:
      emb1 = model(batch['input_ids1'].to(device), batch['attention_mask1'].to(device))
      emb2 = model(batch['input_ids2'].to(device), batch['attention_mask2'].to(device))
      sim = F.cosine_similarity(emb1, emb2, dim=1).cpu().numpy()
      y_scores.extend(sim.tolist())
      y_true.extend(batch['label'].numpy().tolist())
  
  # Ensure evaluation data is valid before calculating AUC
  if len(np.unique(y_true)) > 1: # Check if there's more than one class present
    auc = M.roc_auc_score(y_true, y_scores)
  else:
    print(f"Warning: Only one class present in evaluation labels for Epoch {epoch+1}. AUC cannot be calculated.")
    auc = 0.0 # Or handle as appropriate, e.g., np.nan

  eval_aucs.append(auc) # Append AUC for the epoch
  
  # Log epoch metrics to W&B
  wandb.log({
    "epoch": epoch + 1, # Log integer epoch for epoch-level metrics
    "train_loss_epoch": avg_loss,
    "eval_auc_epoch": auc
  })
  
  print(f"Epoch {epoch+1}, Avg Train Loss: {avg_loss:.4f}, Eval AUC: {auc:.4f}")
  
  # Early Stopping Check
  if auc > best_eval_auc:
    best_eval_auc = auc
    epochs_no_improve = 0
    # Save the best model checkpoint (optional)
    best_model_path = os.path.join(wandb.run.dir, 'best_model.pth') # Save in W&B run dir
    torch.save(model.state_dict(), best_model_path)
    print(f"New best model saved with AUC: {best_eval_auc:.4f}")
    wandb.save(best_model_path) # Save best model to W&B
  else:
    epochs_no_improve += 1
    print(f"No improvement in Eval AUC for {epochs_no_improve} epoch(s).")

  if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
    print(f"Early stopping triggered after {epoch + 1} epochs.")
    break # Exit training loop

  # Step the scheduler (CosineAnnealingLR steps per epoch)
  scheduler.step()

print("Training finished.")
# Log final best AUC
wandb.summary["best_eval_auc"] = best_eval_auc

In [None]:
# Final evaluation on the eval set with cosine similarity
model.eval()
y_true, y_scores = [], []
with torch.no_grad():
  for batch in eval_loader:
    emb1 = model(batch['input_ids1'].to(device), batch['attention_mask1'].to(device))
    emb2 = model(batch['input_ids2'].to(device), batch['attention_mask2'].to(device))
    sim = torch.nn.functional.cosine_similarity(emb1, emb2, dim=1).cpu().numpy()
    y_scores.extend(sim.tolist())
    y_true.extend(batch['label'].numpy().tolist())
auc = M.roc_auc_score(y_true, y_scores)
print(f"Eval AUC after e-CARE training: {auc:.4f}")
wandb.summary["final_eval_auc_ecare"] = auc # Log final e-CARE AUC to the first run

In [None]:
# Save the trained causal embeddings in ../models
os.makedirs(os.path.join(os.getcwd(), '..', 'models'), exist_ok=True)
embeds = model.encoder.wte.weight.data.cpu()
torch.save(embeds, os.path.join(os.getcwd(), '..', 'models', 'causal_embeds.pth'))
print("Causal embeddings saved to ../models/causal_embeds.pth")

#### Phase 2: Fine-tuning on Atomic and CNC

Now we will continue training the model using data from Atomic and CNC datasets. We will use the same model weights that were obtained after training on e-CARE.

In [None]:
# Optimized creation of atomic_pairs
print("Loading and processing Atomic dataset...")
atomic_df = pd.read_csv('../datasets/atomic_causal_pairs.csv')
atomic_df.dropna(subset=['cause', 'effect'], inplace=True)

cause_to_effects = atomic_df.groupby('cause')['effect'].apply(set).to_dict()
all_atomic_effects_set = set(atomic_df['effect'].unique())
print(f"Found unique causes in Atomic: {len(cause_to_effects)}")
print(f"Found unique effects in Atomic: {len(all_atomic_effects_set)}")

# Efficient creation of positive and negative pairs
atomic_pairs = []
for cause, effects in cause_to_effects.items():
    # Positive pairs
    for effect_pos in effects:
        atomic_pairs.append((cause, effect_pos, 1))
    # Preparing candidates for negative examples
    total_negs = len(effects) * 2
    neg_cands = list(all_atomic_effects_set - effects)
    if len(neg_cands) >= total_negs:
        neg_samples = random.sample(neg_cands, total_negs)
    else:
        fallback = neg_cands if neg_cands else list(all_atomic_effects_set)
        neg_samples = random.choices(fallback, k=total_negs)
    # Assigning two negative effects for each positive pair
    for idx in range(len(effects)):
        for neg_effect in neg_samples[2*idx:2*idx+2]:
            atomic_pairs.append((cause, neg_effect, 0))

print(f"Created pairs from Atomic: {len(atomic_pairs)}")
print('Example of positive pair (Atomic):', next((p for p in atomic_pairs if p[2] == 1), None))
print('Example of negative pair (Atomic):', next((p for p in atomic_pairs if p[2] == 0), None))
print('Number of positive pairs (Atomic):', sum(1 for x in atomic_pairs if x[2] == 1))
print('Number of negative pairs (Atomic):', sum(1 for x in atomic_pairs if x[2] == 0))

In [None]:
# Loading and processing the CNC dataset
print("\nLoading and processing CNC dataset...")
cnc_df = pd.read_csv('../datasets/cnc_causal_pairs.csv')
cnc_df.dropna(subset=['cause', 'effect', 'is_causal'], inplace=True)
cnc_df = cnc_df[cnc_df['is_causal'] == 1]

cause_to_effects_cnc = cnc_df.groupby('cause')['effect'].apply(set).to_dict()
all_cnc_effects_set = set(cnc_df['effect'].unique())
print(f"Found unique causes in CNC: {len(cause_to_effects_cnc)}")
print(f"Found unique effects in CNC: {len(all_cnc_effects_set)}")

cnc_pairs = []
for cause, effects in cause_to_effects_cnc.items():
    # Positive pairs
    for effect_pos in effects:
        cnc_pairs.append((cause, effect_pos, 1))
    # Preparing candidates for negative examples (2 for each positive)
    total_negs = len(effects) * 2
    neg_cands = list(all_cnc_effects_set - effects)
    if len(neg_cands) >= total_negs:
        neg_samples = random.sample(neg_cands, total_negs)
    else:
        fallback = neg_cands if neg_cands else list(all_cnc_effects_set)
        neg_samples = random.choices(fallback, k=total_negs)
    # Adding two negative pairs for each positive one
    for idx in range(len(effects)):
        for neg_effect in neg_samples[2*idx:2*idx+2]:
            cnc_pairs.append((cause, neg_effect, 0))

print(f"Created pairs from CNC: {len(cnc_pairs)}")
print('Example of positive pair (CNC):', next((p for p in cnc_pairs if p[2] == 1), None))
print('Example of negative pair (CNC):', next((p for p in cnc_pairs if p[2] == 0), None))
print('Number of positive pairs (CNC):', sum(1 for x in cnc_pairs if x[2] == 1))
print('Number of negative pairs (CNC):', sum(1 for x in cnc_pairs if x[2] == 0))

In [None]:
# Combining data and creating DataLoaders for Phase 2
print("\nCombining data and preparing DataLoaders for Phase 2...")
combined_pairs_phase2 = atomic_pairs + cnc_pairs
random.shuffle(combined_pairs_phase2) # Shuffling the combined dataset

print(f"Total pairs for Phase 2: {len(combined_pairs_phase2)}")
print('Number of positive pairs (Phase 2):', sum(1 for x in combined_pairs_phase2 if x[2] == 1))
print('Number of negative pairs (Phase 2):', sum(1 for x in combined_pairs_phase2 if x[2] == 0))

# Splitting into train/eval for Phase 2
train_data_phase2, eval_data_phase2 = train_test_split(combined_pairs_phase2, test_size=0.15, random_state=43) # Using different ratio and random_state

print(f"Training set size (Phase 2): {len(train_data_phase2)}")
print(f"Validation set size (Phase 2): {len(eval_data_phase2)}")

# Creating Dataset and DataLoader
# Using the same LabeledCausalDataset class and the same tokenizer
train_dataset_phase2 = LabeledCausalDataset(train_data_phase2, tokenizer)
eval_dataset_phase2 = LabeledCausalDataset(eval_data_phase2, tokenizer)

train_loader_phase2 = DataLoader(train_dataset_phase2, batch_size=BATCH_SIZE, shuffle=True)
eval_loader_phase2 = DataLoader(eval_dataset_phase2, batch_size=BATCH_SIZE, shuffle=False)

print("DataLoaders for Phase 2 are ready.")

In [None]:
# Initialization of a new W&B run for Phase 2
# End the previous run if it's still active (just in case)
if wandb.run is not None:
    print("Finishing previous W&B run...")
    try:
        wandb.finish()
    except Exception as e:
        print(f"Error finishing previous W&B run: {e}. Continuing...")

print("\nInitializing W&B for Phase 2 training (Atomic + CNC)...")
wandb.init(
  project="causal-embeddings-atomic-cnc", # New project or group name
  config={
    "learning_rate": LEARNING_RATE / 5, # Can reduce LR for fine-tuning # CHANGED: Lowered LR further
    "weight_decay": WEIGHT_DECAY,
    "dropout_prob": DROPOUT_PROB,
    "temperature": TEMPERATURE,
    "epochs": NUM_EPOCHS, # Using the same max number of epochs
    "batch_size": BATCH_SIZE,
    "model_name": model_name,
    "use_pretrained": use_pretrained_embeds, # Indicates initial initialization
    "training_phase": 2, # Adding phase flag
    "scheduler": "CosineAnnealingLR",
    "loss": "InfoNCE",
    "early_stopping_patience": EARLY_STOPPING_PATIENCE
  }
)

# Log code for reproducibility
wandb.save('causal_embeddings.ipynb')

# Reinitializing optimizer and scheduler for Phase 2
# Using current model parameters but with potentially modified LR
optimizer_phase2 = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE / 5, weight_decay=WEIGHT_DECAY) # CHANGED: Lowered LR further
scheduler_phase2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_phase2, T_max=NUM_EPOCHS, eta_min=1e-8) # Can set eta_min even lower

# Reset variables for early stopping and metrics
train_losses_phase2 = [] 
eval_aucs_phase2 = []  
best_eval_auc_phase2 = 0.0 # Starting from zero for this phase
epochs_no_improve_phase2 = 0

# Using the same criterion
# criterion = InfoNCELoss(temperature=TEMPERATURE) # Already defined earlier

print("Optimizer, scheduler and variables for Phase 2 initialized.")

In [None]:
# Training loop for Phase 2 (Atomic + CNC)
print("\nStarting training Phase 2 (Atomic + CNC)...")
# Make sure the model is in training mode
model.to(device) # Move the model to the appropriate device just in case
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss_phase2 = 0
    batch_losses_phase2 = [] 
  
    for batch_idx, batch in enumerate(train_loader_phase2):
        optimizer_phase2.zero_grad()
        
        # Move batch data to device
        input_ids1 = batch['input_ids1'].to(device)
        attention_mask1 = batch['attention_mask1'].to(device)
        input_ids2 = batch['input_ids2'].to(device)
        attention_mask2 = batch['attention_mask2'].to(device)
        labels = batch['label'].to(device)

        emb1 = model(input_ids1, attention_mask1)
        emb2 = model(input_ids2, attention_mask2)
        
        # Ensure there are positive samples for InfoNCE calculation in the batch
        if (labels == 1).sum() == 0:
            print(f"Phase 2 - Skipping batch {batch_idx+1} due to no positive samples.")
            continue # Skip batch if no positive pairs exist

        loss = criterion(emb1, emb2, labels)
        
        # Check for NaN loss
        if torch.isnan(loss):
            print(f"Phase 2 - NaN loss detected at Epoch {epoch+1}, Batch {batch_idx+1}. Skipping batch.")
            # Log problematic batch data
            try:
                wandb.log({"phase2_problem_batch_data": wandb.Table(data=[batch['cause_text'], batch['effect_text'], batch['label'].tolist()])})
            except Exception as e:
                print(f"Could not log problematic batch to W&B: {e}")
            continue 
        
        print(f"Batch {batch_idx+1}")
    
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer_phase2.step()
    
    current_loss = loss.item()
    total_loss_phase2 += current_loss
    batch_losses_phase2.append(current_loss)
    
    # Log batch loss and learning rate to W&B (Phase 2)
    current_lr_phase2 = optimizer_phase2.param_groups[0]['lr']
    wandb.log({
      "phase2_batch_loss": current_loss,
      "phase2_learning_rate": current_lr_phase2,
      "phase2_epoch_frac": epoch + (batch_idx + 1) / len(train_loader_phase2) # Log fractional epoch for phase 2
    })
    
    if (batch_idx + 1) % 100 == 0: # Print progress every 100 batches
        print(f"Phase 2 - Epoch {epoch+1} Batch {batch_idx+1}/{len(train_loader_phase2)} - Loss {current_loss:.4f}, LR: {current_lr_phase2:.2e}")

    avg_loss_phase2 = total_loss_phase2 / len(batch_losses_phase2) if batch_losses_phase2 else 0 
    train_losses_phase2.append(avg_loss_phase2) 

    # Evaluation step (Phase 2)
    model.eval()
    y_true_phase2, y_scores_phase2 = [], []
    with torch.no_grad():
        for batch in eval_loader_phase2:
            # Move batch data to device
            input_ids1 = batch['input_ids1'].to(device)
            attention_mask1 = batch['attention_mask1'].to(device)
            input_ids2 = batch['input_ids2'].to(device)
            attention_mask2 = batch['attention_mask2'].to(device)
            
            emb1 = model(input_ids1, attention_mask1)
            emb2 = model(input_ids2, attention_mask2)
            sim = F.cosine_similarity(emb1, emb2, dim=1).cpu().numpy()
            y_scores_phase2.extend(sim.tolist())
            y_true_phase2.extend(batch['label'].numpy().tolist())

    # Ensure evaluation data is valid before calculating AUC
    if len(np.unique(y_true_phase2)) > 1: 
        auc_phase2 = M.roc_auc_score(y_true_phase2, y_scores_phase2)
    else:
        print(f"Phase 2 - Warning: Only one class present in evaluation labels for Epoch {epoch+1}. AUC cannot be calculated.")
        auc_phase2 = 0.0 

    eval_aucs_phase2.append(auc_phase2) 

    # Log epoch metrics to W&B (Phase 2)
    wandb.log({
        "phase2_epoch": epoch + 1, 
        "phase2_train_loss_epoch": avg_loss_phase2,
        "phase2_eval_auc_epoch": auc_phase2
    })

    print(f"Phase 2 - Epoch {epoch+1}, Avg Train Loss: {avg_loss_phase2:.4f}, Eval AUC: {auc_phase2:.4f}")
  
    # Early Stopping Check (Phase 2)
    if auc_phase2 > best_eval_auc_phase2:
        best_eval_auc_phase2 = auc_phase2
        epochs_no_improve_phase2 = 0
        # Save the best model checkpoint for phase 2
        best_model_path_phase2 = os.path.join(wandb.run.dir, 'best_model_phase2.pth') 
        torch.save(model.state_dict(), best_model_path_phase2)
        print(f"Phase 2 - New best model saved with AUC: {best_eval_auc_phase2:.4f}")
        wandb.save(best_model_path_phase2) # Save best model to W&B
    else:
        epochs_no_improve_phase2 += 1
        print(f"Phase 2 - No improvement in Eval AUC for {epochs_no_improve_phase2} epoch(s).")

    if epochs_no_improve_phase2 >= EARLY_STOPPING_PATIENCE:
        print(f"Phase 2 - Early stopping triggered after {epoch + 1} epochs.")
        break # Exit training loop

    # Step the scheduler (Phase 2)
    scheduler_phase2.step()

print("Training Phase 2 finished.")
# Log final best AUC for Phase 2
wandb.summary["phase2_best_eval_auc"] = best_eval_auc_phase2

In [None]:
# Final evaluation on Phase 2 eval set
print("\nFinal Evaluation on Phase 2 Eval Set...")
model.eval()
y_true_final_p2, y_scores_final_p2 = [], []
with torch.no_grad():
  for batch in eval_loader_phase2:
    # Move batch data to device
    input_ids1 = batch['input_ids1'].to(device)
    attention_mask1 = batch['attention_mask1'].to(device)
    input_ids2 = batch['input_ids2'].to(device)
    attention_mask2 = batch['attention_mask2'].to(device)
      
    emb1 = model(input_ids1, attention_mask1)
    emb2 = model(input_ids2, attention_mask2)
    sim = torch.nn.functional.cosine_similarity(emb1, emb2, dim=1).cpu().numpy()
    y_scores_final_p2.extend(sim.tolist())
    y_true_final_p2.extend(batch['label'].numpy().tolist())

if len(np.unique(y_true_final_p2)) > 1:
    final_auc_p2 = M.roc_auc_score(y_true_final_p2, y_scores_final_p2)
    print(f"Final Eval AUC (Phase 2): {final_auc_p2:.4f}")
    wandb.summary["final_eval_auc_phase2"] = final_auc_p2
else:
    print("Final Eval AUC (Phase 2): Could not be calculated (only one class in eval set).")
    wandb.summary["final_eval_auc_phase2"] = None

In [None]:
# Save final embeddings after Phase 2
print("\nSaving final causal embeddings after Phase 2...")
final_embeds_path = os.path.join(os.getcwd(), '..', 'models', 'causal_embeds_final_atomic_cnc.pth')
os.makedirs(os.path.dirname(final_embeds_path), exist_ok=True)
embeds_final = model.encoder.wte.weight.data.cpu()
torch.save(embeds_final, final_embeds_path)
print(f"Final causal embeddings saved to {final_embeds_path}")
wandb.save(final_embeds_path) # Save final embeddings to W&B

In [None]:

# Visualization of Phase 2 results
print("\nPlotting Phase 2 training curves...")
# Determine the actual number of epochs run in Phase 2
actual_epochs_phase2 = len(train_losses_phase2) 

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
if actual_epochs_phase2 > 0:
    plt.plot(range(1, actual_epochs_phase2 + 1), train_losses_phase2, '-o')
plt.title('Phase 2: Average Train Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
if actual_epochs_phase2 > 0:
    plt.plot(range(1, actual_epochs_phase2 + 1), eval_aucs_phase2, '-o')
plt.title('Phase 2: Evaluation AUC per Epoch')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.grid(True)

plt.tight_layout()
plt.show()

# Log the final plot to W&B (Phase 2 run)
try:
    wandb.log({"phase2_training_curves": plt}) 
except Exception as e:
    print(f"Could not log plot to W&B: {e}")

# Finish the W&B run for Phase 1 here
print("Finishing W&B run for Phase 1...")
try:
    wandb.finish()
    print("Phase 1 W&B run finished.")
except Exception as e:
    print(f"Error finishing Phase 1 W&B run: {e}. Continuing...")

In [None]:
os.makedirs(os.path.join(os.getcwd(), '..', 'models'), exist_ok=True)
embeds = model.encoder.wte.weight.data.cpu()
torch.save(embeds, os.path.join(os.getcwd(), '..', 'models', 'causal_embeds.pth'))
print("Causal embeddings saved to ../models/causal_embeds.pth")


## How GPT-2 Embeddings Fine-Tuning Works

In this notebook, GPT-2 embeddings are fine-tuned using contrastive learning:

- For each pair (cause, effect, label), texts are tokenized using GPT2Tokenizer and passed through GPT2Model.
- The resulting embeddings (averaged across tokens) are used as representations of the cause and effect.
- The contrastive loss function (ContrastiveLoss) minimizes the distance between embeddings for positive pairs (real causal relationships) and maximizes it for negative (alternative) pairs.
- Gradients propagate only through GPT2Model parameters (including the embedding layer), which leads to fine-tuning of embeddings for the causal distinction task.
- After training, GPT-2 embeddings become more sensitive to causal relationships, which can be used for integration into downstream tasks or attention modification.

# Next: Improving Convergence
- Add a projection head (MLP) on top of embeddings: GPT2 embedding → Linear → ReLU → Linear, then L2-normalize.

In [None]:
# Determine the actual number of epochs run
actual_epochs = len(train_losses) # Use the length of the recorded losses/aucs

# Plot training loss and evaluation AUC
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, actual_epochs + 1), train_losses, '-o')
plt.title('Average Train Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, actual_epochs + 1), eval_aucs, '-o')
plt.title('Evaluation AUC per Epoch')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.grid(True)

plt.tight_layout()
plt.show()

# Log the final plot to W&B
wandb.log({"training_curves": plt}) # Log the matplotlib figure object

# Finish the W&B run for Phase 1 here
print("Finishing W&B run for Phase 1...")
try:
    wandb.finish()
    print("Phase 1 W&B run finished.")
except Exception as e:
    print(f"Error finishing Phase 1 W&B run: {e}. Continuing...")