In [1]:
# Set GPU runtime first: Runtime > Change Runtime Type > GPU

# Check GPU
!nvidia-smi

Sat Jul  5 16:44:15 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   48C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
# Install PyTorch (keep existing CUDA version)
!pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0

# Install compatible gymnasium and ALE
!pip install gymnasium>=1.0.0
!pip install ale-py>=0.10.1

# Install AutoROM
!pip install autorom[accept-rom-license]

# Install other dependencies
!pip install stable-baselines3==2.0.0
!pip install tensorboard
!pip install moviepy



In [11]:
# Accept license and install ROMs
# Fix the ROM installation
!AutoROM --accept-license --install-dir /usr/local/lib/python3.11/dist-packages/ale_py/roms

import ale_py
import os

print('ALE version:', ale_py.__version__)

# Correct way to get ROM path for ale-py 0.8.1
try:
    # Method 1: Check AutoROM installation path
    autorom_path = "/usr/local/lib/python3.11/dist-packages/AutoROM/roms"
    if os.path.exists(autorom_path):
        print(f"AutoROM path exists: {autorom_path}")
        roms = os.listdir(autorom_path)
        print(f"ROMs in AutoROM: {roms[:5]}...")  # Show first 5

    # Method 2: Check ale_py roms path
    ale_py_path = "/usr/local/lib/python3.11/dist-packages/ale_py/roms"
    if os.path.exists(ale_py_path):
        print(f"ALE-py path exists: {ale_py_path}")
        roms = os.listdir(ale_py_path)
        print(f"ROMs in ale_py: {roms[:5]}...")  # Show first 5

    # Method 3: Try to find asteroids ROM specifically
    for search_path in [autorom_path, ale_py_path]:
        if os.path.exists(search_path):
            for file in os.listdir(search_path):
                if 'asteroids' in file.lower():
                    print(f"✅ Found Asteroids ROM: {file} in {search_path}")

except Exception as e:
    print(f"ROM path check failed: {e}")

# Method 4: Try direct ALE interface to see if ROMs are accessible
try:
    ale = ale_py.ALEInterface()
    available_roms = ale.getAvailableRoms()
    print(f"Available ROMs via ALE: {available_roms}")

    if 'asteroids' in available_roms:
        print("✅ Asteroids is available via ALE!")
    else:
        print("❌ Asteroids not found in available ROMs")

except Exception as e:
    print(f"ALE interface check failed: {e}")

AutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.11/dist-packages/ale_py/roms

Existing ROMs will be overwritten.
ALE version: 0.8.1
AutoROM path exists: /usr/local/lib/python3.11/dist-packages/AutoROM/roms
ROMs in AutoROM: ['hangman.bin', 'assault.bin', 'gopher.bin', 'beam_rider.bin', 'demon_attack.bin']...
ALE-py path exists: /usr/local/lib/python3.11/dist-packages/ale_py/roms
ROMs in ale_py: ['hangman.bin', 'plugins.py', 'assault.bin', 'gopher.bin', 'beam_rider.bin']...
✅ Found Asteroids ROM: asteroids.bin in /usr/local/lib/python3.11/dist-packages/AutoROM/roms
✅ Found Asteroids ROM: asteroids.bin in /usr/local/lib/python3.11/dist-packages/ale_py/roms
ALE interface check failed: 'ale_py._ale_py.ALEInterface' object has no attribute 'getAvailableRoms'


## Environment check first

In [13]:
import gymnasium as gym
import ale_py
import warnings
import os
warnings.filterwarnings('ignore')

print("🔍 Testing Asteroids environment with corrected API...")

# First, let's see what environments are actually registered (fixed API)
print("Registered environments:")
try:
    # Correct way to access registry in gymnasium 1.0+
    all_env_ids = list(gym.envs.registry.keys())
    env_ids = [env_id for env_id in all_env_ids if 'asteroid' in env_id.lower()]
    print(f"Asteroid-related envs: {env_ids}")
    print(f"Total registered environments: {len(all_env_ids)}")
except Exception as e:
    print(f"Registry access failed: {e}")
    env_ids = []

# Method 1: Try to register ALE environments
try:
    print("\n--- Method 1: gym.register_envs(ale_py) ---")
    gym.register_envs(ale_py)
    print("✅ ALE environments registered")

    # Check what got registered after ALE registration
    all_env_ids = list(gym.envs.registry.keys())
    ale_envs = [env_id for env_id in all_env_ids if 'ALE' in env_id and 'asteroid' in env_id.lower()]
    print(f"ALE Asteroid environments: {ale_envs}")

    # Also check for any asteroid environments
    asteroid_envs = [env_id for env_id in all_env_ids if 'asteroid' in env_id.lower()]
    print(f"All Asteroid environments: {asteroid_envs}")

    if ale_envs:
        env_name = ale_envs[0]
        print(f"Trying to create: {env_name}")
        env = gym.make(env_name)
        print(f"✅ Successfully created: {env_name}")
        obs, info = env.reset()
        print(f"✅ Observation shape: {obs.shape}")
        print(f"✅ Action space: {env.action_space}")
        env.close()
        working_env = env_name
    elif asteroid_envs:
        env_name = asteroid_envs[0]
        print(f"Trying alternative: {env_name}")
        env = gym.make(env_name)
        print(f"✅ Successfully created: {env_name}")
        obs, info = env.reset()
        env.close()
        working_env = env_name
    else:
        raise Exception("No Asteroid environments found after registration")

except Exception as e:
    print(f"❌ Method 1 failed: {e}")

    # Method 2: Try direct ALE interface
    try:
        print("\n--- Method 2: Direct ALE interface ---")

        ale = ale_py.ALEInterface()
        print("✅ ALE interface created")

        # Check available ROMs
        available_roms = ale.getAvailableRoms()
        print(f"Available ROMs: {available_roms}")

        if 'asteroids' in available_roms:
            print("✅ Asteroids ROM is available!")

            # Try to load it
            ale.loadROM('asteroids')
            print("✅ Successfully loaded Asteroids ROM")

            # Get basic info
            actions = ale.getAvailableActions()
            print(f"✅ Available actions: {actions}")

            # Get screen
            screen = ale.getScreen()
            print(f"✅ Screen shape: {screen.shape}")

            working_env = 'direct_ale_asteroids'
        else:
            print(f"❌ Asteroids not in available ROMs: {available_roms}")
            working_env = None

    except Exception as e2:
        print(f"❌ Method 2 failed: {e2}")

        # Method 3: ROM file system check
        try:
            print("\n--- Method 3: File system ROM check ---")

            possible_paths = [
                "/usr/local/lib/python3.11/dist-packages/AutoROM/roms",
                "/usr/local/lib/python3.11/dist-packages/ale_py/roms",
                "/usr/local/lib/python3.11/dist-packages/atari_py/atari_roms",
                "/root/.local/share/ale-py/roms"
            ]

            rom_found = False
            for path in possible_paths:
                print(f"Checking: {path}")
                if os.path.exists(path):
                    files = os.listdir(path)
                    print(f"  Files found: {len(files)}")

                    # Look for asteroid-related files
                    asteroid_files = [f for f in files if 'asteroid' in f.lower()]
                    if asteroid_files:
                        print(f"  ✅ Asteroid ROM files: {asteroid_files}")
                        rom_found = True

                    # Show first few files as sample
                    if files:
                        print(f"  Sample files: {files[:3]}...")
                else:
                    print(f"  ❌ Path doesn't exist")

            if rom_found:
                working_env = 'roms_found_filesystem'
            else:
                working_env = None

        except Exception as e3:
            print(f"❌ Method 3 failed: {e3}")
            working_env = None

print(f"\n🎯 Final result: {working_env}")

# Additional debugging info
print("\n📋 Debug Summary:")
try:
    import ale_py
    print(f"ALE-py version: {ale_py.__version__}")

    import gymnasium
    print(f"Gymnasium version: {gymnasium.__version__}")

    # Try to create any ALE environment to test basic functionality
    print("\nTesting basic ALE functionality...")
    ale = ale_py.ALEInterface()
    roms = ale.getAvailableRoms()
    print(f"Total ROMs available: {len(roms)}")

    if roms:
        # Try loading any available ROM
        test_rom = roms[0]
        print(f"Testing with ROM: {test_rom}")
        ale.loadROM(test_rom)
        print(f"✅ Successfully loaded {test_rom}")

except Exception as e:
    print(f"Debug info failed: {e}")

if working_env:
    print("\n🚀 Environment setup successful! Ready for CleanRL.")
else:
    print("\n⚠️ Need to investigate ROM installation further.")

🔍 Testing Asteroids environment with corrected API...
Registered environments:
Asteroid-related envs: []
Total registered environments: 50

--- Method 1: gym.register_envs(ale_py) ---
❌ Method 1 failed: module 'gymnasium' has no attribute 'register_envs'

--- Method 2: Direct ALE interface ---
✅ ALE interface created
❌ Method 2 failed: 'ale_py._ale_py.ALEInterface' object has no attribute 'getAvailableRoms'

--- Method 3: File system ROM check ---
Checking: /usr/local/lib/python3.11/dist-packages/AutoROM/roms
  Files found: 110
  ✅ Asteroid ROM files: ['asteroids.bin']
  Sample files: ['hangman.bin', 'assault.bin', 'gopher.bin']...
Checking: /usr/local/lib/python3.11/dist-packages/ale_py/roms
  Files found: 113
  ✅ Asteroid ROM files: ['asteroids.bin']
  Sample files: ['hangman.bin', 'plugins.py', 'assault.bin']...
Checking: /usr/local/lib/python3.11/dist-packages/atari_py/atari_roms
  ❌ Path doesn't exist
Checking: /root/.local/share/ale-py/roms
  ❌ Path doesn't exist

🎯 Final result:

In [2]:
import os
import random
import time
from dataclasses import dataclass
from typing import Optional

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter


@dataclass
class Args:
    """Hyperparameters for Asteroids DQN"""
    exp_name: str = "asteroids_dqn"
    seed: int = 1
    torch_deterministic: bool = True
    cuda: bool = True

    # Environment
    env_id: str = "AsteroidsNoFrameskip-v4"

    # Training
    total_timesteps: int = 1000000
    learning_rate: float = 1e-4
    buffer_size: int = 100000
    gamma: float = 0.99
    target_network_frequency: int = 1000
    batch_size: int = 32
    start_e: float = 1.0
    end_e: float = 0.01
    exploration_fraction: float = 0.10
    learning_starts: int = 100000
    train_frequency: int = 4


class CustomQNetwork(nn.Module):
    """Custom CNN for Asteroids - you can modify this architecture!"""

    def __init__(self, env, hidden_size=512):
        super().__init__()
        self.env = env

        # CNN Feature Extractor - CUSTOMIZE THIS!
        self.network = nn.Sequential(
            # First conv layer: detect basic shapes
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),

            # Second conv layer: detect movement patterns
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),

            # Third conv layer: complex spatial relationships
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),

            nn.Flatten(),
        )

        # Calculate conv output size
        with torch.no_grad():
            sample_input = torch.zeros(1, 4, 84, 84)
            conv_output_size = self.network(sample_input).shape[1]

        # Value head - CUSTOMIZE THIS!
        self.value_head = nn.Sequential(
            nn.Linear(conv_output_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, env.single_action_space.n)
        )

    def forward(self, x):
        features = self.network(x / 255.0)  # Normalize pixels
        return self.value_head(features)


def make_env(env_id, seed, idx, capture_video, run_name):
    """Create Asteroids environment with proper wrappers"""
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)

        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)

        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)

        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)

        env.action_space.seed(seed)
        return env
    return thunk


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    """Linear decay for exploration"""
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


def train_asteroids_dqn():
    """Main training function"""
    args = Args()

    # Setup
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    writer = SummaryWriter(f"runs/{run_name}")

    # Seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # Environment
    envs = gym.vector.SyncVectorEnv([
        make_env(args.env_id, args.seed, 0, True, run_name)
    ])

    # Networks
    q_network = CustomQNetwork(envs, hidden_size=512).to(device)
    optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
    target_network = CustomQNetwork(envs, hidden_size=512).to(device)
    target_network.load_state_dict(q_network.state_dict())

    # Experience Replay
    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        handle_timeout_termination=False,
    )

    # Training Loop
    obs, _ = envs.reset(seed=args.seed)
    episode_rewards = []
    episode_lengths = []

    for global_step in range(args.total_timesteps):
        # Exploration rate
        epsilon = linear_schedule(
            args.start_e, args.end_e,
            args.exploration_fraction * args.total_timesteps,
            global_step
        )

        # Action selection
        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample()])
        else:
            q_values = q_network(torch.Tensor(obs).to(device))
            actions = torch.argmax(q_values, dim=1).cpu().numpy()

        # Environment step
        next_obs, rewards, terminations, truncations, infos = envs.step(actions)

        # Track metrics
        if "final_info" in infos:
            for info in infos["final_info"]:
                if info and "episode" in info:
                    episode_rewards.append(info["episode"]["r"])
                    episode_lengths.append(info["episode"]["l"])
                    writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                    writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                    print(f"Step {global_step}, Episode reward: {info['episode']['r']}")

        # Store experience
        real_next_obs = next_obs.copy()
        for idx, d in enumerate(terminations):
            if d:
                real_next_obs[idx] = infos["final_observation"][idx]

        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
        obs = next_obs

        # Training
        if global_step > args.learning_starts:
            if global_step % args.train_frequency == 0:
                data = rb.sample(args.batch_size)

                with torch.no_grad():
                    target_max, _ = target_network(data.next_observations).max(dim=1)
                    td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())

                old_val = q_network(data.observations).gather(1, data.actions).squeeze()
                loss = F.mse_loss(td_target, old_val)

                if global_step % 1000 == 0:
                    writer.add_scalar("losses/td_loss", loss, global_step)
                    writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                    writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

                # Optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Update target network
        if global_step % args.target_network_frequency == 0:
            target_network.load_state_dict(q_network.state_dict())

    # Save model
    model_path = f"models/{run_name}.pt"
    os.makedirs("models", exist_ok=True)
    torch.save(q_network.state_dict(), model_path)

    envs.close()
    writer.close()

    return q_network, episode_rewards, model_path


def evaluate_model(model_path, num_episodes=10):
    """Evaluate trained model"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    env = gym.vector.SyncVectorEnv([make_env("AsteroidsNoFrameskip-v4", 42, 0, False, "eval")])
    q_network = CustomQNetwork(env).to(device)
    q_network.load_state_dict(torch.load(model_path))
    q_network.eval()

    # Evaluate
    obs, _ = env.reset()
    episode_rewards = []
    episode_lengths = []
    current_reward = 0
    current_length = 0

    for step in range(50000):  # Max steps for evaluation
        with torch.no_grad():
            q_values = q_network(torch.Tensor(obs).to(device))
            actions = torch.argmax(q_values, dim=1).cpu().numpy()

        obs, rewards, terminations, truncations, infos = env.step(actions)
        current_reward += rewards[0]
        current_length += 1

        if terminations[0] or truncations[0]:
            episode_rewards.append(current_reward)
            episode_lengths.append(current_length)
            print(f"Episode {len(episode_rewards)}: {current_reward} points, {current_length} steps")

            current_reward = 0
            current_length = 0
            obs, _ = env.reset()

            if len(episode_rewards) >= num_episodes:
                break

    env.close()

    # Results
    avg_reward = np.mean(episode_rewards)
    avg_length = np.mean(episode_lengths)

    print(f"\n📊 Evaluation Results:")
    print(f"Average Score: {avg_reward:.1f}")
    print(f"Average Length: {avg_length:.0f}")
    print(f"Best Score: {max(episode_rewards)}")
    print(f"Episodes: {len(episode_rewards)}")

    return episode_rewards, episode_lengths


if __name__ == "__main__":
    start_time = time.time()
    print("🚀 Starting Custom Asteroids DQN Training...")

    # Train the model
    model, rewards, model_path = train_asteroids_dqn()

    print(f"✅ Training completed! Model saved to: {model_path}")
    print(f"Training time: {(time.time() - start_time)/60:.1f} minutes")

    # Evaluate the model
    print("\n🎯 Evaluating trained model...")
    eval_rewards, eval_lengths = evaluate_model(model_path)

🚀 Starting Custom Asteroids DQN Training...


NameNotFound: Environment `AsteroidsNoFrameskip` doesn't exist.