In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

# Fourth code segment: Implementing IRM + Bayesian Uncertainty on FastShip data
class BayesianIRM(nn.Module):
    """
    Bayesian IRM Model
    
    Input:
        input_dim: input feature dimension
        hidden_dim: hidden layer dimension
        output_dim: output dimension
    
    Output:
        Prediction results and uncertainty estimates
    
    Function:
        Regression model combining causal invariance and Bayesian uncertainty
    """
    def __init__(self, input_dim=3, hidden_dim=32, output_dim=1):
        super(BayesianIRM, self).__init__()
        
        # Feature extraction layer
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Prediction layer (mean)
        self.predictor = nn.Linear(hidden_dim, output_dim)
        
        # Uncertainty layer (log variance)
        self.uncertainty = nn.Linear(hidden_dim, output_dim)
        
        # Dropout for Bayesian approximation
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x, num_samples=5):
        # Feature extraction
        features = self.feature_extractor(x)
        features = self.dropout(features)
        
        # Prediction
        predictions = []
        log_vars = []
        
        for _ in range(num_samples):
            # Apply dropout for Bayesian approximation
            sampled_features = self.dropout(features)
            pred = self.predictor(sampled_features)
            log_var = self.uncertainty(sampled_features)
            
            predictions.append(pred)
            log_vars.append(log_var)
        
        # Average prediction and uncertainty
        pred_mean = torch.mean(torch.stack(predictions), dim=0)
        pred_var = torch.mean(torch.exp(torch.stack(log_vars)), dim=0)
        
        return pred_mean, pred_var

def bayesian_irm_loss(pred_mean, pred_var, targets, env_id, penalty_weight=1e2):
    """
    Bayesian IRM loss function
    
    Input:
        pred_mean: predicted mean
        pred_var: predicted variance
        targets: true target values
        env_id: environment ID
        penalty_weight: penalty weight
    
    Output:
        total_loss: total loss value
    
    Function:
        Loss function combining Bayesian uncertainty and causal invariance
    """
    # Bayesian negative log likelihood (Gaussian likelihood)
    nll = 0.5 * torch.log(pred_var) + 0.5 * ((targets - pred_mean) ** 2) / pred_var
    nll = nll.mean()
    
    # IRM penalty term (simplified implementation)
    envs = torch.unique(env_id)
    penalty = 0
    
    for e in envs:
        env_mask = (env_id == e)
        if torch.sum(env_mask) > 0:
            env_pred_mean = pred_mean[env_mask]
            env_targets = targets[env_mask]
            env_loss = torch.mean(0.5 * ((env_targets - env_pred_mean) ** 2))
            
            # Calculate gradient penalty - need to create computation graph here
            if env_pred_mean.requires_grad:
                grad_params = torch.autograd.grad(
                    env_loss, 
                    env_pred_mean, 
                    create_graph=True, 
                    retain_graph=True
                )[0]
                if grad_params is not None:
                    penalty += torch.norm(grad_params) ** 2
    
    if len(envs) > 0:
        penalty = penalty / len(envs)
    else:
        penalty = torch.tensor(0.0, device=pred_mean.device)
    
    total_loss = nll + penalty_weight * penalty
    
    return total_loss

def train_bayesian_irm_fastship(csv_file='fastship_data.csv'):
    """
    Train Bayesian IRM model on FastShip dataset
    
    Input:
        csv_file: dataset file path
    
    Output:
        model: trained model
        results: training and test results
    
    Function:
        Load FastShip dataset and train order prediction model using Bayesian IRM algorithm
    """
    # Load data
    df = pd.read_csv(csv_file)
    
    # Prepare features and labels
    feature_cols = ['economic_index', 'trend', 'spurious_power']
    X = df[feature_cols].values
    y = df['orders'].values.astype(np.float32)
    
    # Environment ID
    env_id = df['environment'].values.astype(np.int64)
    
    # Split data
    train_mask = df['split'] == 'train'
    test_mask = df['split'] == 'test'
    
    X_train, X_test = X[train_mask], X[test_mask]
    y_train, y_test = y[train_mask], y[test_mask]
    env_train, env_test = env_id[train_mask], env_id[test_mask]
    
    # Standardize features and targets
    X_scaler = StandardScaler()
    y_scaler = StandardScaler()
    
    X_train = X_scaler.fit_transform(X_train)
    X_test = X_scaler.transform(X_test)
    
    y_train = y_scaler.fit_transform(y_train.reshape(-1, 1)).flatten()
    y_test = y_scaler.transform(y_test.reshape(-1, 1)).flatten()
    
    # Convert to PyTorch tensors
    X_train = torch.FloatTensor(X_train)
    y_train = torch.FloatTensor(y_train)
    env_train = torch.LongTensor(env_train)
    X_test = torch.FloatTensor(X_test)
    y_test = torch.FloatTensor(y_test)
    env_test = torch.LongTensor(env_test)
    
    # Define model
    model = BayesianIRM(input_dim=3, hidden_dim=16, output_dim=1)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    # Train model
    epochs = 100
    train_losses = []
    test_mses = []
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        pred_mean, pred_var = model(X_train)
        loss = bayesian_irm_loss(pred_mean.squeeze(), pred_var.squeeze(), y_train, env_train, penalty_weight=1e2)
        
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        
        # Evaluate
        model.eval()
        with torch.no_grad():
            test_pred_mean, _ = model(X_test)
            test_mse = torch.mean((test_pred_mean.squeeze() - y_test) ** 2)
            test_mses.append(test_mse.item())
        
        if epoch % 20 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item():.4f}, Test MSE: {test_mse.item():.4f}')
    
    print(f"Final test MSE: {test_mses[-1]:.4f}")
    
    # Analyze performance across different environments
    model.eval()
    with torch.no_grad():
        test_pred_mean, test_pred_var = model(X_test)
        test_pred = test_pred_mean.squeeze().numpy()
        y_test_np = y_test.numpy()
        env_test_np = env_test.numpy()
        
        # Weekday environment
        weekday_mask_test = env_test_np == 0
        if np.sum(weekday_mask_test) > 0:
            weekday_mse = np.mean((test_pred[weekday_mask_test] - y_test_np[weekday_mask_test]) ** 2)
        else:
            weekday_mse = 0
        
        # Holiday environment (OoD)
        holiday_mask_test = env_test_np == 1
        if np.sum(holiday_mask_test) > 0:
            holiday_mse = np.mean((test_pred[holiday_mask_test] - y_test_np[holiday_mask_test]) ** 2)
        else:
            holiday_mse = 0
        
        print(f"Weekday test MSE: {weekday_mse:.4f}")
        print(f"Holiday test MSE: {holiday_mse:.4f}")
        print(f"OoD performance gap: {holiday_mse - weekday_mse:.4f}")
        
        # Calculate uncertainty calibration
        residuals = np.abs(test_pred - y_test_np)
        uncertainties = np.sqrt(test_pred_var.squeeze().numpy())
        uncertainty_correlation = np.corrcoef(residuals, uncertainties)[0, 1]
        print(f"Uncertainty-residual correlation: {uncertainty_correlation:.4f}")
    
    return model, {
        'train_losses': train_losses, 
        'test_mses': test_mses,
        'weekday_mse': weekday_mse,
        'holiday_mse': holiday_mse,
        'uncertainty_correlation': uncertainty_correlation
    }

# Execute all code
if __name__ == "__main__":
    print("\n=== Training Bayesian IRM model on FastShip data ===")
    bayesian_irm_model, bayesian_results = train_bayesian_irm_fastship()