In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import os

class CustomDataset(Dataset):
    def __init__(self, data_path):
        self.data = []
        # 读取json文件
        for file in os.listdir(data_path):
            if file.endswith('.json'):
                with open(os.path.join(data_path, file), 'r') as f:
                    self.data.extend(json.load(f))
        
        # 提取需要的字段
        self.observations = np.array([item['obs'] for item in self.data])
        self.actions = np.array([item['actions'] for item in self.data])
        self.rewards = np.array([item['rewards'] for item in self.data])
        self.next_observations = np.array([item['new_obs'] for item in self.data])
        self.dones = np.array([item['dones'] for item in self.data])
        self.action_probs = np.array([item['action_prob'] for item in self.data])
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'obs': torch.FloatTensor(self.observations[idx]),
            'action': torch.FloatTensor(self.actions[idx]),
            'reward': torch.FloatTensor([self.rewards[idx]]),
            'next_obs': torch.FloatTensor(self.next_observations[idx]),
            'done': torch.FloatTensor([self.dones[idx]]),
            'action_prob': torch.FloatTensor([self.action_probs[idx]])
        }
    

CustomDataset('exam\data_collect')

JSONDecodeError: Extra data: line 2 column 1 (char 30981)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)
        
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return torch.tanh(self.fc3(x))

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
        
    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class CQL:
    def __init__(self, state_dim, action_dim, device='cuda'):
        self.actor = Actor(state_dim, action_dim).to(device)
        self.critic1 = Critic(state_dim, action_dim).to(device)
        self.critic2 = Critic(state_dim, action_dim).to(device)
        
        self.actor_optimizer = Adam(self.actor.parameters(), lr=0.0001)
        self.critic1_optimizer = Adam(self.critic1.parameters(), lr=0.0003)
        self.critic2_optimizer = Adam(self.critic2.parameters(), lr=0.0003)
        
        self.device = device
        self.temperature = 10.0
        self.min_q_weight = 0.01
        self.num_actions = 10
        
    def train_step(self, batch):
        obs = batch['obs'].to(self.device)
        actions = batch['action'].to(self.device)
        rewards = batch['reward'].to(self.device)
        next_obs = batch['next_obs'].to(self.device)
        dones = batch['done'].to(self.device)
        action_probs = batch['action_prob'].to(self.device)
        
        # 计算当前Q值
        current_q1 = self.critic1(obs, actions)
        current_q2 = self.critic2(obs, actions)
        current_q = torch.min(current_q1, current_q2)
        
        # 采样动作计算Q值
        with torch.no_grad():
            sampled_actions = torch.randn_like(actions).repeat(self.num_actions, 1)
            sampled_obs = obs.repeat(self.num_actions, 1)
            sampled_q1 = self.critic1(sampled_obs, sampled_actions)
            sampled_q2 = self.critic2(sampled_obs, sampled_actions)
            sampled_q = torch.min(sampled_q1, sampled_q2)
            sampled_q = sampled_q.view(self.num_actions, -1).mean(0, keepdim=True)
        
        # 计算保守Q值
        conservative_q = current_q - self.temperature * (current_q - sampled_q)
        
        # 计算目标Q值
        with torch.no_grad():
            next_actions = self.actor(next_obs)
            next_q1 = self.critic1(next_obs, next_actions)
            next_q2 = self.critic2(next_obs, next_actions)
            next_q = torch.min(next_q1, next_q2)
            target_q = rewards + (1 - dones) * 0.99 * next_q
        
        # 计算损失
        critic1_loss = F.mse_loss(current_q1, target_q)
        critic2_loss = F.mse_loss(current_q2, target_q)
        conservative_loss = F.mse_loss(conservative_q, target_q)
        
        # 更新critic
        self.critic1_optimizer.zero_grad()
        (critic1_loss + self.min_q_weight * conservative_loss).backward()
        self.critic1_optimizer.step()
        
        self.critic2_optimizer.zero_grad()
        (critic2_loss + self.min_q_weight * conservative_loss).backward()
        self.critic2_optimizer.step()
        
        # 更新actor
        actor_loss = -self.critic1(obs, self.actor(obs)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        return {
            'critic1_loss': critic1_loss.item(),
            'critic2_loss': critic2_loss.item(),
            'actor_loss': actor_loss.item(),
            'conservative_loss': conservative_loss.item()
        }

In [None]:
def train():
    # 加载数据
    dataset = CustomDataset("collect/sample_save_folder/PPO/2025-04-01_09-45")
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    
    # 初始化CQL
    cql = CQL(state_dim=20, action_dim=4)
    
    # 训练循环
    for epoch in range(1000):
        epoch_losses = []
        for batch in dataloader:
            losses = cql.train_step(batch)
            epoch_losses.append(losses)
        
        # 打印训练信息
        avg_losses = {k: np.mean([l[k] for l in epoch_losses]) for k in epoch_losses[0].keys()}
        print(f"Epoch {epoch}: {avg_losses}")
        
        # 保存模型
        if epoch % 50 == 0:
            torch.save(cql.actor.state_dict(), f"save_model/{timestamp}/actor_{epoch}.pth")
            torch.save(cql.critic1.state_dict(), f"save_model/{timestamp}/critic1_{epoch}.pth")
            torch.save(cql.critic2.state_dict(), f"save_model/{timestamp}/critic2_{epoch}.pth")

if __name__ == "__main__":
    train()