# Race Car Behavior Optimization using Soft Actor Critic with Chain of Thoughts as Memory Carrier

In [7]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import gym
from gym import RewardWrapper
from gym.wrappers import GrayScaleObservation, ResizeObservation, FrameStack


np.bool8 = np.bool_

## Experiments

### Experiment-1: SAC with Replay Buffer

Got a **Reward** of 605.15 after convergence. But, Can we do better? -- **This question inspired us to read about Prioritized Replay Buffer.**

In [None]:
# ########### Replay buffer ###########
# class ReplayBuffer:
#     def __init__(self, capacity):
#         self.capacity = capacity
#         self.buffer = []
#         self.position = 0

#     def push(self, state, action, reward, next_state, done):
#         if len(self.buffer) < self.capacity:
#             self.buffer.append(None)
#         self.buffer[self.position] = (state, action, reward, next_state, done)
#         self.position = (self.position + 1) % self.capacity

#     def sample(self, batch_size):
#         batch = random.sample(self.buffer, batch_size)
#         states, actions, rewards, next_states, dones = zip(*batch)
#         return (
#             np.array(states),
#             np.array(actions),
#             np.array(rewards, dtype=np.float32),
#             np.array(next_states),
#             np.array(dones, dtype=np.float32),
#         )

#     def __len__(self):
#         return len(self.buffer)

# ########### Networks ###########
# class Actor(nn.Module):
#     def __init__(self, state_shape, action_dim):
#         super().__init__()
#         c, h, w = state_shape
#         self.conv = nn.Sequential(
#             nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
#             nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
#             nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
#         )
#         def conv_out(sz, k, s): return (sz - (k-1) - 1)//s + 1
#         convw = conv_out(conv_out(conv_out(w, 8, 4), 4, 2), 3, 1)
#         convh = conv_out(conv_out(conv_out(h, 8, 4), 4, 2), 3, 1)
#         lin = convw * convh * 64

#         self.fc = nn.Sequential(
#             nn.Linear(lin, 512), nn.ReLU(),
#             nn.Linear(512, 512), nn.ReLU()
#         )
#         self.mean    = nn.Linear(512, action_dim)
#         self.log_std = nn.Linear(512, action_dim)

#     def forward(self, x):
#         x = self.conv(x).flatten(1)
#         x = self.fc(x)
#         return self.mean(x), self.log_std(x).clamp(-20, 2)

#     def sample(self, x):
#         mean, log_std = self.forward(x)
#         std = log_std.exp()
#         dist = Normal(mean, std)
#         z = dist.rsample()
#         raw_action = torch.tanh(z)  # in [-1,1]

#         # carve out steer, gas, brake:
#         steer = raw_action[:, 0:1]
#         gas   = (raw_action[:, 1:2] + 1.0) / 2.0   # map [-1,1] → [0,1]
#         brake = (raw_action[:, 2:3] + 1.0) / 2.0   # map [-1,1] → [0,1]
#         action = torch.cat([steer, gas, brake], dim=1)

#         logp = (
#             dist.log_prob(z)
#             - torch.log(1 - raw_action.pow(2) + 1e-6)
#         ).sum(1, keepdim=True)
#         return action, logp

# class Critic(nn.Module):
#     def __init__(self, state_shape, action_dim):
#         super().__init__()
#         c, h, w = state_shape
#         self.conv = nn.Sequential(
#             nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
#             nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
#             nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
#         )
#         def conv_out(sz, k, s): return (sz - (k-1) - 1)//s + 1
#         convw = conv_out(conv_out(conv_out(w, 8, 4), 4, 2), 3, 1)
#         convh = conv_out(conv_out(conv_out(h, 8, 4), 4, 2), 3, 1)
#         lin = convw * convh * 64

#         self.fc = nn.Sequential(
#             nn.Linear(lin + action_dim, 512), nn.ReLU(),
#             nn.Linear(512, 512), nn.ReLU(),
#             nn.Linear(512, 1)
#         )

#     def forward(self, x, a):
#         x = self.conv(x).flatten(1)
#         return self.fc(torch.cat([x, a], dim=1))

# ########### Utils ###########
# def soft_update(tgt, src, tau):
#     for t, s in zip(tgt.parameters(), src.parameters()):
#         t.data.copy_(tau * s.data + (1 - tau) * t.data)

# def hard_update(tgt, src):
#     tgt.load_state_dict(src.state_dict())


### Experiment-2: SAC with Prioritized Experience Replay/ Replay Buffer

Got a **Reward** of 791.11 after convergence of 2000 episodes. So, we analyzed that this is a very good improvement from the previous experimentation. Still, we weren't satisfied with the results. A question was still there. Can we do better?

In [None]:
# class PrioritizedReplayBuffer:
#     def __init__(self, capacity, state_shape, action_shape, device,
#                  alpha=0.6, beta_start=0.4, beta_frames=100000, n_step=3, gamma=0.99):
#         self.capacity = capacity
#         self.device = device
#         self.alpha = alpha
#         self.beta_start = beta_start
#         self.beta_frames = beta_frames
#         self.frame_idx = 1
#         self.n_step = n_step
#         self.gamma = gamma

#         # Circular buffer storage
#         self.states = torch.zeros((capacity, *state_shape), dtype=torch.float32, device=device)
#         self.actions = torch.zeros((capacity, *action_shape), dtype=torch.float32, device=device)
#         self.rewards = torch.zeros((capacity,), dtype=torch.float32, device=device)
#         self.next_states = torch.zeros((capacity, *state_shape), dtype=torch.float32, device=device)
#         self.dones = torch.zeros((capacity,), dtype=torch.bool, device=device)
#         self.priorities = np.zeros((capacity,), dtype=np.float32)

#         # N-step buffer
#         self.n_step_buffer = []
#         self.position = 0
#         self.size = 0

#     def beta_by_frame(self):
#         return min(1.0, self.beta_start + (1.0 - self.beta_start) * self.frame_idx / self.beta_frames)

#     def add(self, state, action, reward, next_state, done):
#         # Append to n-step buffer
#         self.n_step_buffer.append((state, action, reward, next_state, done))

#         if len(self.n_step_buffer) < self.n_step and not done:
#             return

#         # Compute multi-step return
#         cum_reward, next_s, done_flag = 0.0, None, False
#         for idx, (_, _, r, s_next, d) in enumerate(self.n_step_buffer):
#             cum_reward += (self.gamma**idx) * r
#             next_s, done_flag = s_next, d
#             if d:
#                 break

#         state0, action0, _, _, _ = self.n_step_buffer[0]

#         # Store transition
#         self.states[self.position] = torch.tensor(state0, dtype=torch.float32, device=self.device)
#         self.actions[self.position] = torch.tensor(action0, dtype=torch.float32, device=self.device)
#         self.rewards[self.position] = cum_reward
#         self.next_states[self.position] = torch.tensor(next_s, dtype=torch.float32, device=self.device)
#         self.dones[self.position] = done_flag
#         self.priorities[self.position] = self.priorities.max() if self.size > 0 else 1.0

#         self.position = (self.position + 1) % self.capacity
#         self.size = min(self.size + 1, self.capacity)

#         # Remove oldest
#         self.n_step_buffer.pop(0)
#         if done:
#             self.n_step_buffer.clear()

#     def sample(self, batch_size):
#         assert self.size > 0, "Empty buffer"
#         self.frame_idx += 1

#         # Compute sampling distribution
#         prios = self.priorities[:self.size] + 1e-6
#         probs = prios ** self.alpha
#         probs /= probs.sum()

#         # Sample indices
#         indices = np.random.choice(self.size, batch_size, p=probs)

#         # Importance sampling weights
#         beta = self.beta_by_frame()
#         weights = (self.size * probs[indices]) ** (-beta)
#         weights /= weights.max()
#         weights = torch.tensor(weights, dtype=torch.float32, device=self.device)

#         # Gather samples
#         states = self.states[indices]
#         actions = self.actions[indices]
#         rewards = self.rewards[indices].unsqueeze(1)
#         next_states = self.next_states[indices]
#         dones = self.dones[indices].unsqueeze(1).float()

#         return states, actions, rewards, next_states, dones, weights, indices

#     def update_priorities(self, indices, td_errors):
#         for idx, td in zip(indices, td_errors):
#             self.priorities[idx] = abs(td) + 1e-6

# ########### Network Definitions ###########
# # Actor network: Gaussian policy with tanh and action mapping
# class Actor(nn.Module):
#     def __init__(self, state_shape, action_dim):
#         super().__init__()
#         c, h, w = state_shape
#         # Conv encoder
#         self.conv = nn.Sequential(
#             nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
#             nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
#             nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
#             nn.Flatten())
#         # Compute conv output size
#         def conv_out(sz, k, s): return (sz - (k - 1) - 1) // s + 1
#         convw = conv_out(conv_out(conv_out(w, 8, 4), 4, 2), 3, 1)
#         convh = conv_out(conv_out(conv_out(h, 8, 4), 4, 2), 3, 1)
#         lin = convw * convh * 64

#         self.fc = nn.Sequential(
#             nn.Linear(lin, 512), nn.ReLU(),
#             nn.Linear(512, 512), nn.ReLU())
#         self.mean = nn.Linear(512, action_dim)
#         self.log_std = nn.Linear(512, action_dim)

#     def forward(self, x):
#         x = self.conv(x / 255.0)
#         x = self.fc(x)
#         mean = self.mean(x)
#         log_std = self.log_std(x).clamp(-20, 2)
#         return mean, log_std

#     def sample(self, x):
#         mean, log_std = self.forward(x)
#         std = torch.exp(log_std)
#         dist = Normal(mean, std)
#         z = dist.rsample()
#         raw = torch.tanh(z)

#         # Map actions: steer in [-1,1], gas/brake in [0,1]
#         steer = raw[:, 0:1]
#         gas   = (raw[:, 1:2] + 1.0) / 2.0
#         brake = (raw[:, 2:3] + 1.0) / 2.0
#         action = torch.cat([steer, gas, brake], dim=1)

#         # Log-prob correction
#         logp = dist.log_prob(z) - torch.log(1 - raw.pow(2) + 1e-6)
#         logp = logp.sum(1, keepdim=True)
#         return action, logp

# # Critic network
# class Critic(nn.Module):
#     def __init__(self, state_shape, action_dim):
#         super().__init__()
#         c, h, w = state_shape
#         self.conv = nn.Sequential(
#             nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
#             nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
#             nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
#             nn.Flatten())
#         def conv_out(sz, k, s): return (sz - (k - 1) - 1) // s + 1
#         convw = conv_out(conv_out(conv_out(w, 8, 4), 4, 2), 3, 1)
#         convh = conv_out(conv_out(conv_out(h, 8, 4), 4, 2), 3, 1)
#         lin = convw * convh * 64

#         self.fc = nn.Sequential(
#             nn.Linear(lin + action_dim, 512), nn.ReLU(),
#             nn.Linear(512, 512), nn.ReLU(),
#             nn.Linear(512, 1))

#     def forward(self, x, a):
#         x = self.conv(x / 255.0)
#         return self.fc(torch.cat([x, a], dim=1))

# ########### Soft Updates ###########
# def soft_update(target, source, tau):
#     for t, s in zip(target.parameters(), source.parameters()):
#         t.data.copy_(t.data * (1.0 - tau) + s.data * tau)

# def hard_update(target, source):
#     target.load_state_dict(source.state_dict())

## Our Final Approach to the Problem (Race Car Behavior Optimization)

#### We implemented a SAC agent augmented with a "chain-of-thought" LSTM (long short-term memory model), both in actor and critic, to carry out hidden memory across time and trained on a reward model ``Shaped Rewards`` from the environment. (BTW, we have kept the reward model same across all the experiments). Experience is stored in a prioritized, N-step replay buffer (our learning from experiment-2. We keep the good things.), that also keeps the LSTM's hidden states. During training, we update actor, twin critics and the temperature (how creatively the car is runninng on the road) via soft updates. The ``Reward`` we got was a significant jump from ``791.11`` to ``1514.86``. This reward satisfied us and we stopped asking the question to ourselves. CAN WE DO BETTER?

### Chain-of-Thought Module

The `CoTModule` is a neural network module that implements an LSTM cell to maintain a hidden memory state across time steps. This module is used to augment both the actor and critic networks, enabling them to carry temporal context (chain-of-thought) during decision making.

In [10]:
# ---------- Chain-of-Thought Module ----------
class CoTModule(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.lstm = nn.LSTMCell(input_dim, hidden_dim)

    def forward(self, x, hx, cx):
        # x: [B, input_dim], hx, cx: [B, hidden_dim]
        return self.lstm(x, (hx, cx))

### Prioritized Replay Buffer Module

The **Prioritized N-Step Replay Buffer with CoT States** maintains a fixed-capacity circular buffer of transitions augmented with LSTM-style hidden states (`hx`, `cx`) and supports n-step return computation by temporarily storing the last _n_ steps, then aggregating discounted rewards and final next-state information when enough steps accumulate or on terminal; it samples according to priorities^α, applies importance-sampling weights annealed from β₀ to 1.0 over a schedule, and lets you update priorities based on absolute TD errors plus a small ε to ensure nonzero probability.

In [4]:
# ---------- Prioritized N-Step Replay Buffer with CoT States ----------
class PrioritizedReplayBuffer:
    def __init__(self, capacity, state_shape, action_shape, cot_dim, device,
                 alpha=0.6, beta_start=0.4, beta_frames=100000,
                 n_step=3, gamma=0.99):
        self.capacity = capacity
        self.device = device
        self.alpha = alpha
        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame_idx = 1
        self.n_step = n_step
        self.gamma = gamma
        self.position = 0
        self.size = 0
        self.n_step_buffer = []

        # storage
        self.states = torch.zeros((capacity, *state_shape), device=device)
        self.actions = torch.zeros((capacity, *action_shape), device=device)
        self.rewards = torch.zeros((capacity,), device=device)
        self.next_states = torch.zeros((capacity, *state_shape), device=device)
        self.dones = torch.zeros((capacity,), device=device)
        # CoT hidden states hx, cx
        self.hxs = torch.zeros((capacity, cot_dim), device=device)
        self.cxs = torch.zeros((capacity, cot_dim), device=device)
        self.next_hxs = torch.zeros((capacity, cot_dim), device=device)
        self.next_cxs = torch.zeros((capacity, cot_dim), device=device)
        self.priorities = np.zeros((capacity,), dtype=np.float32)

    def beta_by_frame(self):
        return min(1.0, self.beta_start + (1.0 - self.beta_start) * self.frame_idx / self.beta_frames)

    def add(self, state, action, reward, next_state, done, hx, cx, next_hx, next_cx):
        self.n_step_buffer.append((state, action, reward, next_state, done, hx, cx, next_hx, next_cx))
        if len(self.n_step_buffer) < self.n_step and not done:
            return
        # compute N-step return
        cum_reward, f_next_state, f_done, f_next_hx, f_next_cx = 0, None, False, None, None
        for idx, (_, _, r, ns, d, _, _, nhx, ncx) in enumerate(self.n_step_buffer):
            cum_reward += (self.gamma**idx) * r
            f_next_state, f_done, f_next_hx, f_next_cx = ns, d, nhx, ncx
            if d:
                break
        s, a, _, _, _, h0, c0, _, _ = self.n_step_buffer[0]
        pos = self.position
        self.states[pos] = torch.tensor(s, device=self.device)
        self.actions[pos] = torch.tensor(a, device=self.device)
        self.rewards[pos] = cum_reward
        self.next_states[pos] = torch.tensor(f_next_state, device=self.device)
        self.dones[pos] = f_done
        self.hxs[pos] = hx
        self.cxs[pos] = cx
        self.next_hxs[pos] = f_next_hx
        self.next_cxs[pos] = f_next_cx
        self.priorities[pos] = self.priorities.max() if self.size>0 else 1.0
        self.position = (pos+1) % self.capacity
        self.size = min(self.size+1, self.capacity)
        self.n_step_buffer.pop(0)
        if done:
            self.n_step_buffer.clear()

    def sample(self, batch_size):
        assert self.size>0
        self.frame_idx +=1
        prios = self.priorities[:self.size] + 1e-6
        probs = prios**self.alpha
        probs /= probs.sum()
        idxs = np.random.choice(self.size, batch_size, p=probs)
        beta = self.beta_by_frame()
        weights = (self.size * probs[idxs])**(-beta)
        weights /= weights.max()
        weights = torch.tensor(weights, device=self.device)

        batch = dict(
            s=self.states[idxs],
            a=self.actions[idxs],
            r=self.rewards[idxs].unsqueeze(1),
            s2=self.next_states[idxs],
            d=self.dones[idxs].unsqueeze(1).float(),
            hx=self.hxs[idxs],
            cx=self.cxs[idxs],
            hx2=self.next_hxs[idxs],
            cx2=self.next_cxs[idxs],
            w=weights.unsqueeze(1),
            idxs=idxs
        )
        return batch

    def update_priorities(self, idxs, td_errors):
        for i, td in zip(idxs, td_errors):
            self.priorities[i] = abs(td) + 1e-6

### Soft Actor Critic

The following code defines the neural network architectures for the Actor and Critic used in the SAC agent with Chain-of-Thought (CoT) LSTM modules. Both networks use convolutional layers for feature extraction, followed by fully connected layers and an LSTM cell to maintain temporal context. The Actor outputs the mean and log standard deviation for a Gaussian policy, while the Critic estimates the Q-value for state-action pairs.

- **Actor**: Processes the state through convolutional and fully connected layers, then passes the result through a CoT LSTM. The output is used to parameterize a Gaussian distribution, from which actions are sampled and mapped to the environment's action space.
- **Critic**: Processes the state and action together through convolutional and fully connected layers, then passes the result through a CoT LSTM to estimate the Q-value.

Both classes provide methods to initialize the LSTM hidden states and to perform forward passes with temporal memory.

In [5]:
# ---------- Networks ----------
class Actor(nn.Module):
    def __init__(self, state_shape, action_dim, cot_dim=256):
        super().__init__()
        c,h,w = state_shape
        self.conv = nn.Sequential(
            nn.Conv2d(c,32,8,4), nn.ReLU(),
            nn.Conv2d(32,64,4,2), nn.ReLU(),
            nn.Conv2d(64,64,3,1), nn.ReLU(),
            nn.Flatten())
        def conv_out(sz,k,s): return (sz-(k-1)-1)//s+1
        convw = conv_out(conv_out(conv_out(w,8,4),4,2),3,1)
        convh = conv_out(conv_out(conv_out(h,8,4),4,2),3,1)
        lin = convw*convh*64
        self.fc_enc = nn.Sequential(nn.Linear(lin,512), nn.ReLU())
        self.cot = CoTModule(512, cot_dim)
        self.fc_head = nn.Sequential(nn.Linear(cot_dim,512), nn.ReLU())
        self.mean = nn.Linear(512, action_dim)
        self.log_std = nn.Linear(512, action_dim)

    def init_hidden(self, bsz, device):
        return (torch.zeros(bsz, self.cot.lstm.hidden_size, device=device),
                torch.zeros(bsz, self.cot.lstm.hidden_size, device=device))

    def forward(self, x, hx, cx):
        feat = self.conv(x/255.0)
        enc = self.fc_enc(feat)
        hx, cx = self.cot(enc, hx, cx)
        h = self.fc_head(hx)
        mean = self.mean(h)
        log_std = self.log_std(h).clamp(-20,2)
        return mean, log_std, hx, cx

    def sample(self, x, hx, cx):
        mean, log_std, hx, cx = self.forward(x, hx, cx)
        std = torch.exp(log_std)
        dist = Normal(mean, std)
        z = dist.rsample()
        raw = torch.tanh(z)
        steer = raw[:,0:1]
        gas   = (raw[:,1:2]+1)/2
        brake = (raw[:,2:3]+1)/2
        action = torch.cat([steer,gas,brake],1)
        logp = dist.log_prob(z) - torch.log(1-raw.pow(2)+1e-6)
        return action, logp.sum(1,True), hx, cx

class Critic(nn.Module):
    def __init__(self, state_shape, action_dim, cot_dim=256):
        super().__init__()
        c,h,w = state_shape
        self.conv = nn.Sequential(
            nn.Conv2d(c,32,8,4), nn.ReLU(),
            nn.Conv2d(32,64,4,2), nn.ReLU(),
            nn.Conv2d(64,64,3,1), nn.ReLU(),
            nn.Flatten())
        def conv_out(sz,k,s): return (sz-(k-1)-1)//s+1
        convw = conv_out(conv_out(conv_out(w,8,4),4,2),3,1)
        convh = conv_out(conv_out(conv_out(h,8,4),4,2),3,1)
        lin = convw*convh*64 + action_dim
        self.fc_enc = nn.Sequential(nn.Linear(lin,512), nn.ReLU())
        self.cot = CoTModule(512, cot_dim)
        self.fc_out = nn.Linear(cot_dim,1)

    def init_hidden(self, bsz, device):
        return (torch.zeros(bsz, self.cot.lstm.hidden_size, device=device),
                torch.zeros(bsz, self.cot.lstm.hidden_size, device=device))

    def forward(self, x, a, hx, cx):
        feat = self.conv(x/255.0)
        enc = torch.cat([feat, a],1)
        enc = self.fc_enc(enc)
        hx, cx = self.cot(enc, hx, cx)
        q = self.fc_out(hx)
        return q, hx, cx

### Reward Shaping and Soft Updates

The `ShapedReward` class is a custom Gym `RewardWrapper` that augments the environment's reward signal with additional terms to encourage forward velocity, penalize driving on grass, and discourage excessive rotation. This helps guide the agent toward more desirable behaviors.

The `soft_update` and `hard_update` functions are utility methods for updating the target networks in soft actor-critic (SAC) algorithms. `soft_update` performs a weighted update of the target network parameters, while `hard_update` copies the parameters directly.

In [6]:
# ---------- Reward Shaping ----------
class ShapedReward(RewardWrapper):
    def __init__(self, env): super().__init__(env)
    def reward(self, r):
        un = self.unwrapped
        vel = un.car.hull.linearVelocity
        ang_v = un.car.hull.angularVelocity
        angle = un.car.hull.angle
        forward = vel.x*math.cos(angle) + vel.y*math.sin(angle)
        frame = un.render()
        h,w,_=frame.shape; cx,cy=w//2,h//2
        r_pix,g_pix,b_pix=frame[cy,cx]
        grass_penalty = -0.2 if (g_pix>150 and g_pix>r_pix+30 and g_pix>b_pix+30) else 0.0
        rot_penalty = -0.05*abs(ang_v)
        shaped = 0.1*max(0,forward) + grass_penalty + rot_penalty
        return r + shaped

# ---------- Soft Updates ----------
def soft_update(tgt, src, tau):
    for t,s in zip(tgt.parameters(), src.parameters()):
        t.data.copy_(t.data*(1-tau) + s.data*tau)
def hard_update(tgt, src): tgt.load_state_dict(src.state_dict())


### Our Training Regime

Here we are checking if the reward has improved after every 10 episodes. If the reward has improved, we are saving the model. If not, then we continue traning the model for 2000 episodes.

In [None]:
# ---------- Training ----------
def train():
    # Hyperparams
    env_name="CarRacing-v2"; max_episodes=2000; max_steps=1000
    batch_size=64; gamma=0.99; tau=0.005; lr=2e-4
    buffer_size=200000; n_step=3; cot_dim=256

    env = gym.make(env_name, render_mode="rgb_array")
    env = ShapedReward(env)
    env = GrayScaleObservation(env)
    env = ResizeObservation(env,(84,84))
    env = FrameStack(env,4)
    state_shape=(4,84,84); action_dim=env.action_space.shape[0]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # models
    actor = Actor(state_shape, action_dim, cot_dim).to(device)
    critic1 = Critic(state_shape, action_dim, cot_dim).to(device)
    critic2 = Critic(state_shape, action_dim, cot_dim).to(device)
    c1_tgt = Critic(state_shape, action_dim, cot_dim).to(device)
    c2_tgt = Critic(state_shape, action_dim, cot_dim).to(device)
    hard_update(c1_tgt, critic1); hard_update(c2_tgt, critic2)

    actor_opt = optim.Adam(actor.parameters(), lr=lr)
    c1_opt = optim.Adam(critic1.parameters(), lr=lr)
    c2_opt = optim.Adam(critic2.parameters(), lr=lr)
    log_alpha = torch.zeros(1, requires_grad=True, device=device)
    alpha_opt = optim.Adam([log_alpha], lr=lr)
    target_entropy = -action_dim

    buffer = PrioritizedReplayBuffer(buffer_size, state_shape, (action_dim,), cot_dim, device,
                                     alpha=0.6, n_step=n_step, gamma=gamma)

    best_avg=-1e9; history=[]
    for ep in range(1, max_episodes+1):
        obs = env.reset() if not isinstance(env.reset(), tuple) else env.reset()[0]
        state = np.array(obs)
        if state.ndim==4 and state.shape[-1]==1: state=state.squeeze(-1)
        # init CoT
        hx, cx = actor.init_hidden(1, device)
        ep_r=0
        for t in range(max_steps):
            st = torch.tensor(state, device=device).unsqueeze(0)
            with torch.no_grad():
                a, lp, nhx, ncx = actor.sample(st, hx, cx)
            action = a.cpu().numpy()[0]
            # step
            total_r=0; done=False
            for _ in range(4):
                obs2, r, done_flag, trunc, _=env.step(action.tolist())
                total_r+=r
                if done_flag or trunc: done=True; break
            nxt = np.array(obs2)
            if nxt.ndim==4 and nxt.shape[-1]==1: nxt=nxt.squeeze(-1)

            # store with CoT states
            buffer.add(state, action, total_r, nxt, done,
                       hx.squeeze(0), cx.squeeze(0), nhx.squeeze(0), ncx.squeeze(0))
            state, hx, cx = nxt, nhx, ncx
            ep_r+=total_r

            if buffer.size>=batch_size:
                batch = buffer.sample(batch_size)
                # critic update
                with torch.no_grad():
                    a2, lp2, hx2, cx2 = actor.sample(batch['s2'], batch['hx2'], batch['cx2'])
                    q1_t,_,_ = c1_tgt(batch['s2'], a2, batch['hx2'], batch['cx2'])
                    q2_t,_,_ = c2_tgt(batch['s2'], a2, batch['hx2'], batch['cx2'])
                    q_min = torch.min(q1_t, q2_t) - log_alpha.exp()*lp2
                    target_q = batch['r'] + (1-batch['d'])*(gamma**n_step)*q_min

                q1, _, _ = critic1(batch['s'], batch['a'], batch['hx'], batch['cx'])
                q2, _, _ = critic2(batch['s'], batch['a'], batch['hx'], batch['cx'])
                td1 = q1 - target_q; td2 = q2 - target_q
                loss_c1 = (batch['w'] * td1.pow(2)).mean()
                loss_c2 = (batch['w'] * td2.pow(2)).mean()
                c1_opt.zero_grad(); loss_c1.backward(); c1_opt.step()
                c2_opt.zero_grad(); loss_c2.backward(); c2_opt.step()
                buffer.update_priorities(batch['idxs'], ((td1+td2)/2).abs().detach().cpu().numpy())
                # actor update
                a_new, lp_new, _, _ = actor.sample(batch['s'], batch['hx'], batch['cx'])
                q1_pi,_,_ = critic1(batch['s'], a_new, batch['hx'], batch['cx'])
                q2_pi,_,_ = critic2(batch['s'], a_new, batch['hx'], batch['cx'])
                q_pi = torch.min(q1_pi, q2_pi)
                alpha = log_alpha.exp()
                loss_pi = (alpha * lp_new - q_pi).mean()
                actor_opt.zero_grad(); loss_pi.backward(); actor_opt.step()
                # alpha update
                loss_a = -(log_alpha * (lp_new + target_entropy).detach()).mean()
                alpha_opt.zero_grad(); loss_a.backward(); alpha_opt.step()
                soft_update(c1_tgt, critic1, tau); soft_update(c2_tgt, critic2, tau)

            if done: break

        history.append(ep_r)
        if ep%10==0:
            avg10 = sum(history[-10:])/len(history[-10:])
            print(f"Ep {ep:4d} avg10 {avg10:7.2f} best {best_avg:7.2f}")
            if avg10>best_avg:
                best_avg=avg10; os.makedirs("ckpt", exist_ok=True)
                torch.save(actor.state_dict(), "ckpt/actor.pth")
                print("Saved best model")
    env.close()

if __name__=='__main__':
    train()


  state = np.array(obs)
  nxt = np.array(obs2)
  self.priorities[i] = abs(td) + 1e-6


Ep   10 avg10  -68.86 best -1000000000.00
Saved best model
Ep   20 avg10  -75.76 best  -68.86


  grass_penalty = -0.2 if (g_pix>150 and g_pix>r_pix+30 and g_pix>b_pix+30) else 0.0


Ep   30 avg10  -78.20 best  -68.86
Ep   40 avg10  -69.11 best  -68.86
Ep   50 avg10  -68.77 best  -68.86
Saved best model
Ep   60 avg10  -70.17 best  -68.77
Ep   70 avg10  -65.08 best  -68.77
Saved best model
Ep   80 avg10  -85.28 best  -65.08
Ep   90 avg10  -72.62 best  -65.08
Ep  100 avg10  -51.50 best  -65.08
Saved best model
Ep  110 avg10  -61.70 best  -51.50
Ep  120 avg10  -75.45 best  -51.50
Ep  130 avg10 -111.79 best  -51.50
Ep  140 avg10 -132.76 best  -51.50
Ep  150 avg10  -66.81 best  -51.50
Ep  160 avg10  -17.26 best  -51.50
Saved best model
Ep  170 avg10   72.90 best  -17.26
Saved best model
Ep  180 avg10  193.02 best   72.90
Saved best model
Ep  190 avg10  298.05 best  193.02
Saved best model
Ep  200 avg10  297.15 best  298.05
Ep  210 avg10  271.07 best  298.05
Ep  220 avg10  273.97 best  298.05
Ep  230 avg10  247.29 best  298.05
Ep  240 avg10  244.15 best  298.05
Ep  250 avg10  244.99 best  298.05
Ep  260 avg10  241.78 best  298.05
Ep  270 avg10  232.42 best  298.05
Ep  28