In [None]:
# Multi-Agent Football RL - Interactive Demo
# This notebook demonstrates training, evaluation, and visualization

# Cell 1: Setup and Imports
# ========================

import sys
import os
sys.path.append('..')

import numpy as np
import torch
import matplotlib.pyplot as plt
from IPython.display import clear_output, HTML
import time

from env.football_env import FootballEnv
from models.ppo_agent import PPOAgent
from training.buffer import MultiAgentBuffer
from visualization.heatmap import HeatmapGenerator
from visualization.pass_network import PassNetworkAnalyzer

print("âœ“ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


# Cell 2: Create Environment
# ===========================

# Initialize environment
env = FootballEnv(
    num_agents_per_team=3,
    grid_width=12,
    grid_height=8,
    max_steps=200,
    render_mode='ansi'
)

# Get dimensions
agent = env.agents[0]
obs_dim = env.observation_space(agent).shape[0]
action_dim = env.action_space(agent).n

print(f"Environment created!")
print(f"Observation dim: {obs_dim}")
print(f"Action dim: {action_dim}")
print(f"Agents: {env.agents}")

# Visualize initial state
observations, _ = env.reset()
print("\nInitial State:")
print(env._render_ansi())


# Cell 3: Create Agents
# ======================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create team agents
team_0_agent = PPOAgent(
    obs_dim=obs_dim,
    action_dim=action_dim,
    device=device,
    lr=3e-4,
    gamma=0.99,
    gae_lambda=0.95
)

team_1_agent = PPOAgent(
    obs_dim=obs_dim,
    action_dim=action_dim,
    device=device,
    lr=3e-4,
    gamma=0.99,
    gae_lambda=0.95
)

print("âœ“ Agents created successfully!")
print(f"Actor parameters: {sum(p.numel() for p in team_0_agent.actor.parameters()):,}")
print(f"Critic parameters: {sum(p.numel() for p in team_0_agent.critic.parameters()):,}")


# Cell 4: Random Policy Baseline
# ===============================

def run_episode(env, policy='random'):
    """Run one episode with specified policy"""
    observations, _ = env.reset()
    done = False
    total_reward = 0
    steps = 0
    
    while not done:
        actions = {}
        
        for agent in env.agents:
            obs = observations[agent]
            
            if policy == 'random':
                action = env.action_space(agent).sample()
            else:
                # Use trained agent
                team_id = 0 if 'team_0' in agent else 1
                agent_obj = team_0_agent if team_id == 0 else team_1_agent
                action, _, _, _ = agent_obj.get_action(obs, deterministic=True)
                action = int(action)
            
            actions[agent] = action
        
        # Step environment
        for agent in env.agent_iter():
            observation, reward, termination, truncation, info = env.step(actions[agent])
            observations[agent] = observation
            
            if 'team_0' in agent:
                total_reward += reward
            
            if termination or truncation:
                done = True
                break
        
        steps += 1
    
    return total_reward, steps, env.episode_stats

# Test random policy
print("Testing random policy...")
rewards = []
for i in range(10):
    reward, steps, stats = run_episode(env, 'random')
    rewards.append(reward)
    if i == 0:
        print(f"\nExample episode:")
        print(f"  Total reward: {reward:.2f}")
        print(f"  Steps: {steps}")
        print(f"  Score: {stats['goals_team_0']}-{stats['goals_team_1']}")
        print(f"  Passes: {stats['successful_passes']}/{stats['passes']}")

print(f"\nRandom policy average reward: {np.mean(rewards):.2f} Â± {np.std(rewards):.2f}")


# Cell 5: Training Loop (Mini Version)
# ====================================

def train_mini(num_episodes=100, update_interval=10):
    """Mini training loop for demonstration"""
    
    # Buffers
    buffer_0 = MultiAgentBuffer(3, 2048, obs_dim)
    buffer_1 = MultiAgentBuffer(3, 2048, obs_dim)
    
    rewards_history = []
    win_rate_history = []
    
    print("Starting mini training...")
    
    for episode in range(num_episodes):
        # Collect episode
        observations, _ = env.reset()
        done = False
        episode_reward = 0
        
        while not done:
            actions = {}
            
            for agent in env.agents:
                obs = observations[agent]
                team_id = 0 if 'team_0' in agent else 1
                agent_obj = team_0_agent if team_id == 0 else team_1_agent
                
                action, log_prob, value, _ = agent_obj.get_action(obs)
                actions[agent] = action
                
                # Store in buffer
                agent_idx = int(agent.split('_')[-1])
                buffer = buffer_0 if team_id == 0 else buffer_1
                buffer.add(agent_idx, obs, action, 0, value, log_prob, False)
            
            # Step
            for agent in env.agent_iter():
                observation, reward, termination, truncation, info = env.step(actions[agent])
                observations[agent] = observation
                
                # Update buffer
                team_id = 0 if 'team_0' in agent else 1
                agent_idx = int(agent.split('_')[-1])
                buffer = buffer_0 if team_id == 0 else buffer_1
                buffer.buffers[agent_idx].rewards[buffer.buffers[agent_idx].ptr - 1] = reward
                
                if 'team_0' in agent:
                    episode_reward += reward
                
                if termination or truncation:
                    buffer.buffers[agent_idx].dones[buffer.buffers[agent_idx].ptr - 1] = 1
                    done = True
                    break
        
        rewards_history.append(episode_reward)
        
        # Update
        if (episode + 1) % update_interval == 0:
            # Compute returns
            buffer_0.compute_returns_and_advantages([0, 0, 0])
            buffer_1.compute_returns_and_advantages([0, 0, 0])
            
            # Update agents
            from training.buffer import DummyBuffer
            data_0 = buffer_0.get_all_training_data()
            data_1 = buffer_1.get_all_training_data()
            
            team_0_agent.update(DummyBuffer(*data_0))
            team_1_agent.update(DummyBuffer(*data_1))
            
            # Clear buffers
            buffer_0.clear()
            buffer_1.clear()
            
            # Evaluate
            wins = 0
            for _ in range(10):
                _, _, stats = run_episode(env, 'trained')
                if stats['goals_team_0'] > stats['goals_team_1']:
                    wins += 1
            
            win_rate = wins / 10
            win_rate_history.append(win_rate)
            
            # Print progress
            avg_reward = np.mean(rewards_history[-update_interval:])
            print(f"Episode {episode+1}/{num_episodes} | "
                  f"Avg Reward: {avg_reward:.2f} | "
                  f"Win Rate: {win_rate:.1%}")
            
            clear_output(wait=True)
    
    return rewards_history, win_rate_history

# Run mini training
rewards_hist, winrate_hist = train_mini(num_episodes=100, update_interval=10)

print("\nâœ“ Training complete!")


# Cell 6: Visualize Training Progress
# ====================================

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

# Plot rewards
ax1.plot(rewards_hist, alpha=0.3, label='Raw')
window = 10
smoothed = np.convolve(rewards_hist, np.ones(window)/window, mode='valid')
ax1.plot(range(window-1, len(rewards_hist)), smoothed, label=f'Smoothed ({window})')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Total Reward')
ax1.set_title('Training Reward')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot win rate
episodes = [(i+1)*10 for i in range(len(winrate_hist))]
ax2.plot(episodes, winrate_hist, marker='o')
ax2.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='50% baseline')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Win Rate')
ax2.set_title('Evaluation Win Rate')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


# Cell 7: Visualize Trained Agent
# =================================

print("Running trained agent...")

observations, _ = env.reset()
done = False
step = 0

states = []

while not done and step < 50:  # Limit to 50 steps for notebook
    print(f"\n{'='*50}")
    print(f"Step {step}")
    print(env._render_ansi())
    
    # Save state
    states.append({
        'positions': env.agent_positions.copy(),
        'ball': env.ball_position.copy()
    })
    
    actions = {}
    for agent in env.agents:
        obs = observations[agent]
        team_id = 0 if 'team_0' in agent else 1
        agent_obj = team_0_agent if team_id == 0 else team_1_agent
        action, _, _, _ = agent_obj.get_action(obs, deterministic=True)
        actions[agent] = int(action)
    
    # Step
    for agent in env.agent_iter():
        observation, reward, termination, truncation, info = env.step(actions[agent])
        observations[agent] = observation
        
        if termination or truncation:
            done = True
            break
    
    step += 1
    time.sleep(0.5)  # Slow down for visibility

print(f"\n{'='*50}")
print("Episode Complete!")
print(f"Final Score: {env.episode_stats['goals_team_0']}-{env.episode_stats['goals_team_1']}")
print(f"Passes: {env.episode_stats['successful_passes']}/{env.episode_stats['passes']}")


# Cell 8: Analyze Behavior
# =========================

print("Analyzing learned behavior...")

# Run multiple episodes
episode_data = []
for i in range(20):
    _, _, stats = run_episode(env, 'trained')
    episode_data.append(stats)

# Compute statistics
avg_passes = np.mean([e['passes'] for e in episode_data])
avg_successful = np.mean([e['successful_passes'] for e in episode_data])
avg_shots = np.mean([e['shots'] for e in episode_data])
wins = sum(1 for e in episode_data if e['goals_team_0'] > e['goals_team_1'])

print(f"\nBehavior Analysis (20 episodes):")
print(f"  Win rate: {wins/20:.1%}")
print(f"  Avg passes: {avg_passes:.1f}")
print(f"  Pass success rate: {avg_successful/avg_passes:.1%}" if avg_passes > 0 else "  No passes")
print(f"  Avg shots: {avg_shots:.1f}")
print(f"  Avg goals: {np.mean([e['goals_team_0'] for e in episode_data]):.1f}")

# Compare to random
random_data = []
for i in range(20):
    _, _, stats = run_episode(env, 'random')
    random_data.append(stats)

random_passes = np.mean([e['passes'] for e in random_data])
random_wins = sum(1 for e in random_data if e['goals_team_0'] > e['goals_team_1'])

print(f"\nRandom Policy Baseline:")
print(f"  Win rate: {random_wins/20:.1%}")
print(f"  Avg passes: {random_passes:.1f}")

print(f"\nImprovement:")
print(f"  Win rate: +{(wins-random_wins)/20*100:.1f}%")
print(f"  Passes: +{avg_passes-random_passes:.1f}")


# Cell 9: Visualize Position Heatmap
# ===================================

# Collect position data
position_data = {agent: [] for agent in env.agents}

for _ in range(50):
    observations, _ = env.reset()
    done = False
    
    while not done:
        for agent, pos in env.agent_positions.items():
            position_data[agent].append(pos.copy())
        
        actions = {}
        for agent in env.agents:
            obs = observations[agent]
            team_id = 0 if 'team_0' in agent else 1
            agent_obj = team_0_agent if team_id == 0 else team_1_agent
            action, _, _, _ = agent_obj.get_action(obs, deterministic=True)
            actions[agent] = int(action)
        
        for agent in env.agent_iter():
            observation, reward, termination, truncation, info = env.step(actions[agent])
            observations[agent] = observation
            
            if termination or truncation:
                done = True
                break

# Plot heatmaps
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, agent in enumerate(env.agents):
    positions = np.array(position_data[agent])
    
    # Create 2D histogram
    heatmap, xedges, yedges = np.histogram2d(
        positions[:, 1], positions[:, 0],
        bins=[env.grid_height, env.grid_width],
        range=[[0, env.grid_height], [0, env.grid_width]]
    )
    
    axes[idx].imshow(heatmap, origin='lower', cmap='YlOrRd', aspect='auto')
    axes[idx].set_title(agent)
    axes[idx].set_xlabel('X')
    axes[idx].set_ylabel('Y')

plt.tight_layout()
plt.show()

print("âœ“ Heatmaps generated!")


# Cell 10: Summary and Next Steps
# ================================

print("="*60)
print("DEMO COMPLETE!")
print("="*60)

print("\nðŸ“Š Results Summary:")
print(f"  - Trained for {len(rewards_hist)} episodes")
print(f"  - Final win rate: {winrate_hist[-1]:.1%}")
print(f"  - Learned to pass: {avg_passes:.1f} passes per episode")
print(f"  - Pass success rate: {avg_successful/avg_passes:.1%}" if avg_passes > 0 else "")

print("\nðŸš€ Next Steps:")
print("  1. Train for more episodes (10,000+) for better results")
print("  2. Use curriculum learning: python training/train_ppo.py --curriculum")
print("  3. Analyze pass networks: python visualization/pass_network.py")
print("  4. Experiment with reward shaping in football_env.py")
print("  5. Try self-play training against archived checkpoints")

print("\nðŸ“š Resources:")
print("  - Full training: training/train_ppo.py")
print("  - Visualization: visualization/")
print("  - Config files: configs/")
print("  - Documentation: README.md")

print("\nðŸ’¡ Tips:")
print("  - Monitor training with TensorBoard: tensorboard --logdir runs/")
print("  - Save this trained model: torch.save(...)")
print("  - Load checkpoint: checkpoint = torch.load('model.pt')")

print("\nHappy training! âš½ðŸ¤–")