In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class WorkerNetwork(nn.Module):
    def __init__(self, obs_shape=71, action_dim=5):
        super(WorkerNetwork, self).__init__()
        
        # Shared Feature Extractor
        self.fc1 = nn.Linear(obs_shape, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        
        # Actor Head (Policy) - Outputs probabilities for 5 actions
        self.actor = nn.Linear(64, action_dim)
        
        # Critic Head (Value) - Outputs a single value estimate for the state
        self.critic = nn.Linear(64, 1)

    def forward(self, x):
        # Ensure input is a float tensor
        if not isinstance(x, torch.Tensor):
            x = torch.FloatTensor(x)
            
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        
        # Policy output (Logits)
        policy_logits = self.actor(x)
        
        # State value
        state_value = self.critic(x)
        
        return policy_logits, state_value

# Test the network with your observation size (71)
model = WorkerNetwork()
sample_obs = torch.randn(1, 71) # Simulating one agent's observation
logits, value = model(sample_obs)

print(f"Policy Logits: {logits.shape}") # Should be [1, 5]
print(f"State Value: {value.item():.4f}")

Policy Logits: torch.Size([1, 5])
State Value: -0.1101


In [None]:
import torch
from src.worker_mlp import WorkerNetwork

# Test the architecture without running the whole environment
model = WorkerNetwork(obs_shape=71, action_dim=5)
dummy_input = torch.randn(1, 71)
logits, value = model(dummy_input)

print("Architecture Test Successful!")
print(f"Action Logits Shape: {logits.shape}") # Should be [1, 5]
print(f"Value Output: {value.item():.4f}")