In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
from uav_active_sensing.config import PROCESSED_DATA_DIR
from uav_active_sensing.pytorch_datasets import TinyImageNetDataset
from uav_active_sensing.modeling.img_exploration_env import ImageExplorationEnv, RewardFunction
from uav_active_sensing.modeling.act_vit_mae import ActViTMAEForPreTraining
from uav_active_sensing.plots import visualize_reconstruction
from uav_active_sensing.modeling.ppo import train_ppo

import torch
from torch.utils.data import DataLoader

from transformers import AutoImageProcessor


In [7]:
TINY_IMAGENET_PROCESSED_DIR = PROCESSED_DATA_DIR / "tiny_imagenet/tiny-imagenet-200"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


def tiny_imagenet_collate_fn(batch):
    processed_batch = [
        image[0]["pixel_values"].to(DEVICE)
        for image in batch
    ]

    return torch.cat(processed_batch, dim=0)


image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base", use_fast=True)
tiny_imagenet_train_dataset = TinyImageNetDataset(
    root_dir=TINY_IMAGENET_PROCESSED_DIR, split="train", transform=image_processor
)
tiny_imagenet_val_dataset = TinyImageNetDataset(
    root_dir=TINY_IMAGENET_PROCESSED_DIR, split="val", transform=image_processor
)
train_dataloader = DataLoader(tiny_imagenet_train_dataset, batch_size=1, collate_fn=tiny_imagenet_collate_fn)  # Currently only supports batch size of 1

model = ActViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
reward_function = RewardFunction(model)

### RL trainig with epoch interface

In [9]:

# Repeat for many epochs until convergence


# First half of the epoch. Train RL agent(s) in env instantiated with each batch. Freeze MAE
for i, batch in enumerate(train_dataloader):
    if i == 1:
        break

    train_ppo(batch, reward_function)



# Second half of the epoch. Train MAE here with pairs of batches of sampled and complete images. Freeze agent


Num iters:  10
Curr iteration 1
SPS: 7
Curr iteration 2
SPS: 3
Curr iteration 3
SPS: 3
Curr iteration 4
SPS: 3
Curr iteration 5
SPS: 3
Curr iteration 6
SPS: 3
Curr iteration 7
SPS: 3
Curr iteration 8
SPS: 3
Curr iteration 9
SPS: 3
Curr iteration 10
SPS: 3
