In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class A2CNetwork(nn.Module):
    def __init__(self, state_shape, action_size):
        super(A2CNetwork, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * state_shape[0] * state_shape[1], 128)
        self.dropout = nn.Dropout(0.2)
        self.actor = nn.Linear(128, action_size)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        actor_output = F.softmax(self.actor(x), dim=1)
        critic_output = self.critic(x)
        return actor_output, critic_output

class A2CAgent:
    def __init__(self, state_shape, action_size, lr=0.001, gamma=0.99):
        self.state_shape = state_shape
        self.action_size = action_size
        self.gamma = gamma
        self.device = device
        self.model = A2CNetwork(state_shape, action_size).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def act(self, state, valid_actions=None):
        if valid_actions is None:
            valid_actions = list(range(self.action_size))
        if len(valid_actions) == 0:
            return 0, 0.0
        state_tensor = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(self.device)
        probs, _ = self.model(state_tensor)
        probs_masked = probs.clone().to(self.device)
        probs_masked[0][~np.isin(range(self.action_size), valid_actions)] = 0
        prob_sum = probs_masked.sum()
        if prob_sum == 0 or torch.isnan(prob_sum):
            action = np.random.choice(valid_actions)
            return action, 0.0
        probs_masked = probs_masked / prob_sum
        if torch.isnan(probs_masked).any():
            action = np.random.choice(valid_actions)
            return action, 0.0
        action = torch.multinomial(probs_masked, 1).item()
        return action, probs_masked[0, action].item()

    def train(self, states, actions, rewards, next_state, done):
        if len(rewards) < 1:
            return
        states = torch.FloatTensor(states).unsqueeze(1).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_state = torch.FloatTensor(next_state).unsqueeze(0).unsqueeze(0).to(self.device)
        discounted_rewards = []
        running_reward = 0 if done else self.model(next_state)[1].item()
        for r in reversed(rewards):
            running_reward = r + self.gamma * running_reward
            discounted_rewards.insert(0, running_reward)
        discounted_rewards = torch.FloatTensor(discounted_rewards).to(self.device)
        try:
            _, values = self.model(states)
            values = values.squeeze()
            advantages = discounted_rewards - values
            action_probs, _ = self.model(states)
            log_probs = torch.log(action_probs.gather(1, actions.unsqueeze(1)).squeeze(1) + 1e-9)
            actor_loss = -(log_probs * advantages.detach()).mean()
            critic_loss = F.mse_loss(values, discounted_rewards)
            loss = actor_loss + critic_loss
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
        except Exception as e:
            print(f"Training error: {e}")
            print(f"States shape: {states.shape}, Actions: {actions.shape}, Rewards: {len(rewards)}")