In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import gymnasium as gym
from transformers import AutoImageProcessor

In [3]:
from uav_active_sensing.config import PROCESSED_DATA_DIR
from uav_active_sensing.pytorch_datasets import TinyImageNetDataset
from uav_active_sensing.modeling.ppo import make_env, PPOConfig
from uav_active_sensing.modeling.img_exploration_env import RewardFunction, ActViTMAEForPreTraining
from uav_active_sensing.modeling.train import tiny_imagenet_collate_fn

[32m2025-02-13 14:52:34.497[0m | [1mINFO    [0m | [36muav_active_sensing.config[0m:[36m<module>[0m:[36m13[0m - [1mPROJ_ROOT path is: /home/tcouso/uav_active_sensing[0m
[32m2025-02-13 14:52:34.500[0m | [1mINFO    [0m | [36muav_active_sensing.config[0m:[36m<module>[0m:[36m28[0m - [1mUsing device: cpu[0m


In [4]:

TINY_IMAGENET_PROCESSED_DIR = PROCESSED_DATA_DIR / "tiny_imagenet/tiny-imagenet-200"
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base", use_fast=True)
tiny_imagenet_train_dataset = TinyImageNetDataset(root_dir=TINY_IMAGENET_PROCESSED_DIR, split="train", transform=image_processor)
tiny_imagenet_train_loader = DataLoader(tiny_imagenet_train_dataset, batch_size=16, collate_fn=tiny_imagenet_collate_fn)

In [5]:
model = ActViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
reward_function = RewardFunction(model)
ppo_config = PPOConfig()

In [6]:
for batch in tiny_imagenet_train_loader:
  envs = gym.vector.SyncVectorEnv(
      [make_env(img.unsqueeze(0), reward_function, ppo_config.gamma) for img in batch]
  )
  break


In [7]:
envs.single_observation_space.shape

(150528,)

In [8]:
obs = torch.zeros((ppo_config.num_steps, ppo_config.num_envs) + envs.single_observation_space.shape)

In [9]:
obs.shape

torch.Size([128, 4, 150528])

In [10]:
obs[0].shape

torch.Size([4, 150528])

In [11]:
next_obs, info = envs.reset()
next_obs = torch.Tensor(next_obs)

In [12]:
next_obs.shape

torch.Size([16, 150528])