In [1]:
import numpy as np
import torch
from adapt_drones.networks.agents import  RMA_DATT
from adapt_drones.networks.adapt_net import  AdaptationNetwork
from dataclasses import  dataclass
from typing import  Union
from adapt_drones.cfgs.config import *
import gymnasium as gym

In [2]:
model_path = "../runs/adapt-ICRA/traj_v3-RMA_DATT/earthy-snowball-77/best_model.pt"
adapt_path = "../runs/adapt-ICRA/traj_v3-RMA_DATT/earthy-snowball-77/adapt_network.pt"

In [3]:
@dataclass
class Args:
    env_id: str = "traj_v3"
    run_name: str = "earthy-snowball-77"
    seed: int = 15092024
    agent: str = "RMA_DATT"
    scale: bool = True
    idx: Union[int, None] = None
    wind_bool: bool = True

args = Args()
cfg = Config(
    env_id=args.env_id,
    seed=args.seed,
    eval=True,
    run_name=args.run_name,
    agent=args.agent,
    scale=args.scale,
    wind_bool=args.wind_bool,
)
env = gym.make(cfg.env_id, cfg=cfg)
env = gym.wrappers.FlattenObservation(env)
env = gym.wrappers.RecordEpisodeStatistics(env)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
agent = RMA_DATT(
    priv_info_shape=env.unwrapped.priv_info_shape,
    state_shape=env.unwrapped.state_obs_shape,
    traj_shape=env.unwrapped.reference_traj_shape,
    action_shape=env.action_space.shape,
).to(device)
agent.load_state_dict(torch.load(model_path, weights_only=True))
agent.eval()

RMA_DATT(
  (env_encoder): EnvironmentalEncoder(
    (model): Sequential(
      (0): Linear(in_features=10, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
      (4): Linear(in_features=64, out_features=8, bias=True)
    )
  )
  (traj_encoder): TrajectoryEncoder(
    (conv1): Conv1d(1, 32, kernel_size=(3,), stride=(1,))
    (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
    (conv3): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
    (linear): Linear(in_features=19008, out_features=32, bias=True)
  )
  (critic): Critic(
    (model): Sequential(
      (0): Linear(in_features=52, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
      (4): Linear(in_features=64, out_features=64, bias=True)
      (5): Tanh()
      (6): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (actor_mean): Actor(
    (model): Sequential(
      

In [5]:
priv_info_shape = env.unwrapped.priv_info_shape
state_shape = env.unwrapped.state_obs_shape
traj_shape = env.unwrapped.reference_traj_shape
action_shape = env.action_space.shape[0]

state_action_shape = state_shape + action_shape
time_horizon = cfg.network.adapt_time_horizon

adapt_input = time_horizon * state_action_shape
adapt_output = cfg.network.env_encoder_output

adapt_net = AdaptationNetwork(adapt_input, adapt_output).to(device)
adapt_net.load_state_dict(torch.load(adapt_path, weights_only=True))

<All keys matched successfully>

In [6]:
def numel(m: torch.nn.Module, only_trainable: bool = False):
    """
    Returns the total number of parameters used by `m` (only counting
    shared parameters once); if `only_trainable` is True, then only
    includes parameters with `requires_grad = True`
    """
    parameters = list(m.parameters())
    if only_trainable:
        parameters = [p for p in parameters if p.requires_grad]
    unique = {p.data_ptr(): p for p in parameters}.values()
    return sum(p.numel() for p in unique)

In [7]:
print(f"Model has {numel(agent)} parameters")
print(f"Adaptation network has {numel(adapt_net)} parameters")

print("Total number of parameters:", numel(agent) + numel(adapt_net))

Model has 643761 parameters
Adaptation network has 845896 parameters
Total number of parameters: 1489657


In [8]:
# assume that each of the parameters is a float32, which has 4 bytes
# then we can calculate the memory usage in MB
print(f"Model uses {numel(agent) * 4 / 1024**2:.2f} MB")
print(f"Adaptation network uses {numel(adapt_net) * 4 / 1024**2:.2f} MB")

Model uses 2.46 MB
Adaptation network uses 3.23 MB


In [9]:
# this time only the trainable parameters
print(f"Model uses {numel(agent, only_trainable=True) * 4 / 1024**2:.2f} MB")
print(f"Adaptation network uses {numel(adapt_net, only_trainable=True) * 4 / 1024**2:.2f} MB")
print(f"Total uses {(numel(agent, only_trainable=True) + numel(adapt_net, only_trainable=True)) * 4 / 1024**2:.2f} MB")

Model uses 2.46 MB
Adaptation network uses 3.23 MB
Total uses 5.68 MB
