In [16]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from trainers.torch.networks import NetworkBody
from mlagents_envs.base_env import ObservationSpec, DimensionProperty, ObservationType
from mlagents.trainers.settings import NetworkSettings, TrainerSettings
from mlagents_envs.base_env import ActionSpec
from trainers.policy.torch_policy import TorchPolicy
from mlagents_envs.base_env import BehaviorSpec
import wandb

In [17]:
observation_specs = [ObservationSpec(
            name="position_observation",
            shape=(12,),  # 3D vector
            dimension_property=(DimensionProperty.NONE,),  # Must be a tuple
            observation_type=ObservationType.DEFAULT
        )]
network_settings_importance = NetworkSettings(
    deterministic=False,
    memory=None,
    hidden_units=2,
    num_layers=2,
)
importance_network = NetworkBody(observation_specs, network_settings_importance)

In [18]:
positions = []
for y in range(-9, 11, 1):
    for x in range(-9, 11, 1):
        positions.append([x-0.5, 0.5, y-0.5])
positions = [torch.tensor(positions)]

In [19]:
wandb.init(
    project='bias_importance'
)

In [None]:
target_tensor = torch.tensor([0.0, 1.0])

# Define the optimizer
optimizer = optim.Adam(importance_network.parameters(), lr=0.001)

# Define the loss function
mse_loss = nn.MSELoss()

# Create positions
positions = []
for y in range(-9, 11, 1):
    for x in range(-9, 11, 1):
        positions.append([x - 0.5, 0.5, y - 0.5])

# Training loop
num_epochs = 100000000  # Define the number of epochs
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for pos in positions:
        # Add random values from [0, 1] of size 9
        random_values = np.random.rand(9).tolist()
        input_data = pos + random_values

        # Convert to tensor and wrap in a list
        input_tensor = torch.tensor(input_data, dtype=torch.float32).unsqueeze(0)
        input_tensor_list = [input_tensor]  # Wrap in a list if the network expects a list

        # Forward pass through the network
        output = importance_network(input_tensor_list)[0].squeeze(0)

        # Compute the loss
        loss = mse_loss(output, target_tensor)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate loss
        epoch_loss += loss.item()

    # Calculate average loss for the epoch
    avg_loss = epoch_loss / len(positions)

    # Log the average loss to wandb
    wandb.log({"epoch": epoch + 1, "loss": avg_loss})

    # Print the average loss for the epoch
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}')

# Finish the wandb run
wandb.finish()

Epoch 1/100000000, Loss: 2.618413077145815
Epoch 2/100000000, Loss: 1.0066130000934936
Epoch 3/100000000, Loss: 0.3463236474932637
Epoch 4/100000000, Loss: 0.13324075818120037
Epoch 5/100000000, Loss: 0.05994128439517226
Epoch 6/100000000, Loss: 0.03005171871453058
Epoch 7/100000000, Loss: 0.017207142108527477
Epoch 8/100000000, Loss: 0.013849900011191494
Epoch 9/100000000, Loss: 0.012019930292954086
Epoch 10/100000000, Loss: 0.011454102947029697
Epoch 11/100000000, Loss: 0.011156467818266265
Epoch 12/100000000, Loss: 0.009397874138294
Epoch 13/100000000, Loss: 0.009401539477360714
Epoch 14/100000000, Loss: 0.009402829001579746
Epoch 15/100000000, Loss: 0.008089678172404344
Epoch 16/100000000, Loss: 0.007077647880632012
Epoch 17/100000000, Loss: 0.006387554850293568
Epoch 18/100000000, Loss: 0.006471185615278685
Epoch 19/100000000, Loss: 0.0056113213393709695
Epoch 20/100000000, Loss: 0.0059889689259936315
Epoch 21/100000000, Loss: 0.004965538153433613
Epoch 22/100000000, Loss: 0.00502

In [None]:
model_path = "/home/rmarr/Projects/visibility-game-env/informed_init"
checkpoint_path = os.path.join(model_path, f"{'bias_importance'}.pt")
state_dict = importance_network.state_dict()
torch.save(state_dict, checkpoint_path)