In [None]:
# ------------------------------------------------------------------------
# Modified from PROB: Probabilistic Objectness for Open World Object Detection
# Orr Zohar, Jackson Wang, Serena Yeung
# ------------------------------------------------------------------------
import os
import matplotlib.pyplot as plt
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import argparse
import random
import torch
from torch.utils.data import DataLoader
import util.misc as utils
import datasets.samplers as samplers
from datasets import build_dataset
from models import build_model

import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
def get_args_parser():
    parser = argparse.ArgumentParser('RWD - FOMO Detector', add_help=False)
    parser.add_argument('--batch_size', default=10, type=int)
    # dataset parameters
    parser.add_argument('--output_dir', default='tmp/rwod',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--viz', action='store_true')
    parser.add_argument('--num_workers', default=2, type=int)

    
    ################ RWOD ################
    parser.add_argument('--test_set', default='test.txt', help='testing txt files')
    parser.add_argument('--train_set', default='train.txt', help='training txt files')

    parser.add_argument('--dataset', default='?', help='defines which dataset is used.')
    parser.add_argument('--data_root', default='./data', type=str)
    parser.add_argument('--data_task', default='RWOD', type=str)
    parser.add_argument('--unknown_classnames_file', default='', type=str)
    parser.add_argument('--classnames_file', default='known_classnames.txt', type=str)
    parser.add_argument('--prev_classnames_file', default='known_classnames.txt', type=str)
    parser.add_argument('--templates_file', default='best_templates.txt', type=str)
    parser.add_argument('--attributes_file', default='attributes1.json', type=str)
    parser.add_argument('--use_attributes', action='store_true')
    parser.add_argument('--att_selection', action='store_true')
    parser.add_argument('--image_conditioned_file', default='few_shot_data.json', type=str)

    parser.add_argument('--image_conditioned', action='store_true')
    parser.add_argument('--att_refinement',  action='store_true')
    parser.add_argument('--att_adapt',  action='store_true')
    parser.add_argument('--unk_proposal',  action='store_true')

    parser.add_argument('--num_few_shot', default=100, type=int)
    parser.add_argument('--num_att_per_class', default=25, type=int)

    # model config
    parser.add_argument('--unk_methods', default='sigmoid-max-mcm', type=str)#,sigmoid-max
    parser.add_argument('--unk_method', default='sigmoid-max-mcm', type=str)
    parser.add_argument('--model_type', default='owl_vit', type=str)
    parser.add_argument('--model_name', default='google/owlvit-base-patch16', type=str)
    parser.add_argument('--unk_LLM', action='store_true')
    parser.add_argument('--image_resize', default=768, type=int,
                        help='image resize 768 for owlvit-base models, 840 for owlvit-large models')
    parser.add_argument('--pred_per_im', default=100, type=int)
    parser.add_argument('--PREV_INTRODUCED_CLS', default=0, type=int)
    parser.add_argument('--CUR_INTRODUCED_CLS', default=30, type=int)
    parser.add_argument('--prev_output_file', default='', type=str)
    parser.add_argument('--output_file', default='', type=str)
    # logging
    parser.add_argument('--TCP', default='295499', type=str)

    return parser


In [None]:
def get_dataloader(args, dataset, train=True):
    if args.distributed:
        sampler = samplers.DistributedSampler(dataset, shuffle=train)
    else:
        if train:
            sampler = torch.utils.data.RandomSampler(dataset)
        else:
            sampler = torch.utils.data.SequentialSampler(dataset)

    if train:
        batch_sampler = torch.utils.data.BatchSampler(sampler, args.batch_size, drop_last=True)
        data_loader = DataLoader(dataset, batch_sampler=batch_sampler,
                                 collate_fn=utils.collate_fn, num_workers=args.num_workers,
                                 pin_memory=True)
    else:
        data_loader = DataLoader(dataset, args.batch_size, sampler=sampler,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
                                 pin_memory=True)
    return data_loader


def match_name_keywords(n, name_keywords):
    out = False
    for b in name_keywords:
        if b in n:
            out = True
            break
    return out


In [None]:
dataset="DIOR_FIN"

if dataset=="AQUA":
    args_str = '--model_name google/owlvit-base-patch16 --num_few_shot 100 --batch_size 6 ' \
               '--PREV_INTRODUCED_CLS 0 --CUR_INTRODUCED_CLS 4 --TCP 29550 ' \
               '--dataset AQUA --image_conditioned --image_resize 768 ' \
               '--att_adapt --att_selection --att_refinement '

elif dataset == "DIOR_FIN":
    args_str = '--model_name google/owlvit-base-patch16 --num_few_shot 100 --batch_size 6 ' \
               '--PREV_INTRODUCED_CLS 0 --CUR_INTRODUCED_CLS 10 --TCP 29550 ' \
               '--dataset DIOR_FIN --image_conditioned --image_resize 768 ' \
               '--att_adapt --att_selection --att_refinement '
    
elif dataset == "NEUROSURGICAL_TOOLS_FIN":
    args_str = '--model_name google/owlvit-base-patch16 --num_few_shot 100 --batch_size 6 ' \
               '--PREV_INTRODUCED_CLS 0 --CUR_INTRODUCED_CLS 6 --TCP 29550 ' \
               '--dataset NEUROSURGICAL_TOOLS_FIN --image_conditioned --image_resize 768 ' \
               '--att_adapt --att_selection --att_refinement '
    
elif dataset == "XRAY":
    args_str = '--model_name google/owlvit-base-patch16 --num_few_shot 100 --batch_size 6 ' \
               '--PREV_INTRODUCED_CLS 0 --CUR_INTRODUCED_CLS 6 --TCP 29550 ' \
               '--dataset XRAY --image_conditioned --image_resize 768 ' \
               '--att_adapt --att_selection --att_refinement '
    
elif dataset == "SYNTH":
    args_str = '--model_name google/owlvit-base-patch16 --num_few_shot 100 --batch_size 6 ' \
               '--PREV_INTRODUCED_CLS 0 --CUR_INTRODUCED_CLS 30 --TCP 29550 ' \
               '--dataset SYNTH --image_conditioned --image_resize 768 ' \
               '--att_adapt --att_selection --att_refinement '
    
# Split the arguments string into a list of arguments
# It's important to note that each space-separated element is a separate item in the list
args_list = args_str.split()
parser = get_args_parser()
# Now parse the arguments list
args = parser.parse_args(args=args_list)

In [None]:
print(args)
utils.init_distributed_mode(args)
print("git:\n  {}\n".format(utils.get_sha()))

In [None]:
device = torch.device(args.device)

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

###### get datasets ######
if len(args.train_set) > 0:
    dataset_train = build_dataset(args, args.train_set)
    data_loader_train = get_dataloader(args, dataset_train, train=False)

if len(args.test_set) > 0:
    dataset_val = build_dataset(args, args.test_set)
    data_loader_val = get_dataloader(args, dataset_val, train=False)
args.neg_sup_ep = 1
args.neg_sup_lr = 5e-05

model, postprocessors = build_model(args)
model.to(device)
model.eval()

In [None]:
unk_method = args.unk_methods.split(",")[0]
model.unk_head.method = unk_method

In [None]:
from datasets.open_world_eval_attributes import OWEvaluator
from collections import defaultdict

att_W = model.unk_head.att_W
num_att = att_W.shape[0]
# Initialize the sum and count dictionaries
attribute_sums = defaultdict(lambda: np.zeros(num_att))  # Assuming there are 96 attributes
class_counts = defaultdict(int)

with torch.no_grad():
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    coco_evaluator = OWEvaluator(dataset_val, args=args)
    #import ipdb; ipdb.set_trace()
    for samples, targets in metric_logger.log_every(data_loader_val, 10, header):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples.tensors)
        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        results = postprocessors(outputs, orig_target_sizes)
        for result in results:
            labels = result['labels']
            attributes = result['attributes']
            for label, attribute in zip(labels, attributes):
                lab = label.item()  
                att = attribute.detach().cpu().numpy() 
                attribute_sums[lab] += att# Accumulate the attribute scores
                class_counts[lab] += 1  # Count the images per class

        

In [None]:
# Calculate average attributes per class ensuring tensors are moved to CPU
average_attributes_per_class = {cls: (attribute_sums[cls] / count) for cls, count in class_counts.items() if count > 0}

# Now create the DataFrame
average_attributes_df = pd.DataFrame({k: v for k, v in average_attributes_per_class.items()}).T.sort_index()


In [None]:
# Convert to a DataFrame
average_attributes_df.columns = model.attributes_texts + [0]  # Rename columns

# Add class names if you have them
class_names = {i:model.all_classnames[i] for i in range(len(model.all_classnames))}
average_attributes_df.rename(index=class_names, inplace=True)
if False:
    att_W[:,-1] = 1
    average_attributes_df = average_attributes_df * att_W.detach().cpu().numpy().T
average_attributes_df = average_attributes_df.iloc[:, :-1]  # This selects all rows and all columns except the last one
average_attributes_df = average_attributes_df.drop_duplicates()


In [None]:
top_att=[]
for i in range(len(average_attributes_df)-1):
    top_att.append(average_attributes_df.iloc[i].idxmax())
top_att=list(set(top_att))

In [None]:
df = average_attributes_df[top_att].T
selected_unused_attributes = np.random.choice(model.unused_attributes_texts, size=2, replace=False)

unused_attributes_df = pd.DataFrame(0, index=selected_unused_attributes, columns=df.columns)
df_with_unused = pd.concat([df, unused_attributes_df])
df = df_with_unused.T


In [None]:
categories = df.columns
classnames = list(df.index)
classnames[-1]="unknown"
fig, ax = plt.subplots(figsize=(20*len(classnames)/31+2, 4*len(categories)/6-3))
sns.heatmap(df.T, annot=True, fmt=".2f", linewidths=.7, cmap='coolwarm', ax=ax)
edited_categories = [cat.replace("object which (is/has/etc) ", "") for cat in categories]
ax.set_yticklabels(edited_categories)
ax.set_xticklabels(classnames, rotation=90)
plt.tight_layout()
plt.savefig(f"{dataset}_att.png")

In [None]:
categories = df.columns
classnames = list(df.index)
classnames[-1]="unknown"
fig, ax = plt.subplots(figsize=(5, 2))
sns.heatmap(df, annot=True, fmt=".2f", linewidths=.5, cmap='coolwarm', ax=ax)
edited_categories = [cat.replace("object which (is/has/etc) ", "") for cat in categories]
ax.set_xticklabels(edited_categories,rotation=45,ha='right')
ax.set_yticklabels(classnames)
plt.show()