In [1]:
import numpy as np
import gym
from gym import spaces

class GridEnv(gym.Env):
    def __init__(self):
        super(GridEnv, self).__init__()
        self.grid_size = 10
        self.action_space = spaces.Discrete(4)  # Actions: 0=Up, 1=Right, 2=Down, 3=Left
        self.observation_space = spaces.Box(low=0, high=self.grid_size - 1, shape=(2,), dtype=np.int32)
        self.state = np.zeros(2, dtype=np.int32)
        self.goal = np.array([self.grid_size - 1, self.grid_size - 1], dtype=np.int32)
        
    def reset(self):
        self.state = np.array([0, 0], dtype=np.int32)
        return self.state
    
    def step(self, action):
        if action == 0 and self.state[0] > 0:  # Up
            self.state[0] -= 1
        elif action == 1 and self.state[1] < self.grid_size - 1:  # Right
            self.state[1] += 1
        elif action == 2 and self.state[0] < self.grid_size - 1:  # Down
            self.state[0] += 1
        elif action == 3 and self.state[1] > 0:  # Left
            self.state[1] -= 1
        
        done = np.array_equal(self.state, self.goal)
        reward = 1.0 if done else -0.1
        
        return self.state, reward, done, {}


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, action_dim)
        self.fc_log_std = nn.Linear(256, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_log_std(x)
        return mean, log_std


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_q = nn.Linear(256, 1)

    def forward(self, state, action):
        print(f'state {state}')
        print(f'action {action}')
        x = torch.cat([state, action.unsqueeze(0)], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q_value = self.fc_q(x)
        return q_value


In [4]:
class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_value = nn.Linear(256, 1)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        value = self.fc_value(x)
        return value


In [5]:
class SACAgent:
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, tau=0.005, alpha=0.2):
        self.actor = PolicyNetwork(state_dim, action_dim)
        self.actor_target = PolicyNetwork(state_dim, action_dim)
        self.q1 = QNetwork(state_dim, action_dim)
        self.q2 = QNetwork(state_dim, action_dim)
        self.q1_target = QNetwork(state_dim, action_dim)
        self.q2_target = QNetwork(state_dim, action_dim)
        self.value = ValueNetwork(state_dim)
        self.value_target = ValueNetwork(state_dim)
        
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.q1_optimizer = torch.optim.Adam(self.q1.parameters(), lr=lr)
        self.q2_optimizer = torch.optim.Adam(self.q2.parameters(), lr=lr)
        self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=lr)
        
        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        
        self._update_target_networks()

    def _update_target_networks(self):
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.value_target.parameters(), self.value.parameters()):
            target_param.data.copy_(param.data)
            
    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        mean, log_std = self.actor(state)
        std = torch.exp(log_std)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        action = torch.argmax(action)
        return action.squeeze(0).detach().numpy()
    
    def train(self, replay_buffer, batch_size=64):
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        
        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones).unsqueeze(1)
        
        # Update Q networks
        with torch.no_grad():
            next_actions, next_log_std = self.actor_target(next_states)
            next_std = torch.exp(next_log_std)
            next_dist = torch.distributions.Normal(next_actions, next_std)
            next_actions = next_dist.rsample()
            target_q1 = self.q1_target(next_states, next_actions)
            target_q2 = self.q2_target(next_states, next_actions)
            target_value = torch.min(target_q1, target_q2) - self.alpha * next_log_std
            target = rewards + (1 - dones) * self.gamma * target_value
        
        q1_value = self.q1(states, actions)
        q2_value = self.q2(states, actions)
        
        q1_loss = F.mse_loss(q1_value, target)
        q2_loss = F.mse_loss(q2_value, target)
        
        self.q1_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_optimizer.step()
        
        self.q2_optimizer.zero_grad()
        q2_loss.backward()
        self.q2_optimizer.step()
        
        # Update Value Network
        value = self.value(states)
        value_loss = F.mse_loss(value, target_value.detach())
        
        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()
        
        # Update Policy Network
        actions, log_std = self.actor(states)
        dist = torch.distributions.Normal(actions, torch.exp(log_std))
        log_prob = dist.log_prob(actions).sum(dim=-1, keepdim=True)
        q1_value = self.q1(states, actions)
        q2_value = self.q2(states, actions)
        q_value = torch.min(q1_value, q2_value)
        policy_loss = (self.alpha * log_prob - q_value).mean()
        
        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()
        
        # Update Target Networks
        self._update_target_networks()



In [7]:
import random
from collections import deque

class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = deque(maxlen=max_size)
        
    def add(self, transition):
        self.buffer.append(transition)
        
    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*transitions)
        return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)
        
env = GridEnv()
agent = SACAgent(state_dim=2, action_dim=4)
replay_buffer = ReplayBuffer()

num_episodes = 1000
for episode in range(num_episodes):
    state = env.reset()
    done = False
    total_reward = 0
    while not done:
        action = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)
        replay_buffer.add((state, action, reward, next_state, done))
        state = next_state
        total_reward += reward
        
        if len(replay_buffer.buffer) > 1000:
            agent.train(replay_buffer)
    
    print(f"Episode {episode+1}/{num_episodes}, Total Reward: {total_reward}")


Episode 1/1000, Total Reward: -18.0
Episode 2/1000, Total Reward: -3.4000000000000004
Episode 3/1000, Total Reward: -25.600000000000108
Episode 4/1000, Total Reward: -14.399999999999961
Episode 5/1000, Total Reward: -4.4999999999999964
Episode 6/1000, Total Reward: -11.199999999999973
Episode 7/1000, Total Reward: -10.099999999999977
tensor([[9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [8., 2.],
        [8., 2.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [8., 2.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [8., 2.],
        [9., 9.]

RuntimeError: Tensors must have same number of dimensions: got 2 and 3