# CPT Quantile DDPG Agent Implementation

This notebook implements a CPT-modified DDPG agent using quantile regression. It includes the agent, critic, and a minimal replay buffer, along with a dummy training loop for testing purposes.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

## --- CPT Transformation Function ---
def cpt_transform_tensor(rewards, alpha=0.88, beta=0.88, lambda_=2.25):
    rewards = rewards.float()
    pos = torch.pow(torch.clamp(rewards, min=0), alpha)
    neg = -lambda_ * torch.pow(-torch.clamp(rewards, max=0), beta)
    return pos + neg

## --- Actor Network ---
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()  # Assume actions in [-1,1]
        )
        
    def forward(self, state):
        return self.net(state)

## --- Quantile Critic Network ---
class QuantileCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256, num_quantiles=50):
        super().__init__()
        self.num_quantiles = num_quantiles
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.quantile_layer = nn.Linear(hidden_dim, num_quantiles)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = self.net(x)
        quantiles = self.quantile_layer(x)  # shape: (batch, num_quantiles)
        return quantiles

## --- Quantile Regression Loss ---
def huber(x, kappa=1.0):
    abs_x = x.abs()
    cond = (abs_x <= kappa).float()
    loss = 0.5 * (x ** 2) * cond + kappa * (abs_x - 0.5 * kappa) * (1 - cond)
    return loss

def quantile_regression_loss(predicted, target, taus, kappa=1.0):
    # predicted: (batch, num_quantiles)
    # target: (batch, num_quantiles)
    # taus: (num_quantiles,)
    batch_size, num_quantiles = predicted.shape
    target_exp = target.unsqueeze(2)      # (batch, num_quantiles, 1)
    predicted_exp = predicted.unsqueeze(1)  # (batch, 1, num_quantiles)
    error = target_exp - predicted_exp      # (batch, num_target, num_pred)
    taus = taus.view(1, num_quantiles, 1)     
    indicator = (error < 0).float()
    loss = torch.abs(taus - indicator) * huber(error, kappa)
    return loss.mean()

## --- CPT-Quantile-DDPG Agent ---
class CPTQuantileDDPG:
    def __init__(self, state_dim, action_dim,
                 actor_lr=1e-3, critic_lr=1e-3,
                 gamma=0.99, tau=0.005, num_quantiles=50):
        self.gamma = gamma
        self.tau = tau
        self.num_quantiles = num_quantiles

        self.actor = Actor(state_dim, action_dim)
        self.actor_target = Actor(state_dim, action_dim)
        self.actor_target.load_state_dict(self.actor.state_dict())

        self.critic = QuantileCritic(state_dim, action_dim, num_quantiles=num_quantiles)
        self.critic_target = QuantileCritic(state_dim, action_dim, num_quantiles=num_quantiles)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

        # Create quantile fractions (τ) for the regression loss
        self.taus = torch.linspace(0.0, 1.0, num_quantiles + 1)[1:] - 0.5 / num_quantiles

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        action = self.actor(state)
        return action.detach().cpu().numpy()[0]

    def update(self, replay_buffer, batch_size=64, kappa=1.0):
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)
        transformed_reward = cpt_transform_tensor(reward)
        current_quantiles = self.critic(state, action)
        
        with torch.no_grad():
            next_action = self.actor_target(next_state)
            next_quantiles = self.critic_target(next_state, next_action)
            target_quantiles = transformed_reward.unsqueeze(1) + (1 - done) * self.gamma * next_quantiles
        
        critic_loss = quantile_regression_loss(current_quantiles, target_quantiles, self.taus, kappa=kappa)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        actor_loss = -self.critic(state, self.actor(state)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # Soft update target networks
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        return critic_loss.item(), actor_loss.item()

## --- Minimal Replay Buffer ---
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
    
    def add(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        indices = torch.randint(0, len(self.buffer), (batch_size,))
        batch = [self.buffer[i] for i in indices]
        state, action, reward, next_state, done = zip(*batch)
        state = torch.stack(state)
        action = torch.stack(action)
        reward = torch.stack(reward).squeeze()
        next_state = torch.stack(next_state)
        done = torch.stack(done).squeeze()
        return state, action, reward, next_state, done
    
    def size(self):
        return len(self.buffer)


In [None]:
## --- Example Training Loop for Dummy Environment ---
if __name__ == '__main__':
    state_dim = 10
    action_dim = 2
    agent = CPTQuantileDDPG(state_dim, action_dim, num_quantiles=50)
    replay_buffer = ReplayBuffer(capacity=100000)

    num_episodes = 10
    steps_per_episode = 100
    
    for episode in range(num_episodes):
        state = torch.randn(state_dim)
        for t in range(steps_per_episode):
            action = agent.select_action(state.numpy())
            next_state = torch.randn(state_dim)
            reward = torch.tensor([float(torch.randn(1))])
            done = torch.tensor([0.0]) if t < steps_per_episode - 1 else torch.tensor([1.0])
            
            replay_buffer.add(state, torch.tensor(action), reward, next_state, done)
            state = next_state
            
            if replay_buffer.size() >= 64:
                critic_loss, actor_loss = agent.update(replay_buffer, batch_size=64)
                print(f"Episode {episode} Step {t}: Critic Loss = {critic_loss:.4f}, Actor Loss = {actor_loss:.4f}")
