In [None]:
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import cv2
import random
import string
from collections import deque
from IPython.display import HTML, clear_output

# Utility functions
def preprocess(img):
    img = img[:84, 6:90]
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) / 255.0
    return img

def animate(imgs, video_name=None, _return=True):
    if video_name is None:
        video_name = ''.join(random.choice(string.ascii_letters) for i in range(18)) + '.webm'
    height, width, layers = imgs[0].shape
    fourcc = cv2.VideoWriter_fourcc(*'VP90')
    video = cv2.VideoWriter(video_name, fourcc, 10, (width, height))
    for img in imgs:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        video.write(img)
    video.release()
    if _return:
        from IPython.display import Video
        return Video(video_name)

def evaluate(agent, n_evals=5):
    eval_env = gym.make('CarRacing-v2', continuous=False, render_mode='rgb_array')
    eval_env = ImageEnv(eval_env)
    scores = 0
    for _ in range(n_evals):
        (s, _), done, ret = eval_env.reset(), False, 0
        while not done:
            a = agent.act(s, training=False)
            s_prime, r, terminated, truncated, _ = eval_env.step(a)
            s = s_prime
            ret += r
            done = terminated or truncated
        scores += ret
    return np.round(scores / n_evals, 4)

# ImageEnv Wrapper
class ImageEnv(gym.Wrapper):
    def __init__(self, env, skip_frames=4, stack_frames=4, initial_no_op=50, **kwargs):
        super(ImageEnv, self).__init__(env, **kwargs)
        self.initial_no_op = initial_no_op
        self.skip_frames = skip_frames
        self.stack_frames = stack_frames

    def reset(self):
        s, info = self.env.reset()
        for _ in range(self.initial_no_op):
            s, _, terminated, truncated, _ = self.env.step(0)
            if terminated or truncated:
                s, info = self.env.reset()
        s = preprocess(s)
        self.stacked_state = np.tile(s, (self.stack_frames, 1, 1))
        return self.stacked_state, info

    def step(self, action):
        reward = 0
        for _ in range(self.skip_frames):
            s, r, terminated, truncated, info = self.env.step(action)
            reward += r
            if terminated or truncated:
                break
        s = preprocess(s)
        self.stacked_state = np.concatenate((self.stacked_state[1:], s[np.newaxis]), axis=0)
        return self.stacked_state, reward, terminated, truncated, info

# CNNActionValue class
class CNNActionValue(nn.Module):
    def __init__(self, state_dim, action_dim, activation=F.relu):
        super(CNNActionValue, self).__init__()
        self.conv1 = nn.Conv2d(state_dim, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, action_dim)
        self.activation = activation

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

# ReplayBuffer class
class ReplayBuffer:
    def __init__(self, state_dim, action_dim, max_size=int(1e5)):
        self.s = np.zeros((max_size, *state_dim), dtype=np.float32)
        self.a = np.zeros((max_size, *action_dim), dtype=np.int64)
        self.r = np.zeros((max_size, 1), dtype=np.float32)
        self.s_prime = np.zeros((max_size, *state_dim), dtype=np.float32)
        self.terminated = np.zeros((max_size, 1), dtype=np.float32)
        self.ptr = 0
        self.size = 0
        self.max_size = max_size

    def update(self, s, a, r, s_prime, terminated):
        self.s[self.ptr] = s
        self.a[self.ptr] = a
        self.r[self.ptr] = r
        self.s_prime[self.ptr] = s_prime
        self.terminated[self.ptr] = terminated
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, batch_size)
        return (torch.FloatTensor(self.s[ind]), torch.FloatTensor(self.a[ind]), 
                torch.FloatTensor(self.r[ind]), torch.FloatTensor(self.s_prime[ind]), 
                torch.FloatTensor(self.terminated[ind]))

# DQN class
class DQN:
    def __init__(self, state_dim, action_dim, lr=0.00025, epsilon=1.0, epsilon_min=0.1, gamma=0.99, batch_size=32,
                 warmup_steps=5000, buffer_size=int(1e5), target_update_interval=10000):
        self.action_dim = action_dim
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = (epsilon - epsilon_min) / 1e6
        self.gamma = gamma
        self.batch_size = batch_size
        self.warmup_steps = warmup_steps
        self.target_update_interval = target_update_interval
        self.total_steps = 0

        self.network = CNNActionValue(state_dim[0], action_dim)
        self.target_network = CNNActionValue(state_dim[0], action_dim)
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = optim.RMSprop(self.network.parameters(), lr)
        self.buffer = ReplayBuffer(state_dim, (1,), buffer_size)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.network.to(self.device)
        self.target_network.to(self.device)

    @torch.no_grad()
    def act(self, x, training=True):
        self.network.train(training)
        if training and (np.random.rand() < self.epsilon or self.total_steps < self.warmup_steps):
            return np.random.randint(0, self.action_dim)
        x = torch.from_numpy(x).float().unsqueeze(0).to(self.device)
        return torch.argmax(self.network(x)).item()

    def learn(self):
        s, a, r, s_prime, terminated = map(lambda x: x.to(self.device), self.buffer.sample(self.batch_size))
        next_q = self.target_network(s_prime).detach()
        td_target = r + (1. - terminated) * self.gamma * next_q.max(dim=1, keepdim=True).values
        loss = F.mse_loss(self.network(s).gather(1, a.long()), td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {'total_steps': self.total_steps, 'value_loss': loss.item()}

    def process(self, transition):
        self.total_steps += 1
        self.buffer.update(*transition)
        if self.total_steps > self.warmup_steps:
            result = self.learn()
        if self.total_steps % self.target_update_interval == 0:
            self.target_network.load_state_dict(self.network.state_dict())
        self.epsilon = max(self.epsilon_min, self.epsilon - self.epsilon_decay)
        return result if self.total_steps > self.warmup_steps else {}

# Initialize environment and agent
env = gym.make('CarRacing-v2', continuous=False)
env = ImageEnv(env)

state_dim = (4, 84, 84)
action_dim = env.action_space.n
agent = DQN(state_dim, action_dim)

max_steps = 1e6
eval_interval = 10000

history = {'Step': [], 'AvgReturn': []}

(s, _) = env.reset()

# Training
while agent.total_steps <= max_steps:
    a = agent.act(s)
    s_prime, r, terminated, truncated, _ = env.step(a)
    result = agent.process((s, a, r, s_prime, terminated))
    s = s_prime
    if terminated or truncated:
        s, _ = env.reset()
    if agent.total_steps % eval_interval == 0:
        ret = evaluate(agent)
        history['Step'].append(agent.total_steps)
        history['AvgReturn'].append(ret)
        clear_output()
        plt.figure(figsize=(8, 5))
        plt.plot(history['Step'], history['AvgReturn'], 'r-')
        plt.xlabel('Step', fontsize=16)
        plt.ylabel('AvgReturn', fontsize=16)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        plt.grid(axis='y')
        plt.show()
        torch.save(agent.network.state_dict(), 'dqn.pt')

# Final evaluation and animation
eval_env = gym.make('CarRacing-v2', continuous=False, render_mode='rgb_array')
eval_env = ImageEnv(eval_env)

frames = []
(s, _), done, ret = eval_env.reset(), False, 0
while not done:
    frames.append(eval_env.render())
    a = agent.act(s, training=False)
    s_prime, r, terminated, truncated, _ = eval_env.step(a)
    s = s_prime
    ret += r
    done = terminated or truncated

animate(frames, 'animation.webm')


In [None]:
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import cv2
import random
import string
from collections import deque
from IPython.display import HTML, clear_output

# Utility functions
def preprocess(img):
    img = img[:84, 6:90]
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) / 255.0
    return img

def animate(imgs, video_name=None, _return=True):
    if video_name is None:
        video_name = ''.join(random.choice(string.ascii_letters) for i in range(18)) + '.webm'
    height, width, layers = imgs[0].shape
    fourcc = cv2.VideoWriter_fourcc(*'VP90')
    video = cv2.VideoWriter(video_name, fourcc, 10, (width, height))
    for img in imgs:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        video.write(img)
    video.release()
    if _return:
        from IPython.display import Video
        return Video(video_name)

def evaluate(agent, n_evals=5):
    eval_env = gym.make('CarRacing-v2', continuous=False, render_mode='rgb_array')
    eval_env = ImageEnv(eval_env)
    scores = 0
    for _ in range(n_evals):
        (s, _), done, ret = eval_env.reset(), False, 0
        while not done:
            a = agent.act(s, training=False)
            s_prime, r, terminated, truncated, _ = eval_env.step(a)
            s = s_prime
            ret += r
            done = terminated or truncated
        scores += ret
    return np.round(scores / n_evals, 4)

# ImageEnv Wrapper
class ImageEnv(gym.Wrapper):
    def __init__(self, env, skip_frames=4, stack_frames=4, initial_no_op=50, **kwargs):
        super(ImageEnv, self).__init__(env, **kwargs)
        self.initial_no_op = initial_no_op
        self.skip_frames = skip_frames
        self.stack_frames = stack_frames

    def reset(self):
        s, info = self.env.reset()
        for _ in range(self.initial_no_op):
            s, _, terminated, truncated, _ = self.env.step(0)
            if terminated or truncated:
                s, info = self.env.reset()
        s = preprocess(s)
        self.stacked_state = np.tile(s, (self.stack_frames, 1, 1))
        return self.stacked_state, info

    def step(self, action):
        reward = 0
        for _ in range(self.skip_frames):
            s, r, terminated, truncated, info = self.env.step(action)
            reward += r
            if terminated or truncated:
                break
        s = preprocess(s)
        self.stacked_state = np.concatenate((self.stacked_state[1:], s[np.newaxis]), axis=0)
        return self.stacked_state, reward, terminated, truncated, info

# CNNActionValue class
class CNNActionValue(nn.Module):
    def __init__(self, state_dim, action_dim, activation=F.relu):
        super(CNNActionValue, self).__init__()
        self.conv1 = nn.Conv2d(state_dim, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, action_dim)
        self.activation = activation

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

# ReplayBuffer class
class ReplayBuffer:
    def __init__(self, state_dim, action_dim, max_size=int(1e5)):
        self.s = np.zeros((max_size, *state_dim), dtype=np.float32)
        self.a = np.zeros((max_size, *action_dim), dtype=np.int64)
        self.r = np.zeros((max_size, 1), dtype=np.float32)
        self.s_prime = np.zeros((max_size, *state_dim), dtype=np.float32)
        self.terminated = np.zeros((max_size, 1), dtype=np.float32)
        self.ptr = 0
        self.size = 0
        self.max_size = max_size

    def update(self, s, a, r, s_prime, terminated):
        self.s[self.ptr] = s
        self.a[self.ptr] = a
        self.r[self.ptr] = r
        self.s_prime[self.ptr] = s_prime
        self.terminated[self.ptr] = terminated
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, batch_size)
        return (torch.FloatTensor(self.s[ind]), torch.FloatTensor(self.a[ind]), 
                torch.FloatTensor(self.r[ind]), torch.FloatTensor(self.s_prime[ind]), 
                torch.FloatTensor(self.terminated[ind]))

# DQN class
class DQN:
    def __init__(self, state_dim, action_dim, lr=0.00025, epsilon=1.0, epsilon_min=0.1, gamma=0.99, batch_size=32,
                 warmup_steps=5000, buffer_size=int(1e5), target_update_interval=10000):
        self.action_dim = action_dim
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = (epsilon - epsilon_min) / 1e6
        self.gamma = gamma
        self.batch_size = batch_size
        self.warmup_steps = warmup_steps
        self.target_update_interval = target_update_interval
        self.buffer = ReplayBuffer(state_dim, [1], buffer_size)
        self.network = CNNActionValue(state_dim[0], action_dim)
        self.target_network = CNNActionValue(state_dim[0], action_dim)
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.network.to(self.device)
        self.target_network.to(self.device)
        self.steps = 0

    def act(self, s, training=True):
        if training and np.random.rand() < self.epsilon:
            return np.random.randint(self.action_dim)
        s = torch.FloatTensor(s).unsqueeze(0).to(self.device)
        with torch.no_grad():
            return self.network(s).argmax().item()

    def update(self):
        if self.steps < self.warmup_steps:
            return
        s, a, r, s_prime, terminated = self.buffer.sample(self.batch_size)
        s, a, r, s_prime, terminated = s.to(self.device), a.to(self.device), r.to(self.device), s_prime.to(self.device), terminated.to(self.device)
        q_values = self.network(s).gather(1, a.long())
        with torch.no_grad():
            next_q_values = self.target_network(s_prime).max(1, keepdim=True)[0]
            target = r + self.gamma * next_q_values * (1 - terminated)
        loss = self.loss(q_values, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.steps += 1
        if self.steps % self.target_update_interval == 0:
            self.target_network.load_state_dict(self.network.state_dict())
        self.epsilon = max(self.epsilon_min, self.epsilon - self.epsilon_decay)

# Training loop
env = gym.make('CarRacing-v2', continuous=False, render_mode='rgb_array')
env = ImageEnv(env)

state_dim = [4, 84, 84]
action_dim = 5
lr = 0.00025
epsilon = 1.0
epsilon_min = 0.1
gamma = 0.99
batch_size = 32
warmup_steps = 5000
buffer_size = int(1e5)
target_update_interval = 10000
n_episodes = 1000
max_steps = 1000

agent = DQN(state_dim, action_dim, lr, epsilon, epsilon_min, gamma, batch_size, warmup_steps, buffer_size, target_update_interval)

history = {'Step': [], 'AvgReturn': []}
eval_interval = 10
returns = deque(maxlen=100)

for episode in range(1, n_episodes + 1):
    (s, _), done, ret = env.reset(), False, 0
    for step in range(max_steps):
        a = agent.act(s)
        s_prime, r, terminated, truncated, _ = env.step(a)
        done = terminated or truncated
        agent.buffer.update(s, a, r, s_prime, terminated)
        agent.update()
        s = s_prime
        ret += r
        if done:
            break
    returns.append(ret)
    if episode % eval_interval == 0:
        avg_return = np.mean(returns)
        history['Step'].append(agent.steps)
        history['AvgReturn'].append(avg_return)
        clear_output(wait=True)
        plt.figure(figsize=(12, 8))
        plt.plot(history['Step'], history['AvgReturn'], 'r-')
        plt.xlabel('Step', fontsize=16)
        plt.ylabel('AvgReturn', fontsize=16)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        plt.grid(axis='y')
        plt.show()
        torch.save(agent.network.state_dict(), 'dqn.pt')

# Final evaluation and animation
eval_env = gym.make('CarRacing-v2', continuous=False, render_mode='rgb_array')
eval_env = ImageEnv(eval_env)

frames = []
(s, _), done, ret = eval_env.reset(), False, 0
while not done:
    frames.append(eval_env.render())
    a = agent.act(s, training=False)
    s_prime, r, terminated, truncated, _ = eval_env.step(a)
    s = s_prime
    ret += r
    done = terminated or truncated

animate(frames, 'animation.webm')


In [None]:
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import cv2
import random
import string
from collections import deque
from IPython.display import HTML, clear_output

# Utility functions
def preprocess(img):
    img = img[:84, 6:90]
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) / 255.0
    return img

def animate(imgs, video_name=None, _return=True):
    if video_name is None:
        video_name = ''.join(random.choice(string.ascii_letters) for i in range(18)) + '.webm'
    height, width, layers = imgs[0].shape
    fourcc = cv2.VideoWriter_fourcc(*'VP90')
    video = cv2.VideoWriter(video_name, fourcc, 10, (width, height))
    for img in imgs:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        video.write(img)
    video.release()
    if _return:
        from IPython.display import Video
        return Video(video_name)

def evaluate(agent, n_evals=5):
    eval_env = gym.make('CarRacing-v2', continuous=False, render_mode='rgb_array')
    eval_env = ImageEnv(eval_env)
    scores = 0
    all_frames = []
    for _ in range(n_evals):
        (s, _), done, ret = eval_env.reset(), False, 0
        frames = []
        while not done:
            a = agent.act(s, training=False)
            s_prime, r, terminated, truncated, _ = eval_env.step(a)
            frames.append(eval_env.render())
            s = s_prime
            ret += r
            done = terminated or truncated
        scores += ret
        all_frames.append(frames)
    avg_score = np.round(scores / n_evals, 4)
    return avg_score, all_frames

# ImageEnv Wrapper
class ImageEnv(gym.Wrapper):
    def __init__(self, env, skip_frames=4, stack_frames=4, initial_no_op=50, **kwargs):
        super(ImageEnv, self).__init__(env, **kwargs)
        self.initial_no_op = initial_no_op
        self.skip_frames = skip_frames
        self.stack_frames = stack_frames

    def reset(self):
        s, info = self.env.reset()
        for _ in range(self.initial_no_op):
            s, _, terminated, truncated, _ = self.env.step(0)
            if terminated or truncated:
                s, info = self.env.reset()
        s = preprocess(s)
        self.stacked_state = np.tile(s, (self.stack_frames, 1, 1))
        return self.stacked_state, info

    def step(self, action):
        reward = 0
        for _ in range(self.skip_frames):
            s, r, terminated, truncated, info = self.env.step(action)
            reward += r
            if terminated or truncated:
                break
        s = preprocess(s)
        self.stacked_state = np.concatenate((self.stacked_state[1:], s[np.newaxis]), axis=0)
        return self.stacked_state, reward, terminated, truncated, info

# CNNActionValue class
class CNNActionValue(nn.Module):
    def __init__(self, state_dim, action_dim, activation=F.relu):
        super(CNNActionValue, self).__init__()
        self.conv1 = nn.Conv2d(state_dim, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, action_dim)
        self.activation = activation

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

# ReplayBuffer class
class ReplayBuffer:
    def __init__(self, state_dim, action_dim, max_size=int(1e5)):
        self.s = np.zeros((max_size, *state_dim), dtype=np.float32)
        self.a = np.zeros((max_size, *action_dim), dtype=np.int64)
        self.r = np.zeros((max_size, 1), dtype=np.float32)
        self.s_prime = np.zeros((max_size, *state_dim), dtype=np.float32)
        self.terminated = np.zeros((max_size, 1), dtype=np.float32)
        self.ptr = 0
        self.size = 0
        self.max_size = max_size

    def update(self, s, a, r, s_prime, terminated):
        self.s[self.ptr] = s
        self.a[self.ptr] = a
        self.r[self.ptr] = r
        self.s_prime[self.ptr] = s_prime
        self.terminated[self.ptr] = terminated
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, batch_size)
        return (torch.FloatTensor(self.s[ind]), torch.FloatTensor(self.a[ind]), 
                torch.FloatTensor(self.r[ind]), torch.FloatTensor(self.s_prime[ind]), 
                torch.FloatTensor(self.terminated[ind]))

# DQN class
class DQN:
    def __init__(self, state_dim, action_dim, lr=0.00025, epsilon=1.0, epsilon_min=0.1, gamma=0.99, batch_size=32,
                 warmup_steps=5000, buffer_size=int(1e5), target_update_interval=10000):
        self.action_dim = action_dim
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = (epsilon - epsilon_min) / 1e6
        self.gamma = gamma
        self.batch_size = batch_size
        self.warmup_steps = warmup_steps
        self.target_update_interval = target_update_interval
        self.buffer = ReplayBuffer(state_dim, [1], buffer_size)
        self.network = CNNActionValue(state_dim[0], action_dim)
        self.target_network = CNNActionValue(state_dim[0], action_dim)
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.network.to(self.device)
        self.target_network.to(self.device)
        self.steps = 0

    def act(self, s, training=True):
        if training and np.random.rand() < self.epsilon:
            return np.random.randint(self.action_dim)
        s = torch.FloatTensor(s).unsqueeze(0).to(self.device)
        with torch.no_grad():
            return self.network(s).argmax().item()

    def update(self):
        if self.steps < self.warmup_steps:
            return
        s, a, r, s_prime, terminated = self.buffer.sample(self.batch_size)
        s, a, r, s_prime, terminated = s.to(self.device), a.to(self.device), r.to(self.device), s_prime.to(self.device), terminated.to(self.device)
        q_values = self.network(s).gather(1, a.long())
        with torch.no_grad():
            next_q_values = self.target_network(s_prime).max(1, keepdim=True)[0]
            target = r + self.gamma * next_q_values * (1 - terminated)
        loss = self.loss(q_values, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.steps += 1
        if self.steps % self.target_update_interval == 0:
            self.target_network.load_state_dict(self.network.state_dict())
        self.epsilon = max(self.epsilon_min, self.epsilon - self.epsilon_decay)

    def save(self, filepath):
        torch.save({
            'network_state_dict': self.network.state_dict(),
            'target_network_state_dict': self.target_network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'steps': self.steps,
            'epsilon': self.epsilon,
            'buffer': self.buffer,
        }, filepath)

    def load(self, filepath):
        checkpoint = torch.load(filepath)
        self.network.load_state_dict(checkpoint['network_state_dict'])
        self.target_network.load_state_dict(checkpoint['target_network_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.steps = checkpoint['steps']
        self.epsilon = checkpoint['epsilon']
        self.buffer = checkpoint['buffer']

# Training loop
env = gym.make('CarRacing-v2', continuous=False, render_mode='rgb_array')
env = ImageEnv(env)

state_dim = [4, 84, 84]
action_dim = env.action_space.n

agent = DQN(state_dim, action_dim)

# Load the model if exists
load_model = True
model_path = 'dqn_car_racing.pth'
if load_model:
    try:
        agent.load(model_path)
        print(f"Loaded model from {model_path}")
    except FileNotFoundError:
        print(f"No model found at {model_path}, starting from scratch")

n_episodes = 1000
save_interval = 10
eval_interval = 50

episode_rewards = []

for episode in range(n_episodes):
    (s, _), done, ret = env.reset(), False, 0
    while not done:
        a = agent.act(s)
        s_prime, r, terminated, truncated, _ = env.step(a)
        agent.buffer.update(s, a, r, s_prime, terminated)
        agent.update()
        s = s_prime
        ret += r
        done = terminated or truncated
    episode_rewards.append(ret)
    print(f"Episode {episode + 1}: {ret}")

    # Save the model at intervals
    if (episode + 1) % save_interval == 0:
        agent.save(model_path)
        print(f"Model saved at episode {episode + 1}")

    # Evaluate the model at intervals and save animation
    if (episode + 1) % eval_interval == 0:
        avg_score, all_frames = evaluate(agent)
        print(f"Evaluation at episode {episode + 1}: Avg. score: {avg_score}")
        for idx, frames in enumerate(all_frames):
            animate(frames, f'eval_{episode + 1}_{idx + 1}.webm', _return=False)

        # Plot training chart
        clear_output(wait=True)
        plt.figure(figsize=(12, 8))
        plt.plot(episode_rewards, label='Episode Reward')
        plt.xlabel('Episodes')
        plt.ylabel('Reward')
        plt.title('Training Progress')
        plt.legend()
        plt.show()

