In [81]:
from accelerate import Accelerator
import json 
import os 
from typing import Dict, Any

config = {
        "pretrained_name": "google/vit-base-patch16-224",
        "model_path": "./checkpoints/woven-dust-5/maskgen_epoch_0",
        "results_path": "/scratch365/dpan/new_results/maskgen_final",
        "csv_path": "./new_results/maskgen_final",
        "max_samples": 100,
        "dataset_split": "tiny",
        "num_samples": 1000,
        "batch_size":1,
        # "auc_method": "prob", # 'acc'
        "auc_method": "acc", 
        # dummy trainer
        "num_steps": 5,
        "mini_batch_size": 256,
        "ppo_epochs": 1,
        "epsilon": 0.0,
        "lr": 1e-4,
        "clip_param": 0.2,
        "l_kl": 1,
        "l_actor": 1.0,
        "l_entropy": 0.00001,
        "gamma": 0.50,
        "tau": 0.95,
        "max_epochs": 1,
        "save_interval": 50,
        "save_path": "./checkpoints"
}

accelerator = Accelerator()
device = accelerator.device

print('result_path', config['results_path'])
print("csv_path", config['csv_path'])

result_path /scratch365/dpan/new_results/maskgen_final
csv_path ./new_results/maskgen_final


In [82]:
from maskgen.utils.data_utils import get_imagenet_dataloader
from maskgen.utils.model_utils import load_exp_and_target_model
from maskgen.trainer import PPOTrainer
# get models 
target_model, maskgen_model, processor = load_exp_and_target_model(config, device)

# get dummy trainer
trainer = PPOTrainer(maskgen_model=maskgen_model, target_model=target_model, config=config)

# get dataloader
dataloader = get_imagenet_dataloader(split='tiny', 
                                    batch_size=config['batch_size'], 
                                    processor=processor, 
                                    shuffle=False,
                                    num_samples=config['num_samples'])


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  components = torch.load(f"{save_path}/other_components.pt")


Converted base model layers to PEFT: ['encoder.layer.0.attention.attention.query', 'encoder.layer.0.attention.attention.key', 'encoder.layer.0.attention.attention.value', 'encoder.layer.0.attention.output.dense', 'encoder.layer.0.intermediate.dense', 'encoder.layer.0.output.dense', 'encoder.layer.1.attention.attention.query', 'encoder.layer.1.attention.attention.key', 'encoder.layer.1.attention.attention.value', 'encoder.layer.1.attention.output.dense', 'encoder.layer.1.intermediate.dense', 'encoder.layer.1.output.dense', 'encoder.layer.2.attention.attention.query', 'encoder.layer.2.attention.attention.key', 'encoder.layer.2.attention.attention.value', 'encoder.layer.2.attention.output.dense', 'encoder.layer.2.intermediate.dense', 'encoder.layer.2.output.dense', 'encoder.layer.3.attention.attention.query', 'encoder.layer.3.attention.attention.key', 'encoder.layer.3.attention.attention.value', 'encoder.layer.3.attention.output.dense', 'encoder.layer.3.intermediate.dense', 'encoder.layer

Repo card metadata block was not found. Setting CardData to empty.


In [94]:
from tqdm import tqdm
import torch
import numpy as np


def reward_t_test(reward_np, value_np):
    r_mean = reward_np.mean()
    r_std = reward_np.std()

    A = r_mean - value_np

    stat = A / (r_std / np.sqrt(len(reward_np)))
    # print(f"stat: {stat}, r_mean: {r_mean}, r_std: {r_std}, value_np: {value_np}")
    return stat

def get_heatmap(dist):
    prob = dist.probs
    heatmap = prob.view(1, 14, 14)  # Shape: [N, grid_size, grid_size]
    return heatmap

all_inputs = []
all_heatmaps = []
accept_inputs = []
accept_heatmaps = []
reject_inputs = []
reject_heatmaps = []


for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc="Processing batches"):
    inputs = batch['pixel_values'].to(device)
    with torch.no_grad():
        predicted_class_idx = target_model(inputs).logits.argmax(-1)
        dist, value, mu_sum_logprob = maskgen_model.get_dist_critic(inputs, predicted_class_idx.unsqueeze(1))
        heatmap = get_heatmap(dist)
        reward_list = []
        for i in range(5):
            action = trainer.get_epsilon_greedy_action(dist, 0.0)
            _, reward = trainer.get_action_reward(inputs, action, predicted_class_idx.unsqueeze(1))
            reward_list.append(reward)
        reward_np = np.array([x.item() for x in reward_list])
        value_np = value.item()
        stat = reward_t_test(reward_np, value_np)
        # print(stat)

        # save inputs and heatmaps
        inputs_np = inputs.cpu().numpy()
        heatmap_np = heatmap.cpu().numpy()
        all_inputs.append(inputs_np)
        all_heatmaps.append(heatmap_np)
        if np.abs(stat) < 100:
            accept_inputs.append(inputs_np)
            accept_heatmaps.append(heatmap_np)
        else:
            reject_inputs.append(inputs_np)
            reject_heatmaps.append(heatmap_np)


  stat = A / (r_std / np.sqrt(len(reward_np)))
Processing batches: 100%|██████████| 1000/1000 [00:44<00:00, 22.50it/s]


In [95]:
from maskgen.utils.save_utils import save_pixel_heatmap_pairs

all_inputs_np = np.concatenate(all_inputs, axis=0)
all_heatmaps_np = np.concatenate(all_heatmaps, axis=0)
accept_inputs_np = np.concatenate(accept_inputs, axis=0)
accept_heatmaps_np = np.concatenate(accept_heatmaps, axis=0)
reject_inputs_np = np.concatenate(reject_inputs, axis=0)
reject_heatmaps_np = np.concatenate(reject_heatmaps, axis=0)

save_path = config['results_path']

def save_to_file(inputs, heatmaps, save_path):
    # ensure save path exists
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    new_save_path = os.path.join(save_path, 'pixel_heatmap_pairs.npz')
    print("Saving pixel-heatmap pairs to:", new_save_path)
    save_pixel_heatmap_pairs(inputs, heatmaps, new_save_path)

save_to_file(all_inputs_np, all_heatmaps_np, save_path)
save_to_file(accept_inputs_np, accept_heatmaps_np, os.path.join(save_path, 'accept'))
save_to_file(reject_inputs_np, reject_heatmaps_np, os.path.join(save_path, 'reject'))





Saving pixel-heatmap pairs to: /scratch365/dpan/new_results/maskgen_final/pixel_heatmap_pairs.npz
Saving pixel-heatmap pairs to: /scratch365/dpan/new_results/maskgen_final/accept/pixel_heatmap_pairs.npz
Saving pixel-heatmap pairs to: /scratch365/dpan/new_results/maskgen_final/reject/pixel_heatmap_pairs.npz


In [96]:
reject_inputs_np.shape

(312, 3, 224, 224)