In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import random as rd
import numpy as np
import torch
from torch.utils.data import DataLoader
import gymnasium as gym
from transformers import AutoImageProcessor

In [5]:
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from transformers import AutoImageProcessor
from stable_baselines3 import PPO
from stable_baselines3.common import env_checker


from uav_active_sensing.pytorch_datasets import TinyImageNetDataset, tiny_imagenet_collate_fn
from uav_active_sensing.modeling.img_env.img_exploration_env import ImageExplorationEnv, ImageExplorationEnvConfig, RewardFunction
from uav_active_sensing.modeling.mae.act_vit_mae import ActViTMAEForPreTraining
from uav_active_sensing.modeling.agents.rl_agent_feature_extractor import CustomResNetFeatureExtractor
from uav_active_sensing.config import DEVICE
from uav_active_sensing.plots import visualize_tensor, visualize_act_mae_reconstruction

In [6]:
generator = torch.Generator().manual_seed(1230)

In [7]:
class SingleImageDataset(Dataset):
    def __init__(self, original_dataset: Dataset, index: int):
        self.image, self.label = original_dataset[index]

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.image, self.label

In [8]:
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base", use_fast=True)  # TODO: Download this in advance
tiny_imagenet_train_dataset = TinyImageNetDataset(split="train", transform=image_processor)
random_index = rd.randint(0, len(tiny_imagenet_train_dataset) - 1)
single_image_dataset = SingleImageDataset(tiny_imagenet_train_dataset, random_index)
tiny_imagenet_train_loader = DataLoader(single_image_dataset, batch_size=2, collate_fn=tiny_imagenet_collate_fn)

# Pretrained model and reward function
act_mae_model = ActViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")  # TODO: Download this in advance
reward_function = RewardFunction(act_mae_model,
                                 num_samples=1,
                                 reward_increase=False,
                                 patch_size=16,
                                 masking_ratio=0.8,
                                 generator=generator,
                                 )

# Create a dummy environment to initialize the model
img = next(iter(tiny_imagenet_train_loader))[0]  # Take one image as a dummy input for env initialization

### Test random movement

In [9]:
env_config = ImageExplorationEnvConfig(reward_function=reward_function, seed=45, sensor_size=16 * 2)
env = ImageExplorationEnv(img, env_config)

In [10]:
env_checker.check_env(env)



In [118]:

env.reset()
for j in range(128):
    sample_action = env.action_space.sample()
    # print(sample_action)
    # print(env._denormalize_action(torch.from_numpy(sample_action)))
    env.step(sample_action)

In [None]:
masked_sampled_img = env._reward_function.sampled_img_random_masking(env.sampled_img)
visualize_act_mae_reconstruction(env.img.unsqueeze(0), env.sampled_img.unsqueeze(0), masked_sampled_img.unsqueeze(0), act_mae_model)

### Random masking of sampled image

In [161]:
B, C, H, W = env.sampled_img.shape
patch_size = 16
masking_ratio = 0.5
x = torch.clone(env.sampled_img)
x = x.permute(0, 2, 3, 1)

num_patches_H = H // patch_size
num_patches_W = W // patch_size

kc, kh, kw = patch_size, patch_size, patch_size  # kernel size
dc, dh, dw = patch_size, patch_size, patch_size  # stride

patches = x.unfold(1, kc, dc).unfold(2, kh, dh)
nan_mask = torch.isnan(patches)
patch_nan_mask = nan_mask.any(dim=(3, 4, 5))
valid_patches = ~patch_nan_mask
valid_indices = torch.nonzero(valid_patches, as_tuple=True)

num_valid = valid_indices[0].shape[0]  # Count of valid patches, error
num_to_mask = int(masking_ratio * num_valid)  # Number of patches to mask

mask_indices = torch.randperm(num_valid, generator=generator)[:num_to_mask]
selected_patches = tuple(idx[mask_indices] for idx in valid_indices)  # Extract selected patch indices

# Apply NaN masking
patches[selected_patches] = float('nan')
reconstructed = patches.permute(0, 3, 1, 4, 2, 5).view(B, C, num_patches_H * patch_size, num_patches_W * patch_size)

In [None]:
visualize_tensor(reconstructed)
visualize_tensor(env.sampled_img)

In [103]:
def sampled_img_random_masking(sampled_img: torch.Tensor, masking_ratio: float, generator: torch.Generator) -> torch.Tensor:

    B, C, H, W = sampled_img.shape
    patch_size = 16
    x = torch.clone(sampled_img)
    x = x.permute(0, 2, 3, 1)

    num_patches_H = H // patch_size
    num_patches_W = W // patch_size

    kc, kh = patch_size, patch_size  # kernel size
    dc, dh = patch_size, patch_size  # stride

    patches = x.unfold(1, kc, dc).unfold(2, kh, dh)
    nan_mask = torch.isnan(patches)
    patch_nan_mask = nan_mask.any(dim=(3, 4, 5))
    valid_patches = ~patch_nan_mask
    valid_indices = torch.nonzero(valid_patches, as_tuple=True)

    num_valid = valid_indices[0].shape[0]  # Count of valid patches, error
    num_to_mask = int(masking_ratio * num_valid)  # Number of patches to mask

    mask_indices = torch.randperm(num_valid, generator=generator)[:num_to_mask]
    selected_patches = tuple(idx[mask_indices] for idx in valid_indices)  # Extract selected patch indices

    # Apply NaN masking
    patches[selected_patches] = float('nan')

    # Reassemble image from patches
    reconstructed = patches.permute(0, 3, 1, 4, 2, 5).view(B, C, num_patches_H * patch_size, num_patches_W * patch_size)

    return reconstructed

### Test kernel size increase

In [99]:
env_config = ImageExplorationEnvConfig(reward_function=reward_function, seed=0)
env = ImageExplorationEnv(img, env_config)

In [None]:

env.reset()

for i in range(env.img_h // env.sensor_h // 2 - 1):
    sample_action = np.array([[0, 0, 1] for i in range(env.batch_size)])
    env.step(sample_action)
    # env._sensor_pos = env.sensor_max_pos_from_kernel_size
    # print(env._kernel_size)
    # print(env.sensor_max_pos_from_kernel_size)
    # print(env.fov_bbox)

for k in range(env.batch_size):
    visualize_tensor(env.img[k])
    visualize_tensor(env.sampled_img[k])

### Test deterministic behaviour for a given seed-

In [33]:
env_config = ImageExplorationEnvConfig(img=dummy_batch, reward_function=reward_function)
env = ImageExplorationEnv(env_config)
env.reset()
first_run_actions = []
for j in range(5):
    sample_action = env.action_space.sample()
    # print(sample_action)
    # print(env._denormalize_action(torch.from_numpy(sample_action)))
    env.step(sample_action)
    first_run_actions.append(sample_action)

first_run = env.sampled_img.detach().clone()

env_config = ImageExplorationEnvConfig(img=dummy_batch, reward_function=reward_function)
env = ImageExplorationEnv(env_config)
second_run_actions = []
env.reset()
for j in range(5):
    sample_action = env.action_space.sample()
    # print(sample_action)
    # print(env._denormalize_action(torch.from_numpy(sample_action)))
    env.step(sample_action)
    second_run_actions.append(sample_action)

second_run = env.sampled_img.detach().clone()

In [None]:
# Images shoud be identical if the sampling is deterministic (fixed seed)

for j in range(5):
    print((first_run_actions[j] == second_run_actions[j]).all())

for k in range(env.batch_size):
    visualize_tensor(first_run[k])
    visualize_tensor(second_run[k])
    print("-" * 50)