In [1]:
import torch

In [2]:
from glob import glob 
file_locs = "../datacrawl/crawled_data/*/*.json"
import json

def read_json_file(file_path):
    try:
        with open(file_path, 'r') as file:
            data = json.load(file)
        return data
    except FileNotFoundError:
        raise FileNotFoundError(f"The file {file_path} was not found")
    except json.JSONDecodeError as e:
        raise json.JSONDecodeError(f"Error parsing JSON file: {str(e)}", e.doc, e.pos)


In [3]:
data_1 = read_json_file('../datacrawl/crawled_data/61723086/data.json')


In [4]:
def get_json_schema(data):
    """
    Extract schema from JSON data
    Returns a dict representing the structure, with types as values
    """
    if isinstance(data, dict):
        return {
            key: get_json_schema(value)
            for key, value in data.items()
        }
    elif isinstance(data, list):
        # If list is empty, return empty list schema
        if not data:
            return ["array (empty)"]
        # If all items are same type, just return first item's schema
        if all(isinstance(x, type(data[0])) for x in data):
            return [get_json_schema(data[0])]
        # If mixed types, return schema for each unique type
        return [get_json_schema(x) for x in data]
    else:
        return type(data).__name__

def print_schema(schema, indent=0):
    """Pretty print the schema"""
    if isinstance(schema, dict):
        result = "{\n"
        for key, value in schema.items():
            result += " " * (indent + 2) + f'"{key}": '
            result += print_schema(value, indent + 2)
            result += ",\n"
        result += " " * indent + "}"
        return result
    elif isinstance(schema, list):
        if len(schema) == 1:
            return f"[{print_schema(schema[0], indent)}]"
        else:
            result = "[\n"
            for item in schema:
                result += " " * (indent + 2)
                result += print_schema(item, indent + 2)
                result += ",\n"
            result += " " * indent + "]"
            return result
    else:
        return f'"{schema}"'



schema = get_json_schema(data_1)
print(print_schema(schema))

{
  "configuration": {
    "actTimeout": "int",
    "env_cfg": {
      "map_height": "int",
      "map_width": "int",
      "match_count_per_episode": "int",
      "max_steps_in_match": "int",
      "max_units": "int",
      "num_teams": "int",
      "unit_move_cost": "int",
      "unit_sap_cost": "int",
      "unit_sap_range": "int",
      "unit_sensor_range": "int",
    },
    "episodeSteps": "int",
    "runTimeout": "int",
    "seed": "int",
  },
  "description": "str",
  "id": "str",
  "info": {
    "EpisodeId": "int",
    "LiveVideoPath": "NoneType",
    "TeamNames": ["str"],
  },
  "name": "str",
  "rewards": ["int"],
  "schema_version": "int",
  "specification": {
    "action": {
      "default": "int",
      "description": "str",
      "type": "str",
    },
    "agents": ["int"],
    "configuration": {
      "actTimeout": {
        "default": "int",
        "description": "str",
        "minimum": "int",
        "type": "str",
      },
      "env_cfg": {
        "description": 

In [5]:
import torch
def process_observation(replay_observation):
    """Convert replay observation into tensor format"""
    return {
        'map_features': {
            'energy': torch.FloatTensor(replay_observation['map_features']['energy']),
            'tile_type': torch.FloatTensor(replay_observation['map_features']['tile_type'])
        },
        'sensor_mask': torch.FloatTensor(replay_observation['sensor_mask']),
        'units': {
            'position': torch.FloatTensor(replay_observation['units']['position']),
            'energy': torch.FloatTensor(replay_observation['units']['energy'])
        },
        'units_mask': torch.FloatTensor(replay_observation['units_mask']),
        'relic_nodes': torch.FloatTensor(replay_observation['relic_nodes']),
        'energy_nodes': torch.FloatTensor(replay_observation['energy_nodes'])
    }

In [6]:
from dataclasses import dataclass
from typing import List, Dict, Any
@dataclass
class GameStats:
    round_won: bool
    relic_points: int
    total_energy: int
    relic_tiles_found: float
    area_explored: float


In [7]:
import torch
import numpy as np


def network_to_env_actions(network_outputs, num_units=16):
    """
    Convert network outputs back to environment action format
    
    Args:
        network_outputs: Dictionary containing network outputs
            - action_probs: Action type probabilities
            - unit_probs: Unit selection probabilities
            - offset_x: X coordinate offsets
            - offset_y: Y coordinate offsets
        num_units: Number of units in the environment
    
    Returns:
        numpy array of shape (N, 3) containing environment actions
    """
    # Initialize empty action array
    env_actions = np.zeros((num_units, 3), dtype=int)
    
    # Get selected actions and units
    actions = torch.argmax(network_outputs['action_probs'], dim=1)
    units = torch.argmax(network_outputs['unit_probs'], dim=1)
    
    # Get coordinate offsets
    offset_x = network_outputs['offset_x'].detach().numpy()
    offset_y = network_outputs['offset_y'].detach().numpy()
    
    for i, (action, unit) in enumerate(zip(actions, units)):
        if action == 0:  # Move action
            # Convert offset to cardinal direction
            dx = offset_x[i]
            dy = offset_y[i]
            
            # Convert to closest cardinal direction
            if abs(dx) > abs(dy):
                if dx > 0:
                    action_type = 2  # Right
                else:
                    action_type = 4  # Left
            else:
                if dy > 0:
                    action_type = 3  # Down
                else:
                    action_type = 1  # Up
                    
            env_actions[unit] = [action_type, 0, 0]
            
        elif action == 1:  # Sap action
            dx = round(offset_x[i])
            dy = round(offset_y[i])
            env_actions[unit] = [5, dx, dy]
    
    return env_actions


In [8]:
import torch
import numpy as np

def parse_observation(obs):
    """
    Convert raw observation dict into tensor format required by LuxAIObservationEncoder
    
    Args:
        raw_obs (dict): Raw observation dictionary from the environment
        
    Returns:
        dict: Processed observation with tensors
    """
    
    return {
        'map_features': {
            'energy': torch.FloatTensor(obs['map_features']['energy']).unsqueeze(0),
            'tile_type': torch.FloatTensor(obs['map_features']['tile_type']).unsqueeze(0)
        },
        'sensor_mask': torch.FloatTensor(obs['sensor_mask']).unsqueeze(0),
        'units': {
            'position': torch.FloatTensor(obs['units']['position']).unsqueeze(0),
            'energy': torch.FloatTensor(obs['units']['energy']).unsqueeze(0)
        },
        'units_mask': torch.FloatTensor(obs['units_mask']).unsqueeze(0),
        'relic_nodes': torch.FloatTensor(obs['relic_nodes']).unsqueeze(0),
        'relic_nodes_mask': torch.FloatTensor(obs['relic_nodes_mask']).unsqueeze(0)
    }

In [9]:
step = data_1["steps"][110]

for i, each_player in enumerate(step):
    obs = each_player["observation"]["obs"]
    obs = json.loads(obs)
    print(f"\nPlayer {i}:")
    print(parse_observation(obs))
    
    total_energy = obs['units']['energy'][i]
    total_energy = sum(num for num in total_energy if num != -1)
    
    
    print(GameStats(
        round_won=obs['team_wins'][i] ,
        relic_points=obs['team_points'][i],
        total_energy=total_energy,
        relic_tiles_found=np.mean(obs['relic_nodes_mask']),
        area_explored=np.mean(obs['sensor_mask'])
    ))


Player 0:
{'map_features': {'energy': tensor([[[ 0.,  2.,  4.,  5.,  6.,  7.,  7.,  7.,  7.,  7., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
         [ 5.,  7.,  7.,  7.,  6.,  6.,  6.,  6.,  6.,  7., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
         [ 7.,  5.,  3., -1., -1., -1., -1., -2., -1.,  0., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
         [ 2., -1., -1., -1., -1., -5., -5., -5., -6., -6., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., 

In [10]:

def env_to_network_actions(env_actions):
    
    """
    Convert environment actions to network format
    
    Args:
        env_actions: List/array of shape (N, 3) where each action is [action_type, dx, dy]
                    action_type: 0=nothing, 1=up, 2=right, 3=down, 4=left, 5=sap
    
    Returns:
        dict containing:
            actions_taken: Tensor of action indices
            units_selected: Tensor of unit indices for active units
            coords_taken: Tensor of coordinate offsets
    """
    
    
    
    # Convert to numpy if not already
    
    # Find active units (where action_type != 0)
    if  len(env_actions)==0:  # This checks if list is empty
        return {
        'actions_taken': torch.tensor([[0]]),  # No-op action
        'units_selected': torch.tensor([[0]]),  # First unit
        'coords_taken': torch.zeros(1, 2)  # No movement
        }
    env_actions = np.array(env_actions)
    active_units = np.where(env_actions[:, 0] != 0)[0]
    
    if len(active_units) == 0:
        # Handle case where no actions are taken
        return {
            'actions_taken': torch.tensor([[0]]),  # No-op action
            'units_selected': torch.tensor([[0]]),  # First unit
            'coords_taken': torch.zeros(1, 2)  # No movement
        }
    
    # For each active unit, convert their action
    action_indices = []
    coord_offsets = []
    
    for unit_idx in active_units:
        action = env_actions[unit_idx]
        action_type = action[0]
        
        if action_type == 5:  # Sap action
            action_idx = 2  # Assuming 1 represents "Zap" in your network
            dx, dy = action[1], action[2]
        else:  # Movement actions
            action_idx = 1  #  0 represents "Move"
            # Convert cardinal directions to dx, dy
            if action_type == 1:    # Up
                dx, dy = 0, -1
            elif action_type == 2:  # Right
                dx, dy = 1, 0
            elif action_type == 3:  # Down
                dx, dy = 0, 1
            elif action_type == 4:  # Left
                dx, dy = -1, 0
                
        action_indices.append(action_idx)
        coord_offsets.append([dx, dy])
    
    # Convert to tensors
    actions_taken = torch.tensor(action_indices).view(-1, 1)
    units_selected = torch.tensor(active_units).view(-1, 1)
    coords_taken = torch.tensor(coord_offsets, dtype=torch.float)
    
    return {
        'actions_taken': actions_taken,
        'units_selected': units_selected,
        'coords_taken': coords_taken
    }

In [11]:
action = each_player["action"]
# Example environment actions
env_actions = action

# Convert to network format
network_actions = env_to_network_actions(action)

print("Network format:")
for k, v in network_actions.items():
    print(f"{k}:\n{v}\n")

Network format:
actions_taken:
tensor([[1],
        [1],
        [1]])

units_selected:
tensor([[0],
        [1],
        [2]])

coords_taken:
tensor([[-1.,  0.],
        [-1.,  0.],
        [-1.,  0.]])



In [12]:
class LuxGameParser:
    def __init__(self, map_size=24, max_units=16):
        self.map_size = map_size
        self.max_units = max_units
    
    def parse_game_file(self, file_path: str) -> List[GameStats]:
        with open(file_path, 'r') as f:
            game_data = json.load(f)
        env_stats = [[],[]]
        actions_stats  = [[],[]]
        game_stats = [[],[]]
        for step in game_data['steps']:
            for i, each_player in enumerate(step):
                obs = each_player["observation"]["obs"]
                obs = json.loads(obs)
                env_stats[i].append(parse_observation(obs))
                
                env_actions = each_player["action"]
                network_actions = env_to_network_actions(env_actions)
                
                total_energy = obs['units']['energy'][i]
                total_energy = sum(num for num in total_energy if num != -1)
                
                actions_stats[i].append(network_actions)
                game_stats[i].append(GameStats(
                    round_won=obs['team_wins'][i] ,
                    relic_points=obs['team_points'][i],
                    total_energy=total_energy,
                    relic_tiles_found=np.mean(obs['relic_nodes_mask']),
                    area_explored=np.mean(obs['sensor_mask'])
                ))
                
        return env_stats, game_stats,actions_stats

In [13]:
luxParse = LuxGameParser()


In [14]:
%%time
env_stats, game_stats, actions_stats = luxParse.parse_game_file( glob(file_locs)[2])

CPU times: user 333 ms, sys: 65.5 ms, total: 399 ms
Wall time: 398 ms


In [18]:
env_stats[0][1]["units"]["position"].shape

torch.Size([1, 2, 16, 2])

In [16]:
len(actions_stats[0])

506

In [133]:
import torch
import torch.nn as nn

class SpatialEncoder(nn.Module):
    def __init__(self, width=24, height=24):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),  # 3 channels: energy, tile_type, sensor_mask
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(64 * (width//2) * (height//2), 256)

    def forward(self, x):
        x = self.conv(x)
        return self.fc(x.view(x.size(0), -1))

class UnitProcessor(nn.Module):
    def __init__(self, max_units=16):
        super().__init__()
        self.max_units = max_units
        self.fc1 = nn.Linear(3, 64)  # position (2) + energy (1)
        self.fc2 = nn.Linear(64, 128)
        
    def forward(self, positions, energy, mask):
        # Reshape it to add it in  the last deminsion
        energy = energy.unsqueeze(-1) 
        # Combine position and energy
        features = torch.cat([positions, energy], dim=-1)  # Shape: (batch, teams, units, 3)
        
        # Process each unit
        x = torch.relu(self.fc1(features))
        x = torch.relu(self.fc2(x))
        
        # Apply mask
        x = x * mask.unsqueeze(-1)
        
        # Pool over units dimension
        return torch.max(x, dim=2)[0]  # Shape: (batch, teams, 128)

class RelicProcessor(nn.Module):
    def __init__(self, max_relics=10):
        super().__init__()
        self.fc1 = nn.Linear(2, 64)  # position only
        self.fc2 = nn.Linear(64, 128)
        
    def forward(self, positions, mask):
        x = torch.relu(self.fc1(positions))
        x = torch.relu(self.fc2(x))
        
        # Apply mask
        x = x * mask.unsqueeze(-1)
        
        # Pool over relics
        return torch.max(x, dim=1)[0]

class LuxAIObservationEncoder(nn.Module):
    def __init__(self, width=24, height=24, max_units=10, max_relics=10,final_dim=512):
        super().__init__()
        self.spatial_encoder = SpatialEncoder(width, height)
        self.unit_processor = UnitProcessor(max_units)
        self.relic_processor = RelicProcessor(max_relics)

        # Calculate total embedding dimension
        spatial_dim = 256
        units_dim = 128 * 2  # For 2 teams
        relic_dim = 128
        total_dim = spatial_dim + units_dim + relic_dim
        
        self.output_layer = nn.Linear(total_dim, final_dim)

    def forward(self, obs):
        # Process spatial features
        spatial_features = torch.cat([
            obs['map_features']['energy'].unsqueeze(1),
            obs['map_features']['tile_type'].unsqueeze(1),
            obs['sensor_mask'].unsqueeze(1)
        ], dim=1)
        spatial_embed = self.spatial_encoder(spatial_features)

        # Process units
        unit_embed = self.unit_processor(
            obs['units']['position'],
            obs['units']['energy'],
            obs['units_mask']
        )
        unit_embed = unit_embed.reshape(unit_embed.size(0), -1)  # Flatten teams dimension

        # Process relics
        relic_embed = self.relic_processor(
            obs['relic_nodes'],
            obs['relic_nodes_mask']
        )

        # Combine all embeddings
        combined = torch.cat([
            spatial_embed,
            unit_embed,
            relic_embed
        ], dim=1)

        return self.output_layer(combined)

class GameStateProcessor:
    def process_observation(self, obs):
        """Convert raw observation dict into tensor format"""
        return {
            'map_features': {
                'energy': torch.FloatTensor(obs['obs']['map_features']['energy']),
                'tile_type': torch.FloatTensor(obs['obs']['map_features']['tile_type'])
            },
            'sensor_mask': torch.FloatTensor(obs['obs']['sensor_mask']),
            'units': {
                'position': torch.FloatTensor(obs['obs']['units']['position']),
                'energy': torch.FloatTensor(obs['obs']['units']['energy'])
            },
            'units_mask': torch.FloatTensor(obs['obs']['units_mask']),
            'relic_nodes': torch.FloatTensor(obs['obs']['relic_nodes']),
            'relic_nodes_mask': torch.FloatTensor(obs['obs']['relic_nodes_mask'])
        }

In [134]:
import torch
import torch.nn as nn

class SequentialLSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        """
        Initialize the LSTM model that processes sequences of observations.
        
        Args:
            input_size (int): Size of each input observation
            hidden_size (int): Number of features in the hidden state
            output_size (int): Size of output action space
            num_layers (int): Number of LSTM layers
        """
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # Main LSTM layer to process sequences
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size, 
            num_layers=num_layers,
            batch_first=True  # Use (batch, seq, feature) format
        )
        
        # Linear layer to project LSTM output to action logits
        self.action_head = nn.Linear(hidden_size, output_size)
        
        # Initialize hidden state and cell state
        self.hidden = None
        self.cell = None
        
    def reset_states(self, batch_size=1, device='cuda'):
        """Reset the hidden and cell states"""
        self.hidden = torch.zeros(1, batch_size, self.hidden_size).to(device)
        self.cell = torch.zeros(1, batch_size, self.hidden_size).to(device)
        
    def forward(self, x):
        """
        Forward pass through the network.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_size)
                            where seq_len=500 in your case
        Returns:
            outputs (torch.Tensor): Action logits for each timestep
        """
        # Process the sequence through LSTM
        lstm_out, (self.hidden, self.cell) = self.lstm(x, (self.hidden, self.cell))
        
        # Project LSTM outputs to action space for each timestep
        action_logits = self.action_head(lstm_out)
        
        return action_logits

In [135]:
l = LuxAIObservationEncoder()
obs = env_stats[0][0]

In [136]:
obs

{'map_features': {'energy': tensor([[[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., 

In [137]:
l(obs).shape

torch.Size([1, 512])

In [138]:
spatial_features = torch.cat([
            obs['map_features']['energy'].unsqueeze(1),
            obs['map_features']['tile_type'].unsqueeze(1),
            obs['sensor_mask'].unsqueeze(1)
        ], dim=1)

In [139]:
spatial_features

tensor([[[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          ...,
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.]]]])

In [149]:

import torch
import torch.nn as nn
import torch.nn.functional as F

class BaseNetwork(nn.Module):
    """Shared base network for processing state input"""
    def __init__(self, lstm_hidden_size=256, embedding_dim=512):
        super(BaseNetwork, self).__init__()
        self.lstm_hidden_size = lstm_hidden_size
        self.embedding_dim = embedding_dim
        
        # LSTM for processing state
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=lstm_hidden_size,
            batch_first=True
        )

    def forward(self, state):
        """
        Process state through LSTM
        Args:
            state: Input state tensor (batch_size, seq_len, input_dim)
        Returns:
            lstm_final: Final LSTM hidden state (batch_size, lstm_hidden_size)
        """
        lstm_out, _ = self.lstm(state)  # [batch_size, seq_len, lstm_hidden_size]
        lstm_final = lstm_out[:, -1, :]  # [batch_size, lstm_hidden_size]
        return lstm_final

class PolicyNetwork(nn.Module):
    """Policy network for action selection"""
    def __init__(self, base_network, action_dim=2, unit_dim=16):
        super(PolicyNetwork, self).__init__()
        self.base = base_network
        self.action_dim = action_dim
        self.unit_dim = unit_dim
        
        # Action ID path
        self.action_embedding = nn.Embedding(action_dim, base_network.embedding_dim)
        self.action_fc = nn.Linear(base_network.lstm_hidden_size, base_network.embedding_dim)
        
        # Target unit path
        self.unit_embedding = nn.Embedding(self.unit_dim, base_network.embedding_dim)
        self.unit_fc = nn.Linear(base_network.lstm_hidden_size, base_network.embedding_dim)
        
        # Coordinate prediction
        self.coord_fc = nn.Linear(base_network.lstm_hidden_size, self.unit_dim * 2)

    def forward(self, state, available_actions, available_units):
        """
        Args:
            state: Input state tensor (batch_size, seq_len, input_dim)
            available_actions: Boolean mask of available actions (batch_size, action_dim)
            available_units: Boolean mask of available units (batch_size, unit_dim)
        Returns:
            Dictionary containing action distributions for each unit
        """
        lstm_final = self.base(state)
        
        # Target unit selection - now returns probabilities for EACH unit
        unit_embeds = self.unit_embedding(torch.arange(self.unit_dim).to(state.device))
        unit_query = self.unit_fc(lstm_final)
        unit_scores = torch.matmul(unit_query, unit_embeds.t())
        unit_scores = unit_scores.masked_fill(~available_units, float('-inf'))
        # Instead of softmax, use sigmoid to allow multiple units to be selected
        unit_probs = torch.sigmoid(unit_scores)
        
        # Action ID selection for each available unit
        action_embeds = self.action_embedding(torch.arange(self.action_dim).to(state.device))
        action_query = self.action_fc(lstm_final)
        action_scores = torch.matmul(action_query, action_embeds.t())
        action_scores = action_scores.masked_fill(~available_actions, float('-inf'))
        
        # Get action probabilities for each unit
        action_probs = F.softmax(action_scores.unsqueeze(1).expand(-1, self.unit_dim, -1), dim=-1)
        
        # Coordinate prediction for each unit
        coords_all = self.coord_fc(lstm_final)
        coordinates = coords_all.view(-1, self.unit_dim, 2)
        
        return {
            'action_probs': action_probs,  # [batch_size, unit_dim, action_dim]
            'unit_probs': unit_probs,      # [batch_size, unit_dim]
            'coordinates': coordinates,     # [batch_size, unit_dim, 2]
        }

    def sample_actions(self, outputs):
        """Sample actions for each selected unit"""
        unit_mask = (outputs['unit_probs'] > 0.5)  # Binary selection of units
        
        if self.training:
            actions = torch.multinomial(outputs['action_probs'][unit_mask], 1)
        else:
            actions = torch.argmax(outputs['action_probs'][unit_mask], dim=-1, keepdim=True)
        
        return {
            'selected_units': unit_mask,
            'actions': actions,
            'coordinates': outputs['coordinates'][unit_mask]
        }
        
class ValueNetwork(nn.Module):
    """Value network for state value estimation"""
    def __init__(self, base_network):
        super(ValueNetwork, self).__init__()
        self.base = base_network
        
        # Value head
        self.value_fc1 = nn.Linear(base_network.lstm_hidden_size, 128)
        self.value_fc2 = nn.Linear(128, 1)
    
    def forward(self, state):
        """
        Args:
            state: Input state tensor (batch_size, seq_len, input_dim)
        Returns:
            value: Predicted state value (batch_size, 1)
        """
        lstm_final = self.base(state)
        hidden = F.relu(self.value_fc1(lstm_final))
        value = self.value_fc2(hidden)
        return value

class BehaviorCloningLoss(nn.Module):
    """Loss function for behavior cloning phase"""
    def __init__(self, action_weight=1.0, unit_weight=1.0, coord_weight=1.0):
        super(BehaviorCloningLoss, self).__init__()
        self.action_weight = action_weight
        self.unit_weight = unit_weight
        self.coord_weight = coord_weight

    def forward(self, outputs, target_units, target_actions, target_coords):
        """
        Args:
            outputs: Dictionary from policy network
            target_units: Binary mask of which units should be selected [batch_size, unit_dim]
            target_actions: Actions for selected units [num_selected_units]
            target_coords: Coordinates for selected units [num_selected_units, 2]
        """
        # Binary cross entropy for unit selection
        unit_loss = F.binary_cross_entropy_with_logits(
            outputs['unit_probs'], target_units.float()
        )
        
        # Cross entropy loss for actions of selected units
        selected_mask = target_units.bool()
        action_loss = -torch.mean(
            torch.log(outputs['action_probs'][selected_mask].gather(1, target_actions))
        )
        
        # MSE loss for coordinates of selected units
        coord_loss = F.mse_loss(
            outputs['coordinates'][selected_mask], target_coords
        )
        
        return {
            'total_loss': unit_loss + action_loss + coord_loss,
            'unit_loss': unit_loss,
            'action_loss': action_loss,
            'coord_loss': coord_loss
        }

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


# Example usage
def main():
    batch_size = 16
    seq_len = 100
    input_dim = 512
    
    # Create networks
    base_net = BaseNetwork()
    policy_net = PolicyNetwork(base_net)
    value_net = ValueNetwork(base_net)
    
    # Create loss function for behavior cloning
    bc_criterion = BehaviorCloningLoss(action_weight=1.0, unit_weight=0.8, coord_weight=0.5)
    
    # Sample inputs
    state = torch.randn(batch_size, seq_len, input_dim)
    available_actions = torch.ones(batch_size, 2, dtype=torch.bool)
    available_units = torch.ones(batch_size, 16, dtype=torch.bool)
    
    # Forward passes
    policy_outputs = policy_net(state, available_actions, available_units)
    state_values = value_net(state)
    
    # Create sample target data - now with multiple units per batch
    target_units = torch.zeros(batch_size, 16, dtype=torch.bool)
    # Randomly select 2-4 units per batch
    for i in range(batch_size):
        num_units = torch.randint(2, 5, (1,)).item()
        selected_units = torch.randperm(16)[:num_units]
        target_units[i, selected_units] = True
    
    # Generate actions and coordinates only for selected units
    num_selected = target_units.sum().item()
    target_actions = torch.randint(0, 2, (num_selected, 1))
    target_coords = torch.randn(num_selected, 2)
    
    # Calculate behavior cloning loss
    bc_loss_dict = bc_criterion(
        policy_outputs,
        target_units,
        target_actions,
        target_coords
    )
    
    # Sample actions for inference
    sampled_actions = policy_net.sample_actions(policy_outputs)
    
    print(f"Parameters in base network: {count_parameters(base_net)}")
    print(f"Parameters in policy network: {count_parameters(policy_net)}")
    print(f"Parameters in value network: {count_parameters(value_net)}")
    
    print(f"\nBehavior Cloning Losses:")
    print(f"Total loss: {bc_loss_dict['total_loss'].item():.4f}")
    print(f"Action loss: {bc_loss_dict['action_loss'].item():.4f}")
    print(f"Unit loss: {bc_loss_dict['unit_loss'].item():.4f}")
    print(f"Coordinate loss: {bc_loss_dict['coord_loss'].item():.4f}")
    
    print(f"\nNetwork Outputs:")
    print(f"Action probs shape: {policy_outputs['action_probs'].shape}")
    print(f"Unit probs shape: {policy_outputs['unit_probs'].shape}")
    print(f"Coordinates shape: {policy_outputs['coordinates'].shape}")
    
    print(f"\nSampled Actions:")
    print(f"Number of units selected: {sampled_actions['selected_units'].sum().item()}")
    print(f"Actions shape: {sampled_actions['actions'].shape}")
    print(f"Coordinates shape: {sampled_actions['coordinates'].shape}")
    
    print(f"\nState values shape: {state_values.shape}")
    
    
if __name__ == "__main__":
    main()

Parameters in base network: 788480
Parameters in policy network: 1069088
Parameters in value network: 821505

Behavior Cloning Losses:
Total loss: 3.7338
Action loss: 1.8542
Unit loss: 0.8929
Coordinate loss: 0.9867

Network Outputs:
Action probs shape: torch.Size([16, 16, 2])
Unit probs shape: torch.Size([16, 16])
Coordinates shape: torch.Size([16, 16, 2])

Sampled Actions:
Number of units selected: 128
Actions shape: torch.Size([128, 1])
Coordinates shape: torch.Size([128, 2])

State values shape: torch.Size([16, 1])


In [154]:
l(obs).shape

torch.Size([1, 512])

In [151]:
batch_size = 16
seq_len = 500
input_dim = 512

# Create networks
base_net = BaseNetwork()
policy_net = PolicyNetwork(base_net)
value_net = ValueNetwork(base_net)

# Create loss function for behavior cloning
bc_criterion = BehaviorCloningLoss(action_weight=1.0, unit_weight=0.8, coord_weight=0.5)

# Sample inputs
state = torch.randn(batch_size, seq_len, input_dim)
available_actions = torch.ones(batch_size, 2, dtype=torch.bool)
available_units = torch.ones(batch_size, 16, dtype=torch.bool)

# Forward passes
policy_outputs = policy_net(state, available_actions, available_units)
state_values = value_net(state)

# Create sample target data - now with multiple units per batch
target_units = torch.zeros(batch_size, 16, dtype=torch.bool)
# Randomly select 2-4 units per batch
for i in range(batch_size):
    num_units = torch.randint(2, 5, (1,)).item()
    selected_units = torch.randperm(16)[:num_units]
    target_units[i, selected_units] = True

# Generate actions and coordinates only for selected units
num_selected = target_units.sum().item()
target_actions = torch.randint(0, 2, (num_selected, 1))
target_coords = torch.randn(num_selected, 2)

# Calculate behavior cloning loss
bc_loss_dict = bc_criterion(
    policy_outputs,
    target_units,
    target_actions,
    target_coords
)

# Sample actions for inference
sampled_actions = policy_net.sample_actions(policy_outputs)

print(f"Parameters in base network: {count_parameters(base_net)}")
print(f"Parameters in policy network: {count_parameters(policy_net)}")
print(f"Parameters in value network: {count_parameters(value_net)}")

print(f"\nBehavior Cloning Losses:")
print(f"Total loss: {bc_loss_dict['total_loss'].item():.4f}")
print(f"Action loss: {bc_loss_dict['action_loss'].item():.4f}")
print(f"Unit loss: {bc_loss_dict['unit_loss'].item():.4f}")
print(f"Coordinate loss: {bc_loss_dict['coord_loss'].item():.4f}")

print(f"\nNetwork Outputs:")
print(f"Action probs shape: {policy_outputs['action_probs'].shape}")
print(f"Unit probs shape: {policy_outputs['unit_probs'].shape}")
print(f"Coordinates shape: {policy_outputs['coordinates'].shape}")

print(f"\nSampled Actions:")
print(f"Number of units selected: {sampled_actions['selected_units'].sum().item()}")
print(f"Actions shape: {sampled_actions['actions'].shape}")
print(f"Coordinates shape: {sampled_actions['coordinates'].shape}")

print(f"\nState values shape: {state_values.shape}")

Parameters in base network: 788480
Parameters in policy network: 1069088
Parameters in value network: 821505

Behavior Cloning Losses:
Total loss: 2.8850
Action loss: 0.8750
Unit loss: 0.8671
Coordinate loss: 1.1429

Network Outputs:
Action probs shape: torch.Size([16, 16, 2])
Unit probs shape: torch.Size([16, 16])
Coordinates shape: torch.Size([16, 16, 2])

Sampled Actions:
Number of units selected: 104
Actions shape: torch.Size([104, 1])
Coordinates shape: torch.Size([104, 2])

State values shape: torch.Size([16, 1])


In [166]:
state_values

tensor([[-0.0152],
        [-0.0281],
        [ 0.0524],
        [-0.0249],
        [ 0.0808],
        [ 0.0091],
        [-0.0017],
        [ 0.0086],
        [ 0.0012],
        [-0.0307],
        [-0.0414],
        [-0.0295],
        [ 0.0160],
        [-0.0524],
        [-0.0156],
        [ 0.0114]], grad_fn=<AddmmBackward0>)

In [150]:
base_net = BaseNetwork()
policy_net = PolicyNetwork(base_net)
value_net = ValueNetwork(base_net)

In [145]:
base_net(obs)

torch.Size([1, 256])


IndexError: too many indices for tensor of dimension 2

In [None]:
obs

{'map_features': {'energy': tensor([[[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
           [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
            -1., -1., 

In [125]:

# Example usage
def main():
    batch_size = 16
    seq_len = 100
    input_dim = 64
    
    # Create networks
    base_net = BaseNetwork()
    policy_net = PolicyNetwork(base_net)
    value_net = ValueNetwork(base_net)
    
    # Create loss function for behavior cloning
    bc_criterion = BehaviorCloningLoss(action_weight=1.0, unit_weight=0.8, coord_weight=0.5)
    
    # Sample inputs
    state = torch.randn(batch_size, seq_len, input_dim)
    available_actions = torch.ones(batch_size, 2, dtype=torch.bool)
    available_units = torch.ones(batch_size, 16, dtype=torch.bool)
    
    # Forward passes
    policy_outputs = policy_net(state, available_actions, available_units)
    state_values = value_net(state)
    
    # Create sample target data - now with multiple units per batch
    target_units = torch.zeros(batch_size, 16, dtype=torch.bool)
    # Randomly select 2-4 units per batch
    for i in range(batch_size):
        num_units = torch.randint(2, 5, (1,)).item()
        selected_units = torch.randperm(16)[:num_units]
        target_units[i, selected_units] = True
    
    # Generate actions and coordinates only for selected units
    num_selected = target_units.sum().item()
    target_actions = torch.randint(0, 2, (num_selected, 1))
    target_coords = torch.randn(num_selected, 2)
    
    # Calculate behavior cloning loss
    bc_loss_dict = bc_criterion(
        policy_outputs,
        target_units,
        target_actions,
        target_coords
    )
    
    # Sample actions for inference
    sampled_actions = policy_net.sample_actions(policy_outputs)
    
    print(f"Parameters in base network: {count_parameters(base_net)}")
    print(f"Parameters in policy network: {count_parameters(policy_net)}")
    print(f"Parameters in value network: {count_parameters(value_net)}")
    
    print(f"\nBehavior Cloning Losses:")
    print(f"Total loss: {bc_loss_dict['total_loss'].item():.4f}")
    print(f"Action loss: {bc_loss_dict['action_loss'].item():.4f}")
    print(f"Unit loss: {bc_loss_dict['unit_loss'].item():.4f}")
    print(f"Coordinate loss: {bc_loss_dict['coord_loss'].item():.4f}")
    
    print(f"\nNetwork Outputs:")
    print(f"Action probs shape: {policy_outputs['action_probs'].shape}")
    print(f"Unit probs shape: {policy_outputs['unit_probs'].shape}")
    print(f"Coordinates shape: {policy_outputs['coordinates'].shape}")
    
    print(f"\nSampled Actions:")
    print(f"Number of units selected: {sampled_actions['selected_units'].sum().item()}")
    print(f"Actions shape: {sampled_actions['actions'].shape}")
    print(f"Coordinates shape: {sampled_actions['coordinates'].shape}")
    
    print(f"\nState values shape: {state_values.shape}")
    
    
if __name__ == "__main__":
    main()

IndexError: too many indices for tensor of dimension 3

In [2]:
policy_outputs["coordinates"].shape

torch.Size([16, 16, 2])

In [3]:
policy_outputs["action_probs"].shape

torch.Size([16, 2])

In [4]:
policy_outputs["coordinates"].shape

torch.Size([16, 16, 2])

NameError: name 'network_to_env_actions' is not defined

In [40]:

# Sample inputs
state = torch.randn(batch_size, seq_len, input_dim)
available_actions = torch.ones(batch_size, 2, dtype=torch.bool)
available_units = torch.ones(batch_size, 16, dtype=torch.bool)