In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

def irm_loss(logits, labels, env_id, penalty_weight=1e2):
    """
    Calculate IRM loss
    
    Input:
        logits: model output logits
        labels: true labels
        env_id: environment ID
        penalty_weight: penalty weight
    
    Output:
        total_loss: total loss value
    
    Function:
        Calculate invariant risk minimization loss, including empirical risk and invariance penalty
    """
    # Empirical risk (cross entropy loss)
    criterion = nn.CrossEntropyLoss()
    empirical_risk = criterion(logits, labels)
    
    # Invariance penalty (simplified implementation)
    # In actual IRM, this should calculate gradient penalty for environment-invariant features
    envs = torch.unique(env_id)
    penalty = 0
    
    for e in envs:
        env_mask = (env_id == e)
        if torch.sum(env_mask) > 0:
            env_logits = logits[env_mask]
            env_labels = labels[env_mask]
            env_loss = criterion(env_logits, env_labels)
            
            # Calculate gradient penalty (simplified version)
            grad_params = torch.autograd.grad(env_loss, logits, create_graph=True)[0]
            penalty += torch.norm(grad_params) ** 2
    
    penalty = penalty / len(envs)
    total_loss = empirical_risk + penalty_weight * penalty
    
    return total_loss

def train_irm_wilds(csv_file='wilds_reviews.csv'):
    """
    Train IRM model on Wilds dataset
    
    Input:
        csv_file: dataset file path
    
    Output:
        model: trained model
        results: training and test results
    
    Function:
        Load Wilds dataset and train sentiment classification model using IRM algorithm
    """
    # Load data
    df = pd.read_csv(csv_file)
    
    # Prepare features and labels
    feature_cols = [f'text_feature_{i}' for i in range(10)]
    X = df[feature_cols].values
    y = df['rating'].values - 1  # Convert to 0-4 classes
    
    # Environment ID (reviewer ID)
    env_id = df['reviewer_id'].values
    
    # Split data
    train_mask = df['split'] == 'train'
    test_mask = df['split'] == 'test'
    
    X_train, X_test = X[train_mask], X[test_mask]
    y_train, y_test = y[train_mask], y[test_mask]
    env_train, env_test = env_id[train_mask], env_id[test_mask]
    
    # Standardize features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    # Convert to PyTorch tensors
    X_train = torch.FloatTensor(X_train)
    y_train = torch.LongTensor(y_train)
    env_train = torch.LongTensor(env_train)
    X_test = torch.FloatTensor(X_test)
    y_test = torch.LongTensor(y_test)
    
    # Define model
    class SimpleClassifier(nn.Module):
        def __init__(self, input_dim=10, hidden_dim=64, output_dim=5):
            super(SimpleClassifier, self).__init__()
            self.network = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            )
        
        def forward(self, x):
            return self.network(x)
    
    model = SimpleClassifier()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Train model
    epochs = 100
    train_losses = []
    test_accuracies = []
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        logits = model(X_train)
        loss = irm_loss(logits, y_train, env_train)
        
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        
        # Evaluate
        model.eval()
        with torch.no_grad():
            test_logits = model(X_test)
            test_preds = torch.argmax(test_logits, dim=1)
            test_acc = accuracy_score(y_test.numpy(), test_preds.numpy())
            test_accuracies.append(test_acc)
        
        if epoch % 20 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item():.4f}, Test Acc: {test_acc:.4f}')
    
    print(f"Final test accuracy: {test_accuracies[-1]:.4f}")
    
    return model, {'train_losses': train_losses, 'test_accuracies': test_accuracies}

# Execute all code
if __name__ == "__main__":
    print("\n=== Training IRM model on Wilds data ===")
    irm_model, irm_results = train_irm_wilds()