In [1]:
# imports from libraries
import gymnasium as gym
import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import os
import csv
import time

# imports from modules
from agents.dqn_agent import Agent
from utils.scheduler import HyperparameterScheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# helper function for preprocessing
def preprocess(obs):
    obs = np.mean(obs, axis=2)
    obs = np.expand_dims(obs, axis=0)
    return torch.FloatTensor([obs]).to(device)

In [None]:
# helper function for validation
def validate(agent, env, num_episodes=5):
    """Run validation episodes without training/exploration"""
    validation_rewards = []
    
    for _ in range(num_episodes):
        obs, _ = env.reset()
        state = preprocess(obs)
        episode_reward = 0
        done = False
        truncated = False
        
        while not (done or truncated):
            with torch.no_grad():
                action = agent.select_action(state, eval_mode=True)
            obs, reward, done, truncated, _ = env.step(action)
            next_state = preprocess(obs)
            state = next_state
            episode_reward += reward
            
        validation_rewards.append(episode_reward)
    
    return np.mean(validation_rewards)

In [4]:
# helper function for setting up experiment
def setup_experiment():
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    exp_dir = f'experiment_{timestamp}'
    os.makedirs(exp_dir, exist_ok=True)
    
    log_file = os.path.join(exp_dir, 'training_log.csv')
    with open(log_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Episode', 'Train_Reward', 'Val_Reward', 'Epsilon', 'Learning_Rate'])
    
    return exp_dir, log_file

In [None]:
# training function
def train(agent, train_env, val_env, episodes, exp_dir, log_file, eval_freq=100):
    train_rewards = []
    val_rewards = []
    best_val_reward = float('-inf')
    
    for episode in range(episodes):
        obs, _ = train_env.reset()
        state = preprocess(obs)
        total_reward = 0
        done = False
        truncated = False
        
        while not (done or truncated):
            action = agent.select_action(state)
            obs, reward, done, truncated, _ = train_env.step(action)
            
            next_state = preprocess(obs)
            reward = torch.tensor([reward], device=device)
            
            # Convert continuous action back to discrete
            action_idx = 0
            if np.array_equal(action, [-1.0, 0.0, 0.0]):
                action_idx = 1
            elif np.array_equal(action, [1.0, 0.0, 0.0]):
                action_idx = 2
            elif np.array_equal(action, [0.0, 1.0, 0.0]):
                action_idx = 3
            elif np.array_equal(action, [0.0, 0.0, 0.8]):
                action_idx = 4
            
            agent.memory.push(state, action_idx, reward, next_state, done)
            state = next_state
            total_reward += reward.item()
            
            agent.optimize_model()
        
        train_rewards.append(total_reward)
        
        # Run validation episodes
        if episode % eval_freq == 0:
            val_reward = validate(agent, val_env)
            val_rewards.append(val_reward)
            print(f'Episode {episode}/{episodes}: Train reward: {total_reward:.2f}, Val reward: {val_reward:.2f}')
            
            # Save if validation improved
            if val_reward > best_val_reward:
                best_val_reward = val_reward
                model_path = os.path.join(exp_dir, f'best_model.pth')
                torch.save({
                    'episode': episode,
                    'model_state_dict': agent.policy_net.state_dict(),
                    'optimizer_state_dict': agent.optimizer.state_dict(),
                    'val_reward': val_reward,
                    'train_reward': total_reward,
                    'epsilon': agent.epsilon,
                }, model_path)
                print(f'New best model saved with validation reward: {val_reward:.2f}')
        
        if episode % agent.target_update == 0:
            agent.target_net.load_state_dict(agent.policy_net.state_dict())
        
        # Update learning rate if needed
        lr_changed = agent.scheduler.update(total_reward, agent)
        current_lr = agent.optimizer.param_groups[0]['lr']
        
        # Log to file
        with open(log_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([episode, total_reward, val_reward if episode % eval_freq == 0 else '', 
                           agent.epsilon, current_lr])
    
    return train_rewards, val_rewards

In [None]:
# create agent
n_episodes = 2000 # takes about 16.6 hours on a 7900xtx

# initialize environments with discrete actions
train_env = gym.make('CarRacing-v3', continuous=True)
val_env = gym.make('CarRacing-v3', continuous=True)

train_env.reset(seed=42)
val_env.reset(seed=420)

In [None]:
# setup experiment
exp_dir, log_file = setup_experiment()

# initialize agent with 5 discrete actions
agent = Agent(state_shape=(1, 96, 96), n_actions=5, scheduler=HyperparameterScheduler())

In [None]:
# training
start_time = time.time()
try:
    train_rewards, val_rewards = train(
        agent=agent,
        train_env=train_env,
        val_env=val_env,
        episodes=n_episodes,
        exp_dir=exp_dir,
        log_file=log_file,
    )

    # final model path
    final_path = os.path.join(exp_dir, 'final_model.pth')
    torch.save({
        'model_state_dict': agent.policy_net.state_dict(),
        'optimizer_state_dict': agent.optimizer.state_dict(),
        'train_rewards': train_rewards,
        'val_rewards': val_rewards,
        'total_time': time.time() - start_time,
    }, final_path)

    print(f"Training complete! Took {(time.time() - start_time) / 3600:.2f} hours")
    print(f"Check {exp_dir} for results")

except Exception as e:
    print(f"An error occurred: {e}")
    
    # emergency save in case of error
    try:
        emergency_path = os.path.join(exp_dir, 'emergency_model.pth')
        torch.save({
            'model_state_dict': agent.policy_net.state_dict(),
            'optimizer_state_dict': agent.optimizer.state_dict(),
        }, emergency_path)
    except:
        print(f"Failed to save emergency model: {e}")
    raise e

finally:
    train_env.close()
    val_env.close()