In [104]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [105]:
from transformers import AutoImageProcessor
import torch
from torch.utils.data import DataLoader

from uav_active_sensing.pytorch_datasets import TinyImageNetDataset, tiny_imagenet_collate_fn
from uav_active_sensing.modeling.img_env.img_exploration_env import RewardFunction, ImageExplorationEnv, ImageExplorationEnvConfig
from uav_active_sensing.modeling.mae.act_vit_mae import ActViTMAEForPreTraining
# from uav_active_sensing.modeling.ppo import train_ppo, make_env, PPOConfig, PPOAgent
from uav_active_sensing.plots import visualize_tensor

from stable_baselines3 import PPO, SAC
from uav_active_sensing.modeling.agents.rl_agent_feature_extractor import CustomResNetFeatureExtractor

In [153]:
rl_num_envs = 1
rl_batch_size = 4
rl_num_steps = rl_batch_size * rl_num_envs * 4

image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base", use_fast=True)
tiny_imagenet_train_dataset = TinyImageNetDataset(split="train", transform=image_processor)
tiny_imagenet_train_loader = DataLoader(tiny_imagenet_train_dataset, batch_size=10, collate_fn=tiny_imagenet_collate_fn)

# Pretrained model and reward function
model = ActViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
reward_function = RewardFunction(model)

# Create a dummy environment to initialize the model
dummy_batch = next(iter(tiny_imagenet_train_loader))  # Take one image as a dummy input for env initialization
env_config = ImageExplorationEnvConfig(img=dummy_batch, reward_function=reward_function)
env = ImageExplorationEnv(env_config)

### Batch env

In [154]:
sample_action = env.action_space.sample()
sample_action = torch.from_numpy(sample_action)
sample_action = env._denormalize_action(sample_action)

In [155]:
sample_action

tensor([[ 15,   6,  12],
        [  9,  10,  10],
        [  6,  -9,  13],
        [ -3,   4,  -5],
        [ -3,   7,  10],
        [ -1,  15,  -2],
        [ -9,   3,  -6],
        [ -6,  11,   1],
        [-11,  -7,  -6],
        [ -8,  13,   2]], dtype=torch.int32)

In [165]:
env.move(sample_action)

torch.Size([10])


TypeError: only integer tensors of a single element can be converted to an index

In [166]:
fov_bbox = env.fov_bbox

top = fov_bbox[:, 0]
left = fov_bbox[:, 1]
bottom = fov_bbox[:, 2]
right = fov_bbox[:, 3]

  # # top, bottom, left, right = env.fov_bbox
  # obs = env.img[:, :, top:bottom, left:right].clone()

In [174]:
B, C, H, W = env.img.shape  # Get batch size, channels, height, width

# Generate row and column indices for each image
row_indices = torch.arange(H, device=env.img.device).view(1, -1, 1)  # Shape (1, H, 1)
col_indices = torch.arange(W, device=env.img.device).view(1, 1, -1)  # Shape (1, 1, W)

# Create masks to keep only the pixels within the [top, bottom) and [left, right) range
row_mask = (row_indices >= top.view(B, 1, 1)) & (row_indices < bottom.view(B, 1, 1))  # (B, H, 1)
col_mask = (col_indices >= left.view(B, 1, 1)) & (col_indices < right.view(B, 1, 1))  # (B, 1, W)

# Get final mask for each image
mask = row_mask & col_mask  # (B, H, W)

# Use mask to index the images
cropped_images = env.img * mask[:, None, :, :]  # (B, C, H, W), setting out-of-bounds pixels to 0


In [177]:
cropped_images

tensor([[[[-0.0000, -0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000]],

         [[-0.0000, -0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  

In [127]:
bottom

tensor([ 62., 150.,  90., 131., 162., 136.,  44., 218., 220.,  20.])

### Agent training

In [None]:

policy_kwargs = dict(
    features_extractor_class=CustomResNetFeatureExtractor,
    features_extractor_kwargs=dict(features_dim=512),
)

resnet_rl_model = PPO(
    "CnnPolicy",
    env,
    policy_kwargs=policy_kwargs,
    verbose=1,
      n_steps=rl_num_steps,
    batch_size=rl_batch_size
)

for i, batch in enumerate(tiny_imagenet_train_loader):
    # Set image with batch change: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html
    vec_env = resnet_rl_model.get_env()  
    vec_env.env_method("set_img", batch)
    resnet_rl_model.learn(total_timesteps=2 * rl_num_steps, progress_bar=False)
    
    if i == 1:
        break

In [None]:
# test_env = ImageExplorationEnv(env_config)
test_env = resnet_rl_model.env.envs[0].env


# Reset environment and get initial observation
obs, _ = test_env.reset()

# Number of steps to run the evaluation
num_eval_steps = 30

# Disable training mode
resnet_rl_model.policy.eval()

for step in range(num_eval_steps):

    obs_tensor = torch.tensor(obs, dtype=torch.float32)
    action, _ = resnet_rl_model.predict(obs_tensor, deterministic=True)
    obs, reward, terminated, truncated, _ = test_env.step(action)

    # Print results
    print(f"Step {step + 1}: Action={action}, Reward={reward}")

    # if terminated or truncated:
    #     print("Episode ended, resetting environment.")
    #     obs, _ = test_env.reset()

In [None]:
visualize_tensor(test_env.img)
visualize_tensor(test_env.sampled_img)