In [None]:
# CA8: Causal Reasoning and Multi-Modal Reinforcement Learning

This notebook explores advanced topics in deep reinforcement learning, focusing on:
1. Causal discovery and reasoning
2. Multi-modal environments
3. Causal RL agents
4. Counterfactual reasoning

## Setup and Imports

In [None]:
# Import required libraries
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath('__file__')))

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Import our custom modules
from causal_rl_utils import device
from causal_discovery import CausalGraph, CausalDiscovery
from causal_rl_agent import CausalRLAgent, CounterfactualRLAgent, CausalReasoningNetwork
from multi_modal_env import MultiModalGridWorld, MultiModalWrapper

print("Setup complete!")
print(f"Using device: {device}")

## Section 1: Causal Discovery

In this section, we explore methods for learning causal structure from observational data.

In [None]:
# Demonstrate Causal Graph functionality

def demonstrate_causal_graph():
    """Demonstrate basic causal graph operations"""
    print("=== Causal Graph Demonstration ===")
    
    # Create a simple causal graph
    variables = ['A', 'B', 'C', 'D']
    graph = CausalGraph(variables)
    
    # Add some causal relationships
    graph.add_edge('A', 'B')
    graph.add_edge('A', 'C')
    graph.add_edge('B', 'D')
    graph.add_edge('C', 'D')
    
    print(f"Variables: {graph.variables}")
    print(f"Graph structure: {graph}")
    
    # Test graph properties
    print(f"Is DAG: {graph.is_dag()}")
    print(f"Topological order: {graph.get_topological_order()}")
    
    # Test parent/child relationships
    print(f"Parents of D: {graph.get_parents('D')}")
    print(f"Children of A: {graph.get_children('A')}")
    print(f"Ancestors of D: {graph.get_ancestors('D')}")
    print(f"Descendants of A: {graph.get_descendants('A')}")
    
    # Visualize the graph
    try:
        import networkx as nx
        G = graph.to_networkx()
        pos = nx.spring_layout(G)
        plt.figure(figsize=(8, 6))
        nx.draw(G, pos, with_labels=True, node_color='lightblue', 
                node_size=2000, font_size=16, arrows=True, arrowsize=20)
        plt.title("Causal Graph Visualization")
        plt.show()
    except ImportError:
        print("NetworkX not available for visualization")
    
    return graph

# Run demonstration
causal_graph = demonstrate_causal_graph()

In [None]:
# Demonstrate causal discovery algorithms

def demonstrate_causal_discovery():
    """Demonstrate causal discovery from data"""
    print("=== Causal Discovery Demonstration ===")
    
    # Generate synthetic data with known causal structure
    np.random.seed(42)
    n_samples = 1000
    n_vars = 4
    
    # True causal structure: A -> B -> D <- C <- A
    A = np.random.normal(0, 1, n_samples)
    C = A + np.random.normal(0, 0.5, n_samples)
    B = A + np.random.normal(0, 0.5, n_samples)
    D = B + C + np.random.normal(0, 0.5, n_samples)
    
    data = np.column_stack([A, B, C, D])
    var_names = ['A', 'B', 'C', 'D']
    
    print("Generated data with true causal structure: A -> B, A -> C, B -> D, C -> D")
    
    # Apply different discovery algorithms
    algorithms = {
        'PC Algorithm': CausalDiscovery.pc_algorithm,
        'GES Algorithm': CausalDiscovery.ges_algorithm,
        'LiNGAM': CausalDiscovery.lingam_algorithm
    }
    
    discovered_graphs = {}
    
    for name, algorithm in algorithms.items():
        try:
            graph = algorithm(data, var_names)
            discovered_graphs[name] = graph
            print(f"\n{name} discovered structure:")
            print(graph)
        except Exception as e:
            print(f"\n{name} failed: {e}")
    
    return discovered_graphs

# Run causal discovery
discovered_graphs = demonstrate_causal_discovery()

## Section 2: Causal Reinforcement Learning

Now we implement RL agents that leverage causal structure for improved learning.

In [None]:
# Demonstrate Causal RL Agent

def demonstrate_causal_rl():
    """Demonstrate causal RL agent on a simple environment"""
    print("=== Causal RL Agent Demonstration ===")
    
    # Create a simple grid world environment
    class SimpleGridWorld:
        """Simple grid world for testing"""
        def __init__(self, size=5):
            self.size = size
            self.state_dim = 2  # position only
            self.action_dim = 4  # up, down, left, right
            
        def reset(self):
            self.pos = np.random.randint(0, self.size, 2)
            return self.pos.astype(float), {}
            
        def step(self, action):
            # Action mapping
            moves = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # up, down, left, right
            new_pos = self.pos + np.array(moves[action])
            
            # Check bounds
            new_pos = np.clip(new_pos, 0, self.size - 1)
            self.pos = new_pos
            
            # Simple reward: distance to center
            center = np.array([self.size//2, self.size//2])
            dist = np.linalg.norm(self.pos - center)
            reward = -dist / (self.size * np.sqrt(2))
            
            return self.pos.astype(float), reward, False, False, {}
    
    env = SimpleGridWorld()
    
    # Create causal graph for the environment
    # Assume position affects reward through distance to center
    variables = ['pos_x', 'pos_y', 'distance', 'reward']
    causal_graph = CausalGraph(variables)
    causal_graph.add_edge('pos_x', 'distance')
    causal_graph.add_edge('pos_y', 'distance')
    causal_graph.add_edge('distance', 'reward')
    
    print(f"Environment causal graph: {causal_graph}")
    
    # Create causal RL agent
    agent = CausalRLAgent(
        state_dim=env.state_dim,
        action_dim=env.action_dim,
        causal_graph=causal_graph,
        lr=1e-3
    )
    
    # Training loop
    print("\nTraining Causal RL Agent...")
    rewards = []
    
    for episode in range(100):
        state, _ = env.reset()
        episode_reward = 0
        
        for step in range(20):
            action, _ = agent.select_action(state)
            next_state, reward, done, _, _ = env.step(action)
            
            # Simple training data
            agent.update([state], [action], [reward], [next_state], [done])
            
            episode_reward += reward
            state = next_state
            
            if done:
                break
        
        rewards.append(episode_reward)
        
        if (episode + 1) % 20 == 0:
            avg_reward = np.mean(rewards[-20:])
            print(f"Episode {episode+1:3d} | Avg Reward: {avg_reward:.3f}")
    
    # Test causal interventions
    print("\nTesting causal interventions...")
    test_state = np.array([2.0, 2.0])  # Center position
    
    # Original prediction
    original_action, _ = agent.select_action(test_state, deterministic=True)
    print(f"Original state {test_state}: Action {original_action}")
    
    # Intervene on position
    intervention = {'pos_x': 0.0, 'pos_y': 0.0}  # Move to corner
    intervened_state = agent.perform_intervention(test_state, intervention)
    intervened_action, _ = agent.select_action(intervened_state, deterministic=True)
    print(f"After intervention {intervention}: Action {intervened_action}")
    
    return {
        'agent': agent,
        'environment': env,
        'rewards': rewards,
        'causal_graph': causal_graph
    }

# Run demonstration
causal_rl_results = demonstrate_causal_rl()

## Section 3: Multi-Modal Environments

This section explores environments that provide multiple modalities of information.

In [None]:
# Demonstrate Multi-Modal Environment

def demonstrate_multi_modal_env():
    """Demonstrate multi-modal grid world environment"""
    print("=== Multi-Modal Environment Demonstration ===")
    
    # Create multi-modal environment
    env = MultiModalGridWorld(size=6, render_size=84)
    
    # Reset and get observation
    obs, _ = env.reset()
    
    print("Observation modalities:")
    print(f"- Visual: {obs['visual'].shape} (RGB image)")
    print(f"- Text: {obs['text']['text']}")
    print(f"- State: {obs['state']} (agent position)")
    
    # Take some random actions and show observations
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    for i in range(6):
        action = np.random.randint(0, 4)
        next_obs, reward, done, _, _ = env.step(action)
        
        ax = axes[i // 3, i % 3]
        ax.imshow(next_obs['visual'])
        ax.set_title(f"Step {i+1}: {next_obs['text']['text'][:30]}...")
        ax.axis('off')
        
        if done:
            break
    
    plt.tight_layout()
    plt.show()
    
    # Demonstrate multi-modal wrapper
    wrapper = MultiModalWrapper(env)
    processed_obs = wrapper.process_observation(obs)
    
    print(f"\nProcessed observation shape: {processed_obs.shape}")
    print(f"Feature breakdown:")
    print(f"- Visual features: {wrapper.visual_dim}")
    print(f"- Text features: {wrapper.text_dim}")
    print(f"- State features: {wrapper.state_dim}")
    
    return env, wrapper

# Run demonstration
mm_env, mm_wrapper = demonstrate_multi_modal_env()

## Section 4: Integrated Causal Multi-Modal RL

Combining causal reasoning with multi-modal perception for advanced RL.

In [None]:
# Integrated demonstration combining all components

def demonstrate_integrated_system():
    """Demonstrate integrated causal multi-modal RL system"""
    print("=== Integrated Causal Multi-Modal RL Demonstration ===")
    
    # Create multi-modal environment
    env = MultiModalGridWorld(size=4, render_size=64)  # Smaller for faster training
    wrapper = MultiModalWrapper(env)
    
    # Create causal graph for multi-modal setting
    variables = ['agent_x', 'agent_y', 'goal_x', 'goal_y', 'visual_features', 'text_features', 'reward']
    causal_graph = CausalGraph(variables)
    
    # Define causal relationships
    causal_graph.add_edge('agent_x', 'visual_features')
    causal_graph.add_edge('agent_y', 'visual_features')
    causal_graph.add_edge('goal_x', 'visual_features')
    causal_graph.add_edge('goal_y', 'visual_features')
    causal_graph.add_edge('agent_x', 'text_features')
    causal_graph.add_edge('agent_y', 'text_features')
    causal_graph.add_edge('goal_x', 'text_features')
    causal_graph.add_edge('goal_y', 'text_features')
    causal_graph.add_edge('visual_features', 'reward')
    causal_graph.add_edge('text_features', 'reward')
    
    print(f"Causal graph for multi-modal RL: {causal_graph}")
    
    # Create causal RL agent (adapted for multi-modal)
    class MultiModalCausalRLAgent(CausalRLAgent):
        """Causal RL agent adapted for multi-modal observations"""
        
        def __init__(self, wrapper, causal_graph, lr=1e-3):
            self.wrapper = wrapper
            state_dim = wrapper.total_dim
            action_dim = 4  # grid world actions
            super().__init__(state_dim, action_dim, causal_graph, lr)
        
        def select_action(self, obs, deterministic=False):
            """Select action from multi-modal observation"""
            # Process observation
            state = self.wrapper.process_observation(obs)
            return super().select_action(state, deterministic)
        
        def train_episode(self, env):
            """Train for one episode with multi-modal observations"""
            obs, _ = env.reset()
            episode_reward = 0
            steps = 0
            
            states, actions, rewards, next_obss, dones = [], [], [], [], []
            
            while steps < env.max_steps:
                action, _ = self.select_action(obs)
                next_obs, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                
                # Store processed states
                states.append(self.wrapper.process_observation(obs))
                actions.append(action)
                rewards.append(reward)
                next_obss.append(self.wrapper.process_observation(next_obs))
                dones.append(done)
                
                episode_reward += reward
                steps += 1
                obs = next_obs
                
                if done:
                    break
            
            # Update agent
            if len(states) > 0:
                self.update(states, actions, rewards, next_obss, dones)
            
            self.episode_rewards.append(episode_reward)
            return episode_reward, steps
    
    # Create and train agent
    agent = MultiModalCausalRLAgent(wrapper, causal_graph, lr=1e-3)
    
    print("\nTraining Multi-Modal Causal RL Agent...")
    training_rewards = []
    
    for episode in range(50):  # Shorter training for demo
        reward, steps = agent.train_episode(env)
        training_rewards.append(reward)
        
        if (episode + 1) % 10 == 0:
            avg_reward = np.mean(training_rewards[-10:])
            print(f"Episode {episode+1:2d} | Avg Reward: {avg_reward:.3f} | Steps: {steps}")
    
    # Visualization
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Training curve
    axes[0].plot(training_rewards)
    axes[0].plot(pd.Series(training_rewards).rolling(5).mean(), 
                 color='red', label='Moving Average')
    axes[0].set_title('Multi-Modal Causal RL Training')
    axes[0].set_xlabel('Episode')
    axes[0].set_ylabel('Episode Reward')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Sample environment render
    obs, _ = env.reset()
    axes[1].imshow(obs['visual'])
    axes[1].set_title(f'Environment Render\n{obs["text"]["text"]}')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return {
        'agent': agent,
        'environment': env,
        'wrapper': wrapper,
        'training_rewards': training_rewards,
        'causal_graph': causal_graph
    }

# Run integrated demonstration
integrated_results = demonstrate_integrated_system()

## Section 5: Comprehensive Experiments

Running comprehensive experiments to compare different approaches.

In [None]:
# Comprehensive experiments comparing different approaches

def run_comprehensive_experiments():
    """Run comprehensive experiments comparing different RL approaches"""
    print("=== Comprehensive RL Experiments ===")
    
    # Create environments
    simple_env = MultiModalGridWorld(size=5, render_size=64)
    wrapper = MultiModalWrapper(simple_env)
    
    # Setup causal graph
    variables = ['agent_x', 'agent_y', 'goal_x', 'goal_y', 'visual', 'text', 'reward']
    causal_graph = CausalGraph(variables)
    causal_graph.add_edge('agent_x', 'visual')
    causal_graph.add_edge('agent_y', 'visual')
    causal_graph.add_edge('goal_x', 'visual')
    causal_graph.add_edge('goal_y', 'visual')
    causal_graph.add_edge('agent_x', 'text')
    causal_graph.add_edge('agent_y', 'text')
    causal_graph.add_edge('goal_x', 'text')
    causal_graph.add_edge('goal_y', 'text')
    causal_graph.add_edge('visual', 'reward')
    causal_graph.add_edge('text', 'reward')
    
    # Experiment configurations
    experiments = {
        'Standard RL': {'use_causal': False, 'use_multi_modal': False},
        'Multi-Modal RL': {'use_causal': False, 'use_multi_modal': True},
        'Causal RL': {'use_causal': True, 'use_multi_modal': False},
        'Causal Multi-Modal RL': {'use_causal': True, 'use_multi_modal': True}
    }
    
    results = {}
    
    for exp_name, config in experiments.items():
        print(f"\n--- Running {exp_name} ---")
        
        if config['use_causal']:
            if config['use_multi_modal']:
                # Causal Multi-Modal Agent
                class ExpAgent(MultiModalCausalRLAgent):
                    pass
                agent = ExpAgent(wrapper, causal_graph)
            else:
                # Causal Agent (simplified state)
                agent = CausalRLAgent(
                    state_dim=2, action_dim=4, causal_graph=causal_graph
                )
        else:
            if config['use_multi_modal']:
                # Multi-Modal Agent (no causal reasoning)
                class ExpAgent(CausalRLAgent):
                    def __init__(self, wrapper):
                        self.wrapper = wrapper
                        super().__init__(wrapper.total_dim, 4, causal_graph)
                        
                    def select_action(self, obs, deterministic=False):
                        state = self.wrapper.process_observation(obs)
                        return super().select_action(state, deterministic)
                agent = ExpAgent(wrapper)
            else:
                # Standard Agent
                agent = CausalRLAgent(2, 4, causal_graph)
        
        # Train agent
        rewards = []
        for episode in range(30):  # Short training for demo
            if config['use_multi_modal']:
                reward, _ = agent.train_episode(simple_env)
            else:
                # Simple training for non-multi-modal
                state, _ = simple_env.reset()
                episode_reward = 0
                for step in range(10):
                    action, _ = agent.select_action(state.astype(float))
                    next_state, reward, done, _, _ = simple_env.step(action)
                    agent.update([state.astype(float)], [action], [reward], 
                               [next_state.astype(float)], [done])
                    episode_reward += reward
                    state = next_state
                    if done:
                        break
                reward = episode_reward
            
            rewards.append(reward)
        
        results[exp_name] = {
            'rewards': rewards,
            'final_avg': np.mean(rewards[-10:]),
            'config': config
        }
        
        print(f"{exp_name}: Final Avg Reward = {results[exp_name]['final_avg']:.3f}")
    
    # Visualization
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Training curves
    for exp_name, result in results.items():
        axes[0].plot(result['rewards'], label=exp_name, linewidth=2)
    
    axes[0].set_title('Training Performance Comparison')
    axes[0].set_xlabel('Episode')
    axes[0].set_ylabel('Episode Reward')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Final performance bar chart
    exp_names = list(results.keys())
    final_scores = [results[name]['final_avg'] for name in exp_names]
    
    bars = axes[1].bar(exp_names, final_scores, color=['blue', 'green', 'red', 'purple'], alpha=0.7)
    axes[1].set_title('Final Performance Comparison')
    axes[1].set_ylabel('Average Reward (Last 10 Episodes)')
    axes[1].tick_params(axis='x', rotation=45)
    axes[1].grid(True, alpha=0.3)
    
    # Add value labels
    for bar, score in zip(bars, final_scores):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{score:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print("\n=== Experiment Summary ===")
    for exp_name, result in results.items():
        config = result['config']
        causal_status = "✓" if config['use_causal'] else "✗"
        modal_status = "✓" if config['use_multi_modal'] else "✗"
        print(f"{exp_name:20s}: Causal={causal_status} Multi-Modal={modal_status} "
              f"Final Score={result['final_avg']:.3f}")
    
    return results

# Run comprehensive experiments
experiment_results = run_comprehensive_experiments()

## Conclusion

This notebook demonstrated:

1. **Causal Discovery**: Learning causal structure from data using PC, GES, and LiNGAM algorithms
2. **Causal RL Agents**: Agents that leverage causal reasoning for improved decision making
3. **Multi-Modal Environments**: Environments providing visual, textual, and state information
4. **Integrated Systems**: Combining causal reasoning with multi-modal perception

Key insights:
- Causal reasoning can improve sample efficiency and interpretability
- Multi-modal information provides richer representations for learning
- Combining both approaches leads to more robust and capable RL systems

The modular design allows for easy extension and experimentation with different causal discovery methods, RL algorithms, and multi-modal architectures.