In [20]:
import os
import time
import json
from collections import defaultdict

import fasttext
import pandas as pd
import numpy as np
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
import pickle
from PIL import Image
from tqdm import tqdm
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
def softmax(inputs):
    res = torch.tensor(inputs).float()
    res = res.softmax(dim=-1)
    return res.numpy()

def normalize(inputs):
    res = torch.tensor(inputs).float()
    res /= res.norm(dim=-1, keepdim=True)
    return res.numpy()

def get_precomputed_features(feature, args, is_softmax=False):
    """ Get precomputed CLIP image/pair/attr/obj features """
    data_root = args.precomputed_data_root
    feature_name = "image_features" if feature=="image" else f"{feature}_activations"
    feature_train = np.load(os.path.join(data_root, f"{feature_name}_train.npy"))
    feature_valid = np.load(os.path.join(data_root, f"{feature_name}_valid.npy"))
    feature_test = np.load(os.path.join(data_root, f"{feature_name}_test.npy"))
    if is_softmax:
        feature_train = softmax(feature_train)
        feature_valid = softmax(feature_valid)
        feature_test = softmax(feature_test)
    print(f"{feature} \t| train {feature_train.shape} \t| valid {feature_valid.shape} \t| test {feature_test.shape}")
    return feature_train, feature_valid, feature_test

def get_seen_unseen_indices(split, data):
    if split == "train":
        split_data = data.train_data
    elif split == "valid":
        split_data = data.valid_data
    elif split == "test":
        split_data = data.test_data
    else:
        raise ValueError(f"No split found: {split}")

    pairs = [(sample["attr"], sample["obj"]) for sample in split_data]
    seen_indices = [
        i for i in range(len(pairs))
        if pairs[i] in data.train_pairs
    ]
    unseen_indices = [
        i for i in range(len(pairs))
        if pairs[i] not in data.train_pairs
    ]
    print(f"seen_indices: {len(seen_indices)} | unseen_indices: {len(unseen_indices)}")
    return seen_indices, unseen_indices

def evaluate(results):
    """ Evaluate predictions and Return metrics. """
    all_preds, seen_preds, unseen_preds = results["all_preds"], results["seen_preds"], results["unseen_preds"]
    all_acc, seen_acc, unseen_acc = np.mean(all_preds), np.mean(seen_preds), np.mean(unseen_preds)    
    return {
        "all_acc": all_acc,
        "seen_acc": seen_acc,
        "unseen_acc": unseen_acc,
        "harmonic_mean": (seen_acc * unseen_acc)**0.5,
        "macro_average_acc": (seen_acc + unseen_acc)*0.5,
    }

def generate_predictions(scores, labels, seen_ids, unseen_ids, data, topk, bias=0.0):
    """ Apply bias and Generate predictions for. """
    def get_predictions(_scores):
        # Get predictions
        _, pair_preds = _scores.topk(topk, dim=1)
        pair_preds = pair_preds[:, :topk].contiguous().view(-1)
        attr_preds = all_pairs[pair_preds][:,0].view(-1, topk)
        obj_preds = all_pairs[pair_preds][:,1].view(-1, topk)
        pair_preds = pair_preds.view(-1, topk)
        return pair_preds, attr_preds, obj_preds
    
    # Get predictions with biases applied
    all_pairs = torch.LongTensor([
        (data.attr2idx[attr], data.obj2idx[obj]) 
        for attr, obj in data.pairs
    ])
    scores = scores.clone()
    mask = data.seen_mask.repeat(scores.shape[0], 1)
    scores[~mask] += bias
    pair_preds, attr_preds, obj_preds = get_predictions(scores)
    
    # Get predictions for seen/unseen pairs
    all_preds = np.array([label in pair_preds[row_id,:topk] for row_id, label in enumerate(labels)])
    seen_preds = all_preds[seen_ids]
    unseen_preds = all_preds[unseen_ids]
    return {
        "pair_preds": pair_preds,
        "attr_preds": attr_preds,
        "obj_preds": obj_preds,
        "all_preds": all_preds,
        "seen_preds": seen_preds,
        "unseen_preds": unseen_preds,
    }

In [7]:
clip.available_models()

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']

In [4]:
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
from torch.utils.data import Dataset

class MITStatesDataset(Dataset):
    def __init__(self, root, split):
        self.root = root
        self.split = split
        
        # Load metadata
        self.metadata = torch.load(os.path.join(root, "metadata_compositional-split-natural.t7"))
        
        # Load attribute-noun pairs for each split
        all_info, split_info = self.parse_split()
        self.attrs, self.objs, self.pairs = all_info
        self.train_pairs, self.valid_pairs, self.test_pairs = split_info
        
        # Get obj/attr/pair to indices mappings
        self.obj2idx = {obj: idx for idx, obj in enumerate(self.objs)}
        self.attr2idx = {attr: idx for idx, attr in enumerate(self.attrs)}
        self.pair2idx = {pair: idx for idx, pair in enumerate(self.pairs)}
        self.idx2obj = {idx: obj for obj, idx in self.obj2idx.items()}
        self.idx2attr = {idx: attr for attr, idx in self.attr2idx.items()}
        self.idx2pair = {idx: pair for pair, idx in self.pair2idx.items()}
        
        # Get all data
        self.train_data, self.valid_data, self.test_data = self.get_split_info()
        if self.split == "train":
            self.data = self.train_data
        elif self.split == "valid":
            self.data = self.valid_data
        else:
            self.data = self.test_data
        
        self.sample_indices = list(range(len(self.data)))
        self.sample_pairs = self.train_pairs
        print(f"train pairs: {len(self.train_pairs)} | valid pairs: {len(self.valid_pairs)} | test pairs: {len(self.test_pairs)}")
        print(f"train images: {len(self.train_data)} | valid images: {len(self.valid_data)} | test images: {len(self.test_data)}")
    
    def parse_split(self):
        def parse_pairs(pair_path):
            with open(pair_path, "r") as f:
                pairs = f.read().strip().split("\n")
                pairs = [t.split() for t in pairs]
                pairs = list(map(tuple, pairs))
            attrs, objs = zip(*pairs)
            return attrs, objs, pairs
        
        tr_attrs, tr_objs, tr_pairs = parse_pairs(os.path.join(self.root, "compositional-split-natural", "train_pairs.txt"))
        vl_attrs, vl_objs, vl_pairs = parse_pairs(os.path.join(self.root, "compositional-split-natural", "val_pairs.txt"))
        ts_attrs, ts_objs, ts_pairs = parse_pairs(os.path.join(self.root, "compositional-split-natural", "test_pairs.txt"))

        all_attrs = sorted(list(set(tr_attrs + vl_attrs + ts_attrs)))
        all_objs = sorted(list(set(tr_objs + vl_objs + ts_objs)))
        all_pairs = sorted(list(set(tr_pairs + vl_pairs + ts_pairs)))
        
        return (all_attrs, all_objs, all_pairs), (tr_pairs, vl_pairs, ts_pairs)
    
    def get_split_info(self):
        train_data, val_data, test_data = [], [], []
        for instance in self.metadata:
            image, attr, obj, settype = instance["image"], instance["attr"], instance["obj"], instance["set"]
            image = image.split("/")[1]  # Get the image name without (attr, obj) folder
            image = os.path.join(" ".join([attr, obj]), image)
            
            if (
                (attr == "NA") or 
                ((attr, obj) not in self.pairs) or 
                (settype == "NA")
            ):
                # ignore instances with unlabeled attributes
                # ignore instances that are not in current split
                continue

            data_i = {
                "image_path": image, 
                "attr": attr, 
                "obj": obj,
                "pair": (attr, obj),
                "attr_id": self.attr2idx[attr],
                "obj_id": self.obj2idx[obj],
                "pair_id": self.pair2idx[(attr, obj)],
            }
            if settype == "train":
                train_data.append(data_i)
            elif settype == "val":
                val_data.append(data_i)
            else:
                test_data.append(data_i)
                
        return train_data, val_data, test_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        index = self.sample_indices[index]
        return self.data[index]

In [4]:
root = "../../data/mit_states"
split = "train"
data = MITStatesDataset(root=root, split=split)
print(f"split size: {len(data)}")

train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
split size: 30338


# Retrieval Model

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# OLD!!
class RetrievalModel(nn.Module):
    def __init__(self, data, input_dim, args):
        super(RetrievalModel, self).__init__()
        self.input_dim = input_dim
        self.args = args
        
        self.obj2idx, self.attr2idx, self.pair2idx = data.obj2idx, data.attr2idx, data.pair2idx
        self.train_pairs, self.valid_pairs, self.test_pairs = data.train_pairs, data.valid_pairs, data.test_pairs
        self.limit_pairs = set(self.train_pairs)
        
        self.text_input_dim = args.text_input_dim
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / args.logit_scale))
        self.attr_encoder = nn.Embedding(len(data.attrs), self.text_input_dim)
        self.obj_encoder = nn.Embedding(len(data.objs), self.text_input_dim)
        self.text_projection = nn.Linear(self.text_input_dim*2, input_dim)
        self.image_projection = nn.Linear(input_dim, input_dim)
    
    def get_limit_pair_inputs(self):
        attr_inputs, obj_inputs = [],[]
        for attr, obj in self.limit_pairs:
            attr_id = self.attr2idx[attr]
            obj_id = self.obj2idx[obj]
            attr_inputs.append(attr_id)
            obj_inputs.append(obj_id)
        attr_inputs = torch.LongTensor(attr_inputs).to(device)
        obj_inputs = torch.LongTensor(obj_inputs).to(device)
        return attr_inputs, obj_inputs

    def get_all_pair_inputs(self):
        attr_inputs, obj_inputs = [],[]
        for attr, obj in self.pair2idx.keys():
            attr_id = self.attr2idx[attr]
            obj_id = self.obj2idx[obj]
            attr_inputs.append(attr_id)
            obj_inputs.append(obj_id)
        attr_inputs = torch.LongTensor(attr_inputs).to(device)
        obj_inputs = torch.LongTensor(obj_inputs).to(device)
        return attr_inputs, obj_inputs
    
    def forward(self, image_embs, pair_labels=None, is_train=True):
        # image_embs refers to image embeddings or concept representations
        attrs, objs = self.get_all_pair_inputs()
        #attrs, objs = self.get_limit_pair_inputs() if is_train else self.get_all_pair_inputs()
        attr_embs = self.attr_encoder(attrs)
        obj_embs = self.obj_encoder(objs)
        pair_embs = torch.cat([attr_embs, obj_embs], dim=1)
        pair_embs = self.text_projection(pair_embs)
        pair_embs = F.normalize(pair_embs, dim=1)
        
        image_embs = self.image_projection(image_embs)
        image_embs = F.normalize(image_embs, dim=1)
        
        logit_scale = self.logit_scale.exp()
        logit_scale = logit_scale if logit_scale<=100.0 else 100.0
        logits = logit_scale * image_embs @ pair_embs.t()
        
        loss = None
        if is_train:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, pair_labels)
        return logits, loss
        

In [5]:
# NEW!!
class RetrievalModel(nn.Module):
    def __init__(self, data, image_dim, limit_pairs, args):
        super(RetrievalModel, self).__init__()
        self.image_dim = image_dim
        self.input_dim = args.input_dim
        self.args = args
        
        self.obj2idx, self.attr2idx, self.pair2idx = data.obj2idx, data.attr2idx, data.pair2idx
        self.train_pairs, self.valid_pairs, self.test_pairs = data.train_pairs, data.valid_pairs, data.test_pairs
        self.limit_pairs = limit_pairs
        
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / args.logit_scale))
        self.attr_encoder = nn.Embedding(len(data.attrs), self.input_dim)
        self.obj_encoder = nn.Embedding(len(data.objs), self.input_dim)
        self.text_projection = nn.Linear(self.input_dim*2, self.image_dim, bias=args.is_bias)
        
        if self.args.is_image_projection:
            self.image_projection = nn.Linear(self.image_dim, self.image_dim, bias=args.is_bias)
    
    def get_limit_pair_inputs(self):
        attr_inputs, obj_inputs = [],[]
        for attr, obj in self.limit_pairs:
            attr_id = self.attr2idx[attr]
            obj_id = self.obj2idx[obj]
            attr_inputs.append(attr_id)
            obj_inputs.append(obj_id)
        attr_inputs = torch.LongTensor(attr_inputs).to(device)
        obj_inputs = torch.LongTensor(obj_inputs).to(device)
        return attr_inputs, obj_inputs

    def get_all_pair_inputs(self):
        attr_inputs, obj_inputs = [],[]
        for attr, obj in self.pair2idx.keys():
            attr_id = self.attr2idx[attr]
            obj_id = self.obj2idx[obj]
            attr_inputs.append(attr_id)
            obj_inputs.append(obj_id)
        attr_inputs = torch.LongTensor(attr_inputs).to(device)
        obj_inputs = torch.LongTensor(obj_inputs).to(device)
        return attr_inputs, obj_inputs
    
    def forward(self, image_embs, pair_labels=None, is_train=True):
        # image_embs refers to image embeddings or concept representations
        #attrs, objs = self.get_all_pair_inputs()
        attrs, objs = self.get_limit_pair_inputs() if is_train else self.get_all_pair_inputs()
        attr_embs = self.attr_encoder(attrs)
        obj_embs = self.obj_encoder(objs)
        pair_embs = torch.cat([attr_embs, obj_embs], dim=1)
        pair_embs = self.text_projection(pair_embs)
        pair_embs = F.normalize(pair_embs, dim=1)
        
        if self.args.is_image_projection:
            image_embs = self.image_projection(image_embs.float())
        image_embs = F.normalize(image_embs, dim=1).float()
        
        logit_scale = self.logit_scale.exp()
        logit_scale = logit_scale if logit_scale<=100.0 else 100.0
        logits = logit_scale * image_embs @ pair_embs.t()
        
        loss = None
        if is_train:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, pair_labels)
        return logits, loss

In [6]:
def get_split_labels(split, data, limit_pair2idx, is_limit=True):
    if split == "train":
        split_data = data.train_data
    elif split == "valid":
        split_data = data.valid_data
    elif split == "test":
        split_data = data.test_data
    lables_attr = [sample["attr_id"] for sample in split_data]
    lables_obj = [sample["obj_id"] for sample in split_data]
    
    pair2idx = limit_pair2idx if is_limit else data.pair2idx
    labels = []
    for attr_id, obj_id in zip(lables_attr, lables_obj):
        attr = data.idx2attr[attr_id]
        obj = data.idx2obj[obj_id]
        labels.append(pair2idx[(attr, obj)])
    labels = torch.LongTensor(labels_train)
    return labels

In [7]:
# Run this!
labels_text = [sample["pair"] for sample in data.train_data]
limit_pair2idx = {}
for label in labels_text:
    if label not in limit_pair2idx:
        limit_pair2idx[label] = len(limit_pair2idx)

labels = [limit_pair2idx[(attr, obj)] for attr, obj in labels_text]

In [None]:
# Get labels on train split 
labels_train_text = [sample["pair"] for sample in data.train_data]
limit_pair2idx = {}
for label in labels_train_text:
    if label not in limit_pair2idx:
        limit_pair2idx[label] = len(limit_pair2idx)

labels = [limit_pair2idx[(attr, obj)] for attr, obj in labels_text]
labels_train = torch.LongTensor(labels_train)

In [None]:
torch.HalfTensor([10]) @ torch.FloatTensor([9])

In [219]:
labels_train = get_split_labels("train", data, limit_pair2idx, is_limit=True)

## 1. Pair Actv. (Limit to Seen)

In [8]:
exp_name = "retrieval_limit_pair_actv"
split = "valid"
seed = 0

seen_mask = torch.BoolTensor([1 if pair in data.train_pairs else 0 for pair in data.pairs])
seen_ids, unseen_ids = get_seen_unseen_indices(split, data)
pair_actvs = get_precomputed_features("pair", is_softmax=False)

# Limit to seen pair activations
pair_actvs_limit = tuple(t[:, seen_mask] for t in pair_actvs)
pair_actv_train, pair_actv_valid, pair_actv_test = tuple(normalize(t) for t in pair_actvs_limit)

if split == "train":
    features = pair_actv_train
    split_data, split_pairs = data.train_data, data.train_pairs
elif split == "valid":
    features = pair_actv_valid
    split_data, split_pairs = data.valid_data, data.valid_pairs
elif split == "test":
    features = pair_actv_test
    split_data, split_pairs = data.test_data, data.test_pairs

labels_attr_train = [sample["attr_id"] for sample in data.train_data]
labels_obj_train = [sample["obj_id"] for sample in data.train_data]
labels_train = torch.LongTensor([sample["pair_id"] for sample in data.train_data])

labels_attr = [sample["attr_id"] for sample in split_data]
labels_obj = [sample["obj_id"] for sample in split_data]
labels = [sample["pair_id"] for sample in split_data]

print(f"features: {features.shape}")

seen_indices: 1844 | unseen_indices: 8576
pair 	| train (30338, 1962) 	| valid (10420, 1962) 	| test (12995, 1962)
features: (10420, 1262)


In [9]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

In [24]:
class parseArguments:
    def __init__(self):
        self.feature = "primitive"
        self.model = "vilt"
        self.train_warmup = 0
        self.is_image_projection = False
        self.is_bias = False
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_root = f"../../outputs/mit_states/{self.model}_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 16
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

In [11]:
def get_overall_metrics(features, labels, seen_ids, unseen_ids, data, topk_list=[1]):
    overall_metrics = {}
    for topk in topk_list:
        # Get model's performance (accuracy) on seen/unseen pairs
        bias = 1e3
        results = generate_predictions(features, labels, seen_ids, unseen_ids, data, topk=topk, bias=bias)
        full_unseen_metrics = evaluate(results)
        all_preds, seen_preds, unseen_preds = results["all_preds"], results["seen_preds"], results["unseen_preds"]

        # Get predicted probability distribution of unseen pairs,
        # and the top K scores of seen pairs in the predicted prob. distribution of unseen pairs
        correct_scores = features[np.arange(len(features)), labels][unseen_ids]
        max_seen_scores = features[unseen_ids][:, data.seen_mask].topk(topk, dim=1)[0][:,topk-1]

        # Compute biases
        unseen_score_diff = max_seen_scores - correct_scores
        correct_unseen_score_diff = unseen_score_diff[unseen_preds] - 1e-4
        correct_unseen_score_diff = torch.sort(correct_unseen_score_diff)[0]
        magic_binsize = 20
        bias_skip = max(len(correct_unseen_score_diff) // magic_binsize, 1)
        bias_list = correct_unseen_score_diff[::bias_skip]

        # Get biased predictions and metrics with different biases
        all_metrics = []
        for bias in bias_list:
            results = generate_predictions(features, labels, seen_ids, unseen_ids, data, topk=topk, bias=bias)
            metrics = evaluate(results)
            all_metrics.append(metrics)
        all_metrics.append(full_unseen_metrics)

        # Compute overall metrics
        seen_accs = np.array([metric_dict["seen_acc"] for metric_dict in all_metrics])
        unseen_accs = np.array([metric_dict["unseen_acc"] for metric_dict in all_metrics])
        best_seen_acc = max([metric_dict["seen_acc"] for metric_dict in all_metrics])
        best_unseen_acc = max([metric_dict["unseen_acc"] for metric_dict in all_metrics])
        best_harmonic_mean = max([metric_dict["harmonic_mean"] for metric_dict in all_metrics])
        auc = np.trapz(seen_accs, unseen_accs)
        #print(f"best_seen_acc: {best_seen_acc:6.4f}")
        #print(f"best_unseen_acc: {best_unseen_acc:6.4f}")
        #print(f"best_harmonic_mean: {best_harmonic_mean:6.4f}")
        #print(f"auc: {auc:6.4f}")

        overall_metrics[topk] = {
            "seen_accs": seen_accs.tolist(),
            "unseen_accs": unseen_accs.tolist(),
            "best_seen_acc": best_seen_acc,
            "best_unseen_acc": best_unseen_acc,
            "best_harmonic_mean": best_harmonic_mean,
            "auc": auc,
        }
    return overall_metrics

In [19]:
# NEW!!
from torch.utils.data import Dataset

class Precomputed_MITStatesDataset(Dataset):
    def __init__(self, split, feature, data, args, is_limit=True):
        self.seen_mask = torch.BoolTensor([1 if pair in data.train_pairs else 0 for pair in data.pairs])
        if feature == "primitive":
            attr_actvs_tuple = get_precomputed_features("attr", args, is_softmax=False)
            obj_actvs_tuple = get_precomputed_features("obj", args, is_softmax=False)
            attr_actvs_train, attr_actvs_valid, attr_actvs_test = attr_actvs_tuple
            obj_actvs_train, obj_actvs_valid, obj_actvs_test = obj_actvs_tuple
            image_embs_train = np.concatenate([attr_actvs_train, obj_actvs_train], axis=-1)
            image_embs_valid = np.concatenate([attr_actvs_valid, obj_actvs_valid], axis=-1)
            image_embs_test = np.concatenate([attr_actvs_test, obj_actvs_test], axis=-1)
        elif feature == "pair":
            image_embs_tuple = get_precomputed_features("pair", args, is_softmax=False)
            if (is_limit) and (args.model == "tmn"):
                image_embs_tuple = tuple(t[:, self.seen_mask] for t in image_embs_tuple)
            #image_embs_train, image_embs_valid, image_embs_test = tuple(normalize(t) for t in image_embs_tuple)
            image_embs_train, image_embs_valid, image_embs_test = image_embs_tuple
        elif feature == "all":
            attr_actvs_tuple = get_precomputed_features("attr", args, is_softmax=False)
            obj_actvs_tuple = get_precomputed_features("obj", args, is_softmax=False)
            pair_actvs_tuple = get_precomputed_features("pair", args, is_softmax=False)
            if (is_limit) and (args.model == "tmn"):
                pair_actvs_tuple = tuple(t[:, self.seen_mask] for t in pair_actvs_tuple)
            attr_actvs_train, attr_actvs_valid, attr_actvs_test = attr_actvs_tuple
            obj_actvs_train, obj_actvs_valid, obj_actvs_test = obj_actvs_tuple
            pair_actvs_train, pair_actvs_valid, pair_actvs_test = pair_actvs_tuple
            image_embs_train = np.concatenate([attr_actvs_train, obj_actvs_train, pair_actvs_train], axis=-1)
            image_embs_valid = np.concatenate([attr_actvs_valid, obj_actvs_valid, pair_actvs_valid], axis=-1)
            image_embs_test = np.concatenate([attr_actvs_test, obj_actvs_test, pair_actvs_test], axis=-1)

        labels = None
        if split == "train":
            image_embs = image_embs_train
            if is_limit:
                labels_text = [sample["pair"] for sample in data.train_data]
                limit_pair2idx = {}
                for label in labels_text:
                    if label not in limit_pair2idx:
                        limit_pair2idx[label] = len(limit_pair2idx)
                labels = [limit_pair2idx[(attr, obj)] for attr, obj in labels_text]
                self.limit_pair2idx = limit_pair2idx
            else:
                labels = [sample["pair_id"] for sample in data.train_data]
        elif split == "valid":
            image_embs = image_embs_valid
            labels = [sample["pair_id"] for sample in data.valid_data]
        elif split == "test":
            image_embs = image_embs_test
            labels = [sample["pair_id"] for sample in data.test_data]
        self.image_embs = image_embs
        self.labels = np.array(labels)
        self.image_dim = self.image_embs.shape[-1]
        print(f"image_dim: {self.image_dim:6d}")

    def __len__(self):
        return len(self.image_embs)
    
    def __getitem__(self, index):
        image_embs = self.image_embs[index]
        labels = self.labels[index]
        return image_embs, labels

In [20]:
# Load dataset
data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

train_dataset = Precomputed_MITStatesDataset(split="train", feature="pair", data=data, args=args, is_limit=True)
valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="pair", data=data, args=args, is_limit=True)
test_dataset = Precomputed_MITStatesDataset(split="test", feature="pair", data=data, args=args, is_limit=True)


train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
pair 	| train (30338, 1262) 	| valid (10420, 1262) 	| test (12995, 1262)
image_dim:   1262
pair 	| train (30338, 1262) 	| valid (10420, 1262) 	| test (12995, 1262)
image_dim:   1262
pair 	| train (30338, 1262) 	| valid (10420, 1262) 	| test (12995, 1262)
image_dim:   1262


In [21]:
def load_fasttext_embeddings(vocab, args):
    custom_map = {
        'Faux.Fur': 'fake fur',
        'Faux.Leather': 'fake leather',
        'Full.grain.leather': 'thick leather',
        'Hair.Calf': 'hairy leather',
        'Patent.Leather': 'shiny leather',
        'Boots.Ankle': 'ankle boots',
        'Boots.Knee.High': 'kneehigh boots',
        'Boots.Mid-Calf': 'midcalf boots',
        'Shoes.Boat.Shoes': 'boatshoes',
        'Shoes.Clogs.and.Mules': 'clogs shoes',
        'Shoes.Flats': 'flats shoes',
        'Shoes.Heels': 'heels',
        'Shoes.Loafers': 'loafers',
        'Shoes.Oxfords': 'oxford shoes',
        'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers',
        'traffic_light': 'traficlight',
        'trash_can': 'trashcan',
        'dry-erase_board' : 'dry_erase_board',
        'black_and_white' : 'black_white',
        'eiffel_tower' : 'tower'
    }
    vocab_lower = [v.lower() for v in vocab]
    vocab = []
    for current in vocab_lower:
        if current in custom_map:
            vocab.append(custom_map[current])
        else:
            vocab.append(current)

    import fasttext.util
    ft = fasttext.load_model(args.emb_root+'/fasttext/cc.en.300.bin')
    embeds = []
    for k in vocab:
        if '_' in k:
            ks = k.split('_')
            emb = np.stack([ft.get_word_vector(it) for it in ks]).mean(axis=0)
        else:
            emb = ft.get_word_vector(k)
        embeds.append(emb)

    embeds = torch.Tensor(np.stack(embeds))
    print('Fasttext Embeddings loaded, total embeddings: {}'.format(embeds.size()))
    return embeds

def load_word2vec_embeddings(vocab, args):
    # vocab = [v.lower() for v in vocab]

    from gensim import models
    model = models.KeyedVectors.load_word2vec_format(
        args.emb_root+'/word2vec/GoogleNews-vectors-negative300.bin', binary=True
    )

    custom_map = {
        'Faux.Fur': 'fake_fur',
        'Faux.Leather': 'fake_leather',
        'Full.grain.leather': 'thick_leather',
        'Hair.Calf': 'hair_leather',
        'Patent.Leather': 'shiny_leather',
        'Boots.Ankle': 'ankle_boots',
        'Boots.Knee.High': 'knee_high_boots',
        'Boots.Mid-Calf': 'midcalf_boots',
        'Shoes.Boat.Shoes': 'boat_shoes',
        'Shoes.Clogs.and.Mules': 'clogs_shoes',
        'Shoes.Flats': 'flats_shoes',
        'Shoes.Heels': 'heels',
        'Shoes.Loafers': 'loafers',
        'Shoes.Oxfords': 'oxford_shoes',
        'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers',
        'traffic_light': 'traffic_light',
        'trash_can': 'trashcan',
        'dry-erase_board' : 'dry_erase_board',
        'black_and_white' : 'black_white',
        "eiffel_tower" : "tower"
    }

    embeds = []
    for k in vocab:
        if k in custom_map:
            k = custom_map[k]
        if "_" in k and k not in model:
            ks = k.split("_")
            emb = np.stack([model[it] for it in ks]).mean(axis=0)
        else:
            emb = model[k]
        embeds.append(emb)
    embeds = torch.Tensor(np.stack(embeds))
    print("Word2Vec Embeddings loaded, total embeddings: {}".format(embeds.size()))
    return embeds

def get_pretrained_weights(vocab, args):
    embeds1 = load_fasttext_embeddings(vocab, args)
    embeds2 = load_word2vec_embeddings(vocab, args)
    embeds = torch.cat([embeds1, embeds2], dim = 1)
    print("Combined embeddings are ",embeds.shape)
    return embeds

In [None]:
# Get retrieval model and optimizer
limit_pairs = list(train_dataset.limit_pair2idx.keys())
model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, args)
#model = RetrievalModel(data, train_dataset.input_dim, args)
model.to(device)
optimizer = Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

In [48]:
model

RetrievalModel(
  (attr_encoder): Embedding(115, 600)
  (obj_encoder): Embedding(245, 600)
  (text_projection): Linear(in_features=1200, out_features=1262, bias=True)
  (image_projection): Linear(in_features=1262, out_features=1262, bias=True)
)

In [37]:
pretrained_weight = get_pretrained_weights(data.attrs, args)
model.attr_encoder.weight.data.copy_(pretrained_weight)
pretrained_weight = get_pretrained_weights(data.objs, args)
model.obj_encoder.weight.data.copy_(pretrained_weight)


Fasttext Embeddings loaded, total embeddings: torch.Size([245, 300])
Word2Vec Embeddings loaded, total embeddings: torch.Size([245, 300])


In [214]:
best_auc = float("-inf")
best_epoch_id = 0
best_overall_metrics = {}
#for epoch_id in range(args.num_epochs):
for epoch_id in range(1):
    # Train model
    model.train()
    loader = DataLoader(train_dataset, shuffle=True, pin_memory=True, batch_size=args.batch_size, drop_last=True)
    for iter_id, batch in enumerate(loader):
        image_embs, labels = tuple(t.to(device) for t in batch)
        logits, loss = model(image_embs, labels, is_train=True)
        loss.backward()
        optimizer.step()
        model.zero_grad()
        break
        #if args.verbose and (iter_id % args.report_step == 0):
        #    logger.info(f"")

    # Early stop and Evaluation
    model.eval()
    loader = DataLoader(valid_dataset, shuffle=False, pin_memory=True, batch_size=args.batch_size, drop_last=False)
    features = np.zeros([len(valid_dataset), len(data.pairs)])
    for iter_id, batch in enumerate(loader):
        image_embs, labels = tuple(t.to(device) for t in batch)
        logits, loss = model(image_embs, labels, is_train=False)
        prob = logits.softmax(dim=-1).log()
        features[iter_id*args.batch_size:(iter_id+1)*args.batch_size] = prob.detach().cpu().numpy().copy()
    features = torch.tensor(features)

    labels = valid_dataset.labels
    overall_metrics = get_overall_metrics(
        features, labels, seen_ids_valid, unseen_ids_valid, data, topk_list=[1],
    )
    auc = overall_metrics[1]["auc"]
    if auc > best_auc:
        best_auc = auc
        best_epoch_id = epoch_id
        best_overall_metrics
        #torch.save(model.state_dict(), args.ckpt_path)

# Load best checkpoint and test
#model = RetrievalModel(data, train_dataset.input_dim, args)
#model.load_state_dict(torch.load(args.ckpt_path))
#model.to(device)

model.eval()
loader = DataLoader(test_dataset, shuffle=False, pin_memory=True, batch_size=args.batch_size, drop_last=False)
features = np.zeros([len(test_dataset), len(data.pairs)])
for iter_id, batch in enumerate(loader):
    image_embs, labels = tuple(t.to(device) for t in batch)
    logits, loss = model(image_embs, labels, is_train=False)
    prob = logits.softmax(dim=-1).log()
    features[iter_id*args.batch_size:(iter_id+1)*args.batch_size] = prob.detach().cpu().numpy().copy()
features = torch.tensor(features)

labels = test_dataset.labels
overall_metrics = get_overall_metrics(
    features, labels, seen_ids_test, unseen_ids_test, data, 
    topk_list=[1,2,3],
)

In [209]:
labels = valid_dataset.labels
overall_metrics_valid = get_overall_metrics(
    features, labels, seen_ids_valid, unseen_ids_valid, data, topk_list=[1],
)
auc_valid = overall_metrics_valid[1]["auc"]


# 2. Intervention

In [8]:
def get_gt_primitives(split, data, ):
    """ Get groundtruth primtiive concepts. """
    data_dict = {
        "train": data.train_data,
        "valid": data.valid_data,
        "test": data.test_data,
    }
    split_data = data_dict[split]
    labels_attr = [sample["attr_id"] for sample in split_data]
    labels_obj = [sample["obj_id"] for sample in split_data]
    gt_features_attr = np.zeros((len(split_data), len(data.attrs)))
    gt_features_obj = np.zeros((len(split_data), len(data.objs)))
    gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
    gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
    gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
    return gt_features_concat

def evaluate(results):
    """ Evaluate predictions and Return metrics. """
    all_preds, seen_preds, unseen_preds = results["all_preds"], results["seen_preds"], results["unseen_preds"]
    all_acc, seen_acc, unseen_acc = np.mean(all_preds), np.mean(seen_preds), np.mean(unseen_preds)    
    return {
        "all_acc": all_acc,
        "seen_acc": seen_acc,
        "unseen_acc": unseen_acc,
        "harmonic_mean": (seen_acc * unseen_acc)**0.5,
        "macro_average_acc": (seen_acc + unseen_acc)*0.5,
    }

def generate_predictions(scores, labels, seen_ids, unseen_ids, seen_mask, data, topk, bias=0.0):
    """ Apply bias and Generate predictions for. """
    def get_predictions(_scores):
        # Get predictions
        _, pair_preds = _scores.topk(topk, dim=1)
        pair_preds = pair_preds[:, :topk].contiguous().view(-1)
        attr_preds = all_pairs[pair_preds][:,0].view(-1, topk)
        obj_preds = all_pairs[pair_preds][:,1].view(-1, topk)
        pair_preds = pair_preds.view(-1, topk)
        return pair_preds, attr_preds, obj_preds
    
    # Get predictions with biases applied
    all_pairs = torch.LongTensor([
        (data.attr2idx[attr], data.obj2idx[obj]) 
        for attr, obj in data.pairs
    ])
    scores = scores.clone()
    mask = seen_mask.repeat(scores.shape[0], 1)
    scores[~mask] += bias
    pair_preds, attr_preds, obj_preds = get_predictions(scores)
    
    # Get predictions for seen/unseen pairs
    all_preds = np.array([label in pair_preds[row_id,:topk] for row_id, label in enumerate(labels)])
    seen_preds = all_preds[seen_ids]
    unseen_preds = all_preds[unseen_ids]
    return {
        "pair_preds": pair_preds,
        "attr_preds": attr_preds,
        "obj_preds": obj_preds,
        "all_preds": all_preds,
        "seen_preds": seen_preds,
        "unseen_preds": unseen_preds,
    }

def get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1], is_open_world=False):
    overall_metrics = {}
    for topk in topk_list:
        # Get model"s performance (accuracy) on seen/unseen pairs
        bias = 1e3
        results = generate_predictions(features, labels, seen_ids, unseen_ids, seen_mask, data, topk=topk, bias=bias)
        full_unseen_metrics = evaluate(results)
        all_preds, seen_preds, unseen_preds = results["all_preds"], results["seen_preds"], results["unseen_preds"]

        # Get predicted probability distribution of unseen pairs,
        # and the top K scores of seen pairs in the predicted prob. distribution of unseen pairs
        correct_scores = features[np.arange(len(features)), labels][unseen_ids]
        max_seen_scores = features[unseen_ids][:, seen_mask].topk(topk, dim=1)[0][:,topk-1]

        # Compute biases
        unseen_score_diff = max_seen_scores - correct_scores
        correct_unseen_score_diff = unseen_score_diff[unseen_preds] - 1e-4
        correct_unseen_score_diff = torch.sort(correct_unseen_score_diff)[0]
        magic_binsize = 20
        bias_skip = max(len(correct_unseen_score_diff) // magic_binsize, 1)
        bias_list = correct_unseen_score_diff[::bias_skip]

        # Get biased predictions and metrics with different biases
        all_metrics = []
        for bias in bias_list:
            results = generate_predictions(features, labels, seen_ids, unseen_ids, seen_mask, data, topk=topk, bias=bias)
            metrics = evaluate(results)
            all_metrics.append(metrics)
        all_metrics.append(full_unseen_metrics)

        # Compute overall metrics
        seen_accs = np.array([metric_dict["seen_acc"] for metric_dict in all_metrics])
        unseen_accs = np.array([metric_dict["unseen_acc"] for metric_dict in all_metrics])
        best_seen_acc = max([metric_dict["seen_acc"] for metric_dict in all_metrics])
        best_unseen_acc = max([metric_dict["unseen_acc"] for metric_dict in all_metrics])
        best_harmonic_mean = max([metric_dict["harmonic_mean"] for metric_dict in all_metrics])
        auc = np.trapz(seen_accs, unseen_accs)
        print(f"topk: {topk}")
        print(f"best_seen_acc: {best_seen_acc:6.4f}")
        print(f"best_unseen_acc: {best_unseen_acc:6.4f}")
        print(f"best_harmonic_mean: {best_harmonic_mean:6.4f}")
        print(f"auc: {auc:6.4f}")

        overall_metrics[topk] = {
            "seen_accs": seen_accs.tolist(),
            "unseen_accs": unseen_accs.tolist(),
            "best_seen_acc": best_seen_acc,
            "best_unseen_acc": best_unseen_acc,
            "best_harmonic_mean": best_harmonic_mean,
            "auc": auc,
        }
    return overall_metrics

In [9]:
class Precomputed_MITStatesDataset(Dataset):
    def __init__(self, split, feature, data, args, is_limit=True):
        # Load precomputed features with temporary seen_mask
        self.seen_mask = torch.BoolTensor([1 if pair in data.train_pairs else 0 for pair in data.pairs])
        if feature == "primitive":
            attr_actvs_tuple = get_precomputed_features("attr", args, is_softmax=False)
            obj_actvs_tuple = get_precomputed_features("obj", args, is_softmax=False)
            attr_actvs_train, attr_actvs_valid, attr_actvs_test = attr_actvs_tuple
            obj_actvs_train, obj_actvs_valid, obj_actvs_test = obj_actvs_tuple
            image_embs_train = np.concatenate([attr_actvs_train, obj_actvs_train], axis=-1)
            image_embs_valid = np.concatenate([attr_actvs_valid, obj_actvs_valid], axis=-1)
            image_embs_test = np.concatenate([attr_actvs_test, obj_actvs_test], axis=-1)
        elif feature == "pair":
            image_embs_tuple = get_precomputed_features("pair", args, is_softmax=False)
            if is_limit:
                image_embs_tuple = tuple(t[:, self.seen_mask] for t in image_embs_tuple)
            #image_embs_train, image_embs_valid, image_embs_test = tuple(normalize(t) for t in image_embs_tuple)
            image_embs_train, image_embs_valid, image_embs_test = image_embs_tuple
        elif feature == "all":
            attr_actvs_tuple = get_precomputed_features("attr", args, is_softmax=False)
            obj_actvs_tuple = get_precomputed_features("obj", args, is_softmax=False)
            pair_actvs_tuple = get_precomputed_features("pair", args, is_softmax=False)
            if is_limit:
                pair_actvs_tuple = tuple(t[:, self.seen_mask] for t in pair_actvs_tuple)
            attr_actvs_train, attr_actvs_valid, attr_actvs_test = attr_actvs_tuple
            obj_actvs_train, obj_actvs_valid, obj_actvs_test = obj_actvs_tuple
            pair_actvs_train, pair_actvs_valid, pair_actvs_test = pair_actvs_tuple
            image_embs_train = np.concatenate([attr_actvs_train, obj_actvs_train, pair_actvs_train], axis=-1)
            image_embs_valid = np.concatenate([attr_actvs_valid, obj_actvs_valid, pair_actvs_valid], axis=-1)
            image_embs_test = np.concatenate([attr_actvs_test, obj_actvs_test, pair_actvs_test], axis=-1)
        elif feature == "gt_primitive":
            image_embs_train = get_gt_primitives("train", data)
            image_embs_valid = get_gt_primitives("valid", data)
            image_embs_test = get_gt_primitives("test", data)
        
        # Prepare labels
        self.limit_pair2idx = self.get_limit_pair2idx(data)
        self.open_world_pair2idx = self.get_open_world_pair2idx(split, data)
        
        labels = None
        if split == "train":
            image_embs = image_embs_train
            if args.is_open_world:
                labels_text = [sample["pair"] for sample in data.valid_data]
                labels = [self.open_world_pair2idx[(attr, obj)] for attr, obj in labels_text]
            else:
                if is_limit:
                    labels_text = [sample["pair"] for sample in data.train_data]
                    labels = [self.limit_pair2idx[(attr, obj)] for attr, obj in labels_text]
                else:
                    labels = [sample["pair_id"] for sample in data.train_data]
        elif split == "valid":
            image_embs = image_embs_valid
            if args.is_open_world:
                labels_text = [sample["pair"] for sample in data.valid_data]
                labels = [self.open_world_pair2idx[(attr, obj)] for attr, obj in labels_text]
            else:
                labels = [sample["pair_id"] for sample in data.valid_data]
        elif split == "test":
            image_embs = image_embs_test
            if args.is_open_world:
                labels_text = [sample["pair"] for sample in data.test_data]
                labels = [self.open_world_pair2idx[(attr, obj)] for attr, obj in labels_text]
            else:
                labels = [sample["pair_id"] for sample in data.test_data]
        
        # Compute seen mask, 1 for training pair, 0 for the others
        if args.is_open_world:
            self.seen_mask = torch.BoolTensor([1 if pair in data.train_pairs else 0 for pair in self.open_world_pair2idx.keys()])
        else:
            self.seen_mask = torch.BoolTensor([1 if pair in data.train_pairs else 0 for pair in data.pairs])
        
        self.image_embs = image_embs
        self.labels = np.array(labels)
        self.image_dim = self.image_embs.shape[-1]
        print(f"image_dim: {self.image_dim:6d}")
    
    def get_limit_pair2idx(self, data):
        labels_text = [sample["pair"] for sample in data.train_data]
        limit_pair2idx = {}
        for label in labels_text:
            if label not in limit_pair2idx:
                limit_pair2idx[label] = len(limit_pair2idx)
        return limit_pair2idx
    
    def get_open_world_pair2idx(self, split, data):
        open_world_pair2idx = {}
        for attr in data.attrs:
            for obj in data.objs:
                if (attr, obj) not in open_world_pair2idx:
                    open_world_pair2idx[(attr, obj)] = len(open_world_pair2idx)
        return open_world_pair2idx
    
    def __len__(self):
        return len(self.image_embs)
    
    def __getitem__(self, index):
        image_embs = self.image_embs[index]
        labels = self.labels[index]
        return image_embs, labels

class RetrievalModel(nn.Module):
    def __init__(self, data, image_dim, limit_pairs, open_world_pairs, args):
        super(RetrievalModel, self).__init__()
        self.image_dim = image_dim
        self.input_dim = args.input_dim
        self.args = args
        
        self.obj2idx, self.attr2idx, self.pair2idx = data.obj2idx, data.attr2idx, data.pair2idx
        self.train_pairs, self.valid_pairs, self.test_pairs = data.train_pairs, data.valid_pairs, data.test_pairs
        self.limit_pairs = limit_pairs
        self.open_world_pairs = open_world_pairs
        
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / args.logit_scale))
        self.attr_encoder = nn.Embedding(len(data.attrs), self.input_dim)
        self.obj_encoder = nn.Embedding(len(data.objs), self.input_dim)
        self.text_projection = nn.Linear(self.input_dim*2, self.image_dim, bias=args.is_bias)
        
        if self.args.is_image_projection:
            self.image_projection = nn.Linear(self.image_dim, self.image_dim, bias=args.is_bias)
    
    def get_limit_pair_inputs(self):
        attr_inputs, obj_inputs = [],[]
        for attr, obj in self.limit_pairs:
            attr_id = self.attr2idx[attr]
            obj_id = self.obj2idx[obj]
            attr_inputs.append(attr_id)
            obj_inputs.append(obj_id)
        attr_inputs = torch.LongTensor(attr_inputs).to(device)
        obj_inputs = torch.LongTensor(obj_inputs).to(device)
        return attr_inputs, obj_inputs

    def get_all_pair_inputs(self):
        attr_inputs, obj_inputs = [],[]
        for attr, obj in self.pair2idx.keys():
            attr_id = self.attr2idx[attr]
            obj_id = self.obj2idx[obj]
            attr_inputs.append(attr_id)
            obj_inputs.append(obj_id)
        attr_inputs = torch.LongTensor(attr_inputs).to(device)
        obj_inputs = torch.LongTensor(obj_inputs).to(device)
        return attr_inputs, obj_inputs
    
    def get_open_world_pair_inputs(self):
        attr_inputs, obj_inputs = [],[]
        for attr, obj in self.open_world_pairs:
            attr_id = self.attr2idx[attr]
            obj_id = self.obj2idx[obj]
            attr_inputs.append(attr_id)
            obj_inputs.append(obj_id)
        attr_inputs = torch.LongTensor(attr_inputs).to(device)
        obj_inputs = torch.LongTensor(obj_inputs).to(device)
        return attr_inputs, obj_inputs
    
    def forward(self, image_embs, pair_labels=None, is_train=True, is_intervene=False, is_open_world=False):
        # image_embs refers to image embeddings or concept representations
        #attrs, objs = self.get_all_pair_inputs()
        if is_open_world:
            # Open world setting, get all the possible pairs
            attrs, objs = self.get_open_world_pair_inputs()
        else:
            # Close world setting, get all the pairs in the data
            attrs, objs = self.get_limit_pair_inputs() if is_train else self.get_all_pair_inputs()
        attr_embs = self.attr_encoder(attrs)
        obj_embs = self.obj_encoder(objs)
        pair_embs = torch.cat([attr_embs, obj_embs], dim=1)
        pair_embs = self.text_projection(pair_embs)
        pair_embs = F.normalize(pair_embs, dim=1)
        
        if self.args.is_image_projection:
            image_embs = self.image_projection(image_embs.float())
        if not is_intervene:
            image_embs = F.normalize(image_embs, dim=1).float()
        image_embs = image_embs.float()
        
        logit_scale = self.logit_scale.exp()
        logit_scale = logit_scale if logit_scale<=100.0 else 100.0
        logits = logit_scale * image_embs @ pair_embs.t()
        
        loss = None
        if is_train:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, pair_labels)
        return logits, loss

## 2.0. GT -> GT

In [159]:
class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "gt_primitive"
        self.model = "gt"
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_root = f"../../outputs/mit_states/{self.model}_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 100
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

In [160]:
# Load dataset
data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

train_dataset = Precomputed_MITStatesDataset(split="train", feature="gt_primitive", data=data, args=args, is_limit=True)
valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="gt_primitive", data=data, args=args, is_limit=True)
test_dataset = Precomputed_MITStatesDataset(split="test", feature="gt_primitive", data=data, args=args, is_limit=True)
seen_mask = train_dataset.seen_mask


train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
image_dim:    360
image_dim:    360
image_dim:    360


In [161]:
# Load trained model
emb_init = 1 if args.emb_init else 0
imgproj = args.is_image_projection
print("emb_init: ", emb_init)
print("imgproj: ", imgproj)

ckpt_dir = f"../../outputs/mit_states/{args.model}_retrieval_model"
ckpt_name = "{}_retrieval_model_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
    args.model,
    1 if args.emb_init else 0,
    1 if args.is_image_projection else 0,
    1 if args.is_bias else 0,
    args.feature,
    1 if args.is_limit else 0, 
    args.num_epochs,
    args.train_warmup,
    args.batch_size,
    args.learning_rate,
    args.weight_decay,
    args.logit_scale,
)
ckpt_path = os.path.join(ckpt_dir, ckpt_name, "retrieval_model.ckpt")

limit_pairs = list(train_dataset.limit_pair2idx.keys())
open_world_pairs = list(valid_dataset.open_world_pair2idx.keys())
model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, open_world_pairs, args)
model.load_state_dict(torch.load(ckpt_path))
model.eval()
model.to(device)

emb_init:  1
imgproj:  True


RetrievalModel(
  (attr_encoder): Embedding(115, 600)
  (obj_encoder): Embedding(245, 600)
  (text_projection): Linear(in_features=1200, out_features=360, bias=False)
  (image_projection): Linear(in_features=360, out_features=360, bias=False)
)

In [162]:
# Prepare groundtruth labels
seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

if args.split == "valid":
    labels_attr = [sample["attr_id"] for sample in data.valid_data]
    labels_obj = [sample["obj_id"] for sample in data.valid_data]
    labels = [sample["pair_id"] for sample in data.valid_data]
    gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
    gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
elif args.split == "test":
    labels_attr = [sample["attr_id"] for sample in data.test_data]
    labels_obj = [sample["obj_id"] for sample in data.test_data]
    labels = [sample["pair_id"] for sample in data.test_data]
    gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
    gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))
    
gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
print(f"Intervention features: {gt_features_concat.shape}")

seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])


In [163]:
# Intervene!!!
with torch.no_grad():
    logits, loss = model(gt_features_concat, pair_labels=None, is_train=False, is_intervene=True)
features = logits.softmax(dim=-1).log()
print("features: ", features.shape)

features:  torch.Size([10420, 1962])


In [None]:
overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])

## 2.1. Primitive -> GT

In [302]:
class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "primitive"
        self.model = "vilt"
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_root = f"../../outputs/mit_states/{self.model}_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

In [303]:
# Load dataset
data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
seen_mask = train_dataset.seen_mask


train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360


In [304]:
# Load trained model
emb_init = 1 if args.emb_init else 0
imgproj = args.is_image_projection
print("emb_init: ", emb_init)
print("imgproj: ", imgproj)

ckpt_dir = f"../../outputs/mit_states/{args.model}_retrieval_model"
ckpt_name = "{}_retrieval_model_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
    args.model,
    1 if args.emb_init else 0,
    1 if args.is_image_projection else 0,
    1 if args.is_bias else 0,
    args.feature,
    1 if args.is_limit else 0, 
    args.num_epochs,
    args.train_warmup,
    args.batch_size,
    args.learning_rate,
    args.weight_decay,
    args.logit_scale,
)
ckpt_path = os.path.join(ckpt_dir, ckpt_name, "retrieval_model.ckpt")

limit_pairs = list(train_dataset.limit_pair2idx.keys())
open_world_pairs = list(valid_dataset.open_world_pair2idx.keys())
model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, open_world_pairs, args)
model.load_state_dict(torch.load(ckpt_path))
model.eval()
model.to(device)

emb_init:  1
imgproj:  True


RetrievalModel(
  (attr_encoder): Embedding(115, 600)
  (obj_encoder): Embedding(245, 600)
  (text_projection): Linear(in_features=1200, out_features=360, bias=False)
  (image_projection): Linear(in_features=360, out_features=360, bias=False)
)

In [305]:
# Prepare groundtruth labels
seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

if args.split == "valid":
    labels_attr = [sample["attr_id"] for sample in data.valid_data]
    labels_obj = [sample["obj_id"] for sample in data.valid_data]
    labels = [sample["pair_id"] for sample in data.valid_data]
    gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
    gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
elif args.split == "test":
    labels_attr = [sample["attr_id"] for sample in data.test_data]
    labels_obj = [sample["obj_id"] for sample in data.test_data]
    labels = [sample["pair_id"] for sample in data.test_data]
    gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
    gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))
    
gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
print(f"Intervention features: {gt_features_concat.shape}")

seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])


In [306]:
# Intervene!!!
with torch.no_grad():
    logits, loss = model(gt_features_concat, pair_labels=None, is_train=False, is_intervene=True)
features = logits.softmax(dim=-1).log()
print("features: ", features.shape)

features:  torch.Size([10420, 1962])


In [307]:
overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])

topk: 1
best_seen_acc: 0.2912
best_unseen_acc: 0.3853
best_harmonic_mean: 0.2283
auc: 0.0856
topk: 2
best_seen_acc: 0.4084
best_unseen_acc: 0.5146
best_harmonic_mean: 0.3298
auc: 0.1711
topk: 3
best_seen_acc: 0.4572
best_unseen_acc: 0.6027
best_harmonic_mean: 0.4060
auc: 0.2363


## 2.2. GT -> Primitive (argmax)

In [260]:
class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "gt_primitive"
        self.model = "albef"  # determines where the precomputed activations come from
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_root = f"../../outputs/mit_states/{self.model}_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 100
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

In [261]:
# Load dataset
data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
seen_mask = train_dataset.seen_mask


train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360


In [262]:
# Load trained model
emb_init = 1 if args.emb_init else 0
imgproj = args.is_image_projection
print("emb_init: ", emb_init)
print("imgproj: ", imgproj)

# Fixed to gt_retrieval_model
ckpt_dir = f"../../outputs/mit_states/gt_retrieval_model"
ckpt_name = "gt_retrieval_model_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
    1 if args.emb_init else 0,
    1 if args.is_image_projection else 0,
    1 if args.is_bias else 0,
    args.feature,
    1 if args.is_limit else 0, 
    args.num_epochs,
    args.train_warmup,
    args.batch_size,
    args.learning_rate,
    args.weight_decay,
    args.logit_scale,
)
ckpt_path = os.path.join(ckpt_dir, ckpt_name, "retrieval_model.ckpt")

limit_pairs = list(train_dataset.limit_pair2idx.keys())
open_world_pairs = list(valid_dataset.open_world_pair2idx.keys())
model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, open_world_pairs, args)
model.load_state_dict(torch.load(ckpt_path))
model.eval()
model.to(device)

emb_init:  1
imgproj:  True


RetrievalModel(
  (attr_encoder): Embedding(115, 600)
  (obj_encoder): Embedding(245, 600)
  (text_projection): Linear(in_features=1200, out_features=360, bias=False)
  (image_projection): Linear(in_features=360, out_features=360, bias=False)
)

In [263]:
# Prepare primitive concepts
seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

if args.split == "valid":
    labels = [sample["pair_id"] for sample in data.valid_data]
    features_interv = valid_dataset.image_embs
elif args.split == "test":
    labels = [sample["pair_id"] for sample in data.test_data]
    features_interv = test_dataset.image_embs

features_interv = torch.tensor(features_interv).to(device).double()
features_interv = F.normalize(features_interv, dim=1).float()

# Compute argmax
pred_attr_ids = features_interv[:,:len(data.attrs)].argmax(dim=1)
pred_obj_ids = features_interv[:,len(data.attrs):].argmax(dim=1) + len(data.attrs)

features_interv = torch.zeros(features_interv.shape).to(device)
features_interv[range(features_interv.shape[0]), pred_attr_ids] = 1
features_interv[range(features_interv.shape[0]), pred_obj_ids] = 1
print(f"Intervention features: {features_interv.shape}")

seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])


In [264]:
# Intervene!!!
with torch.no_grad():
    logits, loss = model(features_interv, pair_labels=None, is_train=False, is_intervene=True)
features = logits.softmax(dim=-1).log()
print("features: ", features.shape)

features:  torch.Size([10420, 1962])


In [265]:
overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])

topk: 1
best_seen_acc: 0.0461
best_unseen_acc: 0.0394
best_harmonic_mean: 0.0282
auc: 0.0012
topk: 2
best_seen_acc: 0.0895
best_unseen_acc: 0.0849
best_harmonic_mean: 0.0508
auc: 0.0044
topk: 3
best_seen_acc: 0.1171
best_unseen_acc: 0.1193
best_harmonic_mean: 0.0674
auc: 0.0083


## 2.3. Primitive -> ?

In [322]:
class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "primitive"
        self.model = "albef"
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_root = f"../../outputs/mit_states/{self.model}_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

In [323]:
# Load dataset
data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
seen_mask = train_dataset.seen_mask


train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360


In [324]:
# Load trained model
emb_init = 1 if args.emb_init else 0
imgproj = args.is_image_projection
print("emb_init: ", emb_init)
print("imgproj: ", imgproj)

# Open to arbitrary models
ckpt_dir = f"../../outputs/mit_states/{args.model}_retrieval_model"
ckpt_name = "{}_retrieval_model_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
    args.model,
    1 if args.emb_init else 0,
    1 if args.is_image_projection else 0,
    1 if args.is_bias else 0,
    args.feature,
    1 if args.is_limit else 0, 
    args.num_epochs,
    args.train_warmup,
    args.batch_size,
    args.learning_rate,
    args.weight_decay,
    args.logit_scale,
)
ckpt_path = os.path.join(ckpt_dir, ckpt_name, "retrieval_model.ckpt")

limit_pairs = list(train_dataset.limit_pair2idx.keys())
open_world_pairs = list(valid_dataset.open_world_pair2idx.keys())
model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, open_world_pairs, args)
model.load_state_dict(torch.load(ckpt_path))
model.eval()
model.to(device)

emb_init:  1
imgproj:  True


RetrievalModel(
  (attr_encoder): Embedding(115, 600)
  (obj_encoder): Embedding(245, 600)
  (text_projection): Linear(in_features=1200, out_features=360, bias=False)
  (image_projection): Linear(in_features=360, out_features=360, bias=False)
)

In [325]:
# Prepare groundtruth labels
seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

if args.split == "valid":
    labels_attr = [sample["attr_id"] for sample in data.valid_data]
    labels_obj = [sample["obj_id"] for sample in data.valid_data]
    labels = [sample["pair_id"] for sample in data.valid_data]
    gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
    gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
elif args.split == "test":
    labels_attr = [sample["attr_id"] for sample in data.test_data]
    labels_obj = [sample["obj_id"] for sample in data.test_data]
    labels = [sample["pair_id"] for sample in data.test_data]
    gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
    gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))
    
gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
print(f"GT features: {gt_features_concat.shape}")

seen_indices: 1844 | unseen_indices: 8576
GT features: torch.Size([10420, 360])


In [326]:
# Prepare "binary" primitive concept activations
# Set GT concepts to 1
seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

if args.split == "valid":
    labels_attr = [sample["attr_id"] for sample in data.valid_data]
    labels_obj = [sample["obj_id"] for sample in data.valid_data]
    labels = [sample["pair_id"] for sample in data.valid_data]
    features_interv = valid_dataset.image_embs
    features_interv = torch.tensor(features_interv).to(device).double()
    features_interv = F.normalize(features_interv, dim=1)
    
elif args.split == "test":
    labels_attr = [sample["attr_id"] for sample in data.test_data]
    labels_obj = [sample["obj_id"] for sample in data.test_data]
    labels = [sample["pair_id"] for sample in data.test_data]
    features_interv = test_dataset.image_embs
    features_interv = torch.tensor(features_interv).to(device).double()
    features_interv = F.normalize(features_interv, dim=1)

features_interv = torch.where(gt_features_concat==1, 1.0, features_interv)
print(f"Intervention features: {features_interv.shape}")

seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])


In [327]:
# Intervene!!!
with torch.no_grad():
    logits, loss = model(features_interv, pair_labels=None, is_train=False, is_intervene=True)
features = logits.softmax(dim=-1).log()
print("features: ", features.shape)

features:  torch.Size([10420, 1962])


In [328]:
overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])

topk: 1
best_seen_acc: 0.6014
best_unseen_acc: 0.6931
best_harmonic_mean: 0.5389
auc: 0.3867
topk: 2
best_seen_acc: 0.7402
best_unseen_acc: 0.8089
best_harmonic_mean: 0.6768
auc: 0.5650
topk: 3
best_seen_acc: 0.8259
best_unseen_acc: 0.8791
best_harmonic_mean: 0.7447
auc: 0.6831


# 3.  Open World MIT States

In [36]:
def get_gt_primitives(split, data):
    """ Get groundtruth primtiive concepts. """
    data_dict = {
        "train": data.train_data,
        "valid": data.valid_data,
        "test": data.test_data,
    }
    split_data = data_dict[split]
    labels_attr = [sample["attr_id"] for sample in split_data]
    labels_obj = [sample["obj_id"] for sample in split_data]
    gt_features_attr = np.zeros((len(split_data), len(data.attrs)))
    gt_features_obj = np.zeros((len(split_data), len(data.objs)))
    gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
    gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
    gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
    return gt_features_concat

def evaluate(results):
    """ Evaluate predictions and Return metrics. """
    all_preds, seen_preds, unseen_preds = results["all_preds"], results["seen_preds"], results["unseen_preds"]
    all_acc, seen_acc, unseen_acc = np.mean(all_preds), np.mean(seen_preds), np.mean(unseen_preds)    
    return {
        "all_acc": all_acc,
        "seen_acc": seen_acc,
        "unseen_acc": unseen_acc,
        "harmonic_mean": (seen_acc * unseen_acc)**0.5,
        "macro_average_acc": (seen_acc + unseen_acc)*0.5,
    }

def generate_predictions(scores, labels, seen_ids, unseen_ids, seen_mask, data, topk, is_open_world, bias=0.0):
    """ Apply bias and Generate predictions for. """
    def get_predictions(_scores):
        # Get predictions
        _, pair_preds = _scores.topk(topk, dim=1)
        pair_preds = pair_preds[:, :topk].contiguous().view(-1)
        attr_preds = all_pairs[pair_preds][:,0].view(-1, topk)
        obj_preds = all_pairs[pair_preds][:,1].view(-1, topk)
        pair_preds = pair_preds.view(-1, topk)
        return pair_preds, attr_preds, obj_preds
    
    # Get predictions with biases applied
    if is_open_world:
        all_pairs = torch.LongTensor([
            (data.attr2idx[attr], data.obj2idx[obj])
            for attr in data.attrs
            for obj in data.objs
        ])
    else:
        all_pairs = torch.LongTensor([
            (data.attr2idx[attr], data.obj2idx[obj]) 
            for attr, obj in data.pairs
        ])
    scores = scores.clone()
    mask = seen_mask.repeat(scores.shape[0], 1)
    scores[~mask] += bias
    pair_preds, attr_preds, obj_preds = get_predictions(scores)
    
    # Get predictions for seen/unseen pairs
    all_preds = np.array([label in pair_preds[row_id,:topk] for row_id, label in enumerate(labels)])
    seen_preds = all_preds[seen_ids]
    unseen_preds = all_preds[unseen_ids]
    return {
        "pair_preds": pair_preds,
        "attr_preds": attr_preds,
        "obj_preds": obj_preds,
        "all_preds": all_preds,
        "seen_preds": seen_preds,
        "unseen_preds": unseen_preds,
    }

def get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1], is_open_world=False):
    overall_metrics = {}
    for topk in topk_list:
        # Get model"s performance (accuracy) on seen/unseen pairs
        bias = 1e3
        results = generate_predictions(
            features, labels, seen_ids, unseen_ids, seen_mask, data, topk, is_open_world, bias=bias,
        )
        full_unseen_metrics = evaluate(results)
        all_preds, seen_preds, unseen_preds = results["all_preds"], results["seen_preds"], results["unseen_preds"]

        # Get predicted probability distribution of unseen pairs,
        # and the top K scores of seen pairs in the predicted prob. distribution of unseen pairs
        correct_scores = features[np.arange(len(features)), labels][unseen_ids]
        max_seen_scores = features[unseen_ids][:, seen_mask].topk(topk, dim=1)[0][:,topk-1]

        # Compute biases
        unseen_score_diff = max_seen_scores - correct_scores
        correct_unseen_score_diff = unseen_score_diff[unseen_preds] - 1e-4
        correct_unseen_score_diff = torch.sort(correct_unseen_score_diff)[0]
        magic_binsize = 20
        bias_skip = max(len(correct_unseen_score_diff) // magic_binsize, 1)
        bias_list = correct_unseen_score_diff[::bias_skip]

        # Get biased predictions and metrics with different biases
        all_metrics = []
        for bias in bias_list:
            results = generate_predictions(
                features, labels, seen_ids, unseen_ids, seen_mask, data, topk, is_open_world, bias=bias,
            )
            metrics = evaluate(results)
            all_metrics.append(metrics)
        all_metrics.append(full_unseen_metrics)

        # Compute overall metrics
        seen_accs = np.array([metric_dict["seen_acc"] for metric_dict in all_metrics])
        unseen_accs = np.array([metric_dict["unseen_acc"] for metric_dict in all_metrics])
        best_seen_acc = max([metric_dict["seen_acc"] for metric_dict in all_metrics])
        best_unseen_acc = max([metric_dict["unseen_acc"] for metric_dict in all_metrics])
        best_harmonic_mean = max([metric_dict["harmonic_mean"] for metric_dict in all_metrics])
        auc = np.trapz(seen_accs, unseen_accs)
        print(f"topk: {topk}")
        print(f"best_seen_acc: {best_seen_acc:6.4f}")
        print(f"best_unseen_acc: {best_unseen_acc:6.4f}")
        print(f"best_harmonic_mean: {best_harmonic_mean:6.4f}")
        print(f"auc: {auc:6.4f}")

        overall_metrics[topk] = {
            "seen_accs": seen_accs.tolist(),
            "unseen_accs": unseen_accs.tolist(),
            "best_seen_acc": best_seen_acc,
            "best_unseen_acc": best_unseen_acc,
            "best_harmonic_mean": best_harmonic_mean,
            "auc": auc,
        }
    return overall_metrics

In [41]:
class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = True
        
        self.feature = "primitive"
        self.model = "vilt"
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_root = f"../../outputs/mit_states/{self.model}_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

In [42]:
# Load dataset
data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

train_dataset = Precomputed_MITStatesDataset(split="train", feature=args.feature, data=data, args=args, is_limit=True)
valid_dataset = Precomputed_MITStatesDataset(split="valid", feature=args.feature, data=data, args=args, is_limit=True)
test_dataset = Precomputed_MITStatesDataset(split="test", feature=args.feature, data=data, args=args, is_limit=True)
seen_mask = train_dataset.seen_mask


train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360


In [27]:
# Load trained model
emb_init = 1 if args.emb_init else 0
imgproj = args.is_image_projection
print("emb_init: ", emb_init)
print("imgproj: ", imgproj)

# Open to arbitrary models
ckpt_dir = f"../../outputs/mit_states/{args.model}_retrieval_model"
ckpt_name = "{}_retrieval_model_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
    args.model,
    1 if args.emb_init else 0,
    1 if args.is_image_projection else 0,
    1 if args.is_bias else 0,
    args.feature,
    1 if args.is_limit else 0, 
    args.num_epochs,
    args.train_warmup,
    args.batch_size,
    args.learning_rate,
    args.weight_decay,
    args.logit_scale,
)
ckpt_path = os.path.join(ckpt_dir, ckpt_name, "retrieval_model.ckpt")

limit_pairs = list(train_dataset.limit_pair2idx.keys())
open_world_pairs = list(valid_dataset.open_world_pair2idx.keys())
model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, open_world_pairs, args)
model.load_state_dict(torch.load(ckpt_path))
model.eval()
model.to(device)

emb_init:  1
imgproj:  True


RetrievalModel(
  (attr_encoder): Embedding(115, 600)
  (obj_encoder): Embedding(245, 600)
  (text_projection): Linear(in_features=1200, out_features=360, bias=False)
  (image_projection): Linear(in_features=360, out_features=360, bias=False)
)

In [28]:
# Prepare primitive concept activations
seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

if args.split == "valid":
    labels = valid_dataset.labels
    features_interv = valid_dataset.image_embs
    
elif args.split == "test":
    labels = test_dataset.labels
    features_interv = test_dataset.image_embs
    
features_interv = torch.tensor(features_interv).to(device).double()
print(f"Intervention features: {features_interv.shape}")

seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])


In [29]:
# Intervene!!!
with torch.no_grad():
    logits, loss = model(features_interv, pair_labels=None, is_train=False, is_intervene=False, is_open_world=True)
features = logits.softmax(dim=-1).log()
print("features: ", features.shape)

features:  torch.Size([10420, 28175])


In [38]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [39]:
overall_metrics = get_overall_metrics(
    features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3], is_open_world=True,
)

RuntimeError: CUDA out of memory. Tried to allocate 4.18 GiB (GPU 0; 10.92 GiB total capacity; 4.68 GiB already allocated; 3.17 GiB free; 6.97 GiB reserved in total by PyTorch)

# 4. Aggregate Results for Multiple Prompts

In [66]:
# Load trained model
def load_model(template_src, template_id):
    ckpt_dir = f"../../outputs/mit_states/clip_prompt_retrieval_model"
    ckpt_name = "clip_prompt_retrieval_model_TP{}_TID{}_open{}_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
        template_src,
        template_id,
        1 if args.is_open_world else 0,
        1 if args.emb_init else 0,
        1 if args.is_image_projection else 0,
        1 if args.is_bias else 0,
        args.feature,
        1 if args.is_limit else 0, 
        args.num_epochs,
        args.train_warmup,
        args.batch_size,
        args.learning_rate,
        args.weight_decay,
        args.logit_scale,
    )
    ckpt_path = os.path.join(ckpt_dir, ckpt_name, "retrieval_model.ckpt")

    limit_pairs = list(train_dataset.limit_pair2idx.keys())
    open_world_pairs = list(valid_dataset.open_world_pair2idx.keys())
    model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, open_world_pairs, args)
    model.load_state_dict(torch.load(ckpt_path))
    model.eval()
    model.to(device)
    return model

def get_output_path(template_src, template_id, args):
    ckpt_dir = f"../../outputs/mit_states/clip_prompt_retrieval_model"
    ckpt_name = "clip_prompt_retrieval_model_TP{}_TID{}_open{}_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
        template_src,
        template_id,
        1 if args.is_open_world else 0,
        1 if args.emb_init else 0,
        1 if args.is_image_projection else 0,
        1 if args.is_bias else 0,
        args.feature,
        1 if args.is_limit else 0, 
        args.num_epochs,
        args.train_warmup,
        args.batch_size,
        args.learning_rate,
        args.weight_decay,
        args.logit_scale,
    )
    return os.path.join(ckpt_dir, ckpt_name)

In [13]:
!ls ../../outputs/mit_states/clip_prompt_retrieval_model/clip_prompt_retrieval_model_TPclip_TID0_open0_init1_imgproj1_bias0_pair_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07


retrieval_model.ckpt  test_metrics.json  valid_metrics.json


In [39]:
class parseArguments:
    def __init__(self):
        self.data_root = "../../outputs/mit_states/clip_prompt_retrieval_model"
        self.feature = "pair"
        #self.template_src = "clip"
        #self.template_id = 0

args = parseArguments()

## 4.1. Usefulness

In [40]:
# Usefulness
template_count_dict = {"clip": 7, "compdl": 10}

for template_src, template_count in template_count_dict.items():
    # For each possible template source, intialize a new statistics dict
    all_statistics_dict = {
        "valid": {
            "1": {"auc": []},
            "2": {"auc": []},
            "3": {"auc": []},
        },
        "test": {
            "1": {
                "auc": [],
                "best_seen_acc": [],
                "best_unseen_acc": [],
                "best_harmonic_mean": [],
            },
            "2": {"auc": []},
            "3": {"auc": []},
        },
    }

    # Sum all the required statistics
    for template_id in range(template_count):
        exp_name = "clip_prompt_retrieval_model_TP{}_TID{}_open0_init1_imgproj1_bias0_{}_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07".format(
            template_src,
            template_id,
            args.feature,
        )
        output_dir = os.path.join(args.data_root, exp_name)
        
        # Valid metrics
        with open(os.path.join(output_dir, "valid_metrics.json"), "rb") as f:
            data = json.load(f)
        all_statistics_dict["valid"]["1"]["auc"].append(data["1"]["auc"])
        all_statistics_dict["valid"]["2"]["auc"].append(data["2"]["auc"])
        all_statistics_dict["valid"]["3"]["auc"].append(data["3"]["auc"])
        
        # Test metrics
        with open(os.path.join(output_dir, "test_metrics.json"), "rb") as f:
            data = json.load(f)
        all_statistics_dict["test"]["1"]["auc"].append(data["1"]["auc"])
        all_statistics_dict["test"]["2"]["auc"].append(data["2"]["auc"])
        all_statistics_dict["test"]["3"]["auc"].append(data["3"]["auc"])
        all_statistics_dict["test"]["1"]["best_seen_acc"].append(data["1"]["best_seen_acc"])
        all_statistics_dict["test"]["1"]["best_unseen_acc"].append(data["1"]["best_unseen_acc"])
        all_statistics_dict["test"]["1"]["best_harmonic_mean"].append(data["1"]["best_harmonic_mean"])
        
    # Take the mean of the required statistics and output
    print(f"template_src: {template_src}")
    print("split: valid")
    for topk in all_statistics_dict["valid"].keys():
        for k, v in all_statistics_dict["valid"][topk].items():
            mean = 100 * np.mean(all_statistics_dict["valid"][topk][k])
            std = 100 * np.std(all_statistics_dict["valid"][topk][k])
            print(f"| topk {topk} | {k} | mean: {mean:8.6f} | std: {std:8.6f}")
    print()
    
    print("split: test")
    for topk in all_statistics_dict["test"].keys():
        for k, v in all_statistics_dict["test"][topk].items():
            mean = 100 * np.mean(all_statistics_dict["test"][topk][k])
            std = 100 * np.std(all_statistics_dict["test"][topk][k])
            print(f"| topk {topk} | {k} | mean: {mean:8.6f} | std: {std:8.6f}")
    print()

template_src: clip
split: valid
| topk 1 | auc | mean: 8.340957 | std: 0.342888
| topk 2 | auc | mean: 17.300506 | std: 0.453321
| topk 3 | auc | mean: 24.907180 | std: 0.547247

split: test
| topk 1 | auc | mean: 6.974296 | std: 0.238532
| topk 1 | best_seen_acc | mean: 34.255702 | std: 0.625115
| topk 1 | best_unseen_acc | mean: 27.981966 | std: 0.822270
| topk 1 | best_harmonic_mean | mean: 20.556065 | std: 0.439918
| topk 2 | auc | mean: 15.927997 | std: 0.390615
| topk 3 | auc | mean: 23.217200 | std: 0.519239

template_src: compdl
split: valid
| topk 1 | auc | mean: 7.913990 | std: 0.143389
| topk 2 | auc | mean: 16.757348 | std: 0.466115
| topk 3 | auc | mean: 24.071215 | std: 0.549441

split: test
| topk 1 | auc | mean: 6.633088 | std: 0.353728
| topk 1 | best_seen_acc | mean: 33.369748 | std: 1.019165
| topk 1 | best_unseen_acc | mean: 27.368818 | std: 0.541875
| topk 1 | best_harmonic_mean | mean: 20.056457 | std: 0.591366
| topk 2 | auc | mean: 15.167544 | std: 0.483960
| to

## 4.2. Interpretability for Interv (GT)

In [68]:
# Interpretability for Interv(GT)

class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "primitive"
        self.model = "clip"
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_dir = f"../../outputs/mit_states/clip_prompt_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

# Start!!
template_count_dict = {"clip": 7, "compdl": 10}

for split in ["valid", "test"]:
    args.split = split
    for template_src, template_count in template_count_dict.items():
        for template_id in range(template_count):
            print("=" * 70)
            print(f"| {args.split} | template {template_src} | template_id {template_id} |")
            args.precomputed_data_root = os.path.join(args.precomputed_data_dir, template_src, f"template_{template_id}")

            # Load dataset
            data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
            seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
            seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

            train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
            valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
            test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
            seen_mask = train_dataset.seen_mask

            # Load model
            model = load_model(template_src, template_id)

            # Prepare groundtruth labels
            seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

            if args.split == "valid":
                labels_attr = [sample["attr_id"] for sample in data.valid_data]
                labels_obj = [sample["obj_id"] for sample in data.valid_data]
                labels = [sample["pair_id"] for sample in data.valid_data]
                gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
                gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
            elif args.split == "test":
                labels_attr = [sample["attr_id"] for sample in data.test_data]
                labels_obj = [sample["obj_id"] for sample in data.test_data]
                labels = [sample["pair_id"] for sample in data.test_data]
                gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
                gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))

            gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
            gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
            gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
            gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
            print(f"Intervention features: {gt_features_concat.shape}")

            # Intervene!!!
            with torch.no_grad():
                logits, loss = model(gt_features_concat, pair_labels=None, is_train=False, is_intervene=True)
            features = logits.softmax(dim=-1).log()
            print("features: ", features.shape)

            # Compute metrics
            overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])
            output_path = get_output_path(template_src, template_id, args)
            with open(os.path.join(output_path, f"{args.split}_metrics_interv_gt.json"), "w") as f:
                json.dump(overall_metrics, f)
            print("\n\n")



| valid | template clip | template_id 0 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
features:  torch.Size([10420, 1962])
topk: 1
best_seen_acc: 0.6388
best_unseen_acc: 0.6941
best_harmonic_mean: 0.5666
auc: 0.4128
topk: 2
best_seen_acc: 0.7961
best_unseen_acc: 0.8328
best_harmo

topk: 3
best_seen_acc: 0.8297
best_unseen_acc: 0.8720
best_harmonic_mean: 0.7561
auc: 0.6913



| valid | template compdl | template_id 0 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
features:  torch.Size([10420, 1962])
topk: 1
best_seen_acc: 0.5640
best_unseen_acc: 0.6247
best_

topk: 1
best_seen_acc: 0.6198
best_unseen_acc: 0.6974
best_harmonic_mean: 0.5528
auc: 0.3940
topk: 2
best_seen_acc: 0.7587
best_unseen_acc: 0.8336
best_harmonic_mean: 0.6924
auc: 0.5938
topk: 3
best_seen_acc: 0.8270
best_unseen_acc: 0.9144
best_harmonic_mean: 0.7745
auc: 0.7161



| valid | template compdl | template_id 7 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360


image_dim:    360
seen_indices: 2380 | unseen_indices: 10615
Intervention features: torch.Size([12995, 360])
features:  torch.Size([12995, 1962])
topk: 1
best_seen_acc: 0.5899
best_unseen_acc: 0.7511
best_harmonic_mean: 0.5631
auc: 0.4087
topk: 2
best_seen_acc: 0.7445
best_unseen_acc: 0.8539
best_harmonic_mean: 0.7120
auc: 0.5979
topk: 3
best_seen_acc: 0.8218
best_unseen_acc: 0.9048
best_harmonic_mean: 0.7758
auc: 0.7102



| test | template clip | template_id 4 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train 

image_dim:    360
seen_indices: 2380 | unseen_indices: 10615
Intervention features: torch.Size([12995, 360])
features:  torch.Size([12995, 1962])
topk: 1
best_seen_acc: 0.5332
best_unseen_acc: 0.6640
best_harmonic_mean: 0.4835
auc: 0.3073
topk: 2
best_seen_acc: 0.6777
best_unseen_acc: 0.8039
best_harmonic_mean: 0.6296
auc: 0.4983
topk: 3
best_seen_acc: 0.7298
best_unseen_acc: 0.8500
best_harmonic_mean: 0.7002
auc: 0.5911



| test | template compdl | template_id 4 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| trai

In [72]:
# Aggregate interpretability results for Interv(GT)
# Usefulness
template_count_dict = {"clip": 7, "compdl": 10}

for template_src, template_count in template_count_dict.items():
    # For each possible template source, intialize a new statistics dict
    all_statistics_dict = {
        "valid": {
            "1": {"auc": []},
            "2": {"auc": []},
            "3": {"auc": []},
        },
        "test": {
            "1": {
                "auc": [],
                "best_seen_acc": [],
                "best_unseen_acc": [],
                "best_harmonic_mean": [],
            },
            "2": {"auc": []},
            "3": {"auc": []},
        },
    }

    # Sum all the required statistics
    for template_id in range(template_count):
        exp_name = "clip_prompt_retrieval_model_TP{}_TID{}_open0_init1_imgproj1_bias0_{}_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07".format(
            template_src,
            template_id,
            args.feature,
        )
        output_dir = get_output_path(template_src, template_id, args)
        
        # Valid metrics
        with open(os.path.join(output_dir, "valid_metrics_interv_gt.json"), "rb") as f:
            data = json.load(f)
        all_statistics_dict["valid"]["1"]["auc"].append(data["1"]["auc"])
        all_statistics_dict["valid"]["2"]["auc"].append(data["2"]["auc"])
        all_statistics_dict["valid"]["3"]["auc"].append(data["3"]["auc"])
        
        # Test metrics
        with open(os.path.join(output_dir, "test_metrics_interv_gt.json"), "rb") as f:
            data = json.load(f)
        all_statistics_dict["test"]["1"]["auc"].append(data["1"]["auc"])
        all_statistics_dict["test"]["2"]["auc"].append(data["2"]["auc"])
        all_statistics_dict["test"]["3"]["auc"].append(data["3"]["auc"])
        all_statistics_dict["test"]["1"]["best_seen_acc"].append(data["1"]["best_seen_acc"])
        all_statistics_dict["test"]["1"]["best_unseen_acc"].append(data["1"]["best_unseen_acc"])
        all_statistics_dict["test"]["1"]["best_harmonic_mean"].append(data["1"]["best_harmonic_mean"])
        
    # Take the mean of the required statistics and output
    print(f"template_src: {template_src}")
    print("split: valid")
    for topk in all_statistics_dict["valid"].keys():
        for k, v in all_statistics_dict["valid"][topk].items():
            mean = 100 * np.mean(all_statistics_dict["valid"][topk][k])
            std = 100 * np.std(all_statistics_dict["valid"][topk][k])
            print(f"| topk {topk} | {k} | mean: {mean:8.6f} | std: {std:8.6f}")
    print()
    
    print("split: test")
    for topk in all_statistics_dict["test"].keys():
        for k, v in all_statistics_dict["test"][topk].items():
            mean = 100 * np.mean(all_statistics_dict["test"][topk][k])
            std = 100 * np.std(all_statistics_dict["test"][topk][k])
            print(f"| topk {topk} | {k} | mean: {mean:8.6f} | std: {std:8.6f}")
    print()

template_src: clip
split: valid
| topk 1 | auc | mean: 39.267366 | std: 3.926191
| topk 2 | auc | mean: 58.646386 | std: 3.537431
| topk 3 | auc | mean: 68.550824 | std: 3.394782

split: test
| topk 1 | auc | mean: 39.861835 | std: 3.008578
| topk 1 | best_seen_acc | mean: 59.237695 | std: 2.428830
| topk 1 | best_unseen_acc | mean: 72.666712 | std: 2.387179
| topk 1 | best_harmonic_mean | mean: 55.703362 | std: 2.639456
| topk 2 | auc | mean: 58.146802 | std: 3.129845
| topk 3 | auc | mean: 68.448382 | std: 2.582705

template_src: compdl
split: valid
| topk 1 | auc | mean: 33.711541 | std: 3.976065
| topk 2 | auc | mean: 52.148164 | std: 4.991058
| topk 3 | auc | mean: 64.094234 | std: 5.566963

split: test
| topk 1 | auc | mean: 32.318739 | std: 3.805140
| topk 1 | best_seen_acc | mean: 54.138655 | std: 3.506340
| topk 1 | best_unseen_acc | mean: 67.075836 | std: 2.708888
| topk 1 | best_harmonic_mean | mean: 49.055195 | std: 2.941528
| topk 2 | auc | mean: 51.082362 | std: 3.604010


## 4.3. Interpretability for Interv(GTX)

In [75]:
# Interpretability for Interv(GTX)

class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "primitive"
        self.model = "clip"
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_dir = f"../../outputs/mit_states/clip_prompt_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

# Start!!
template_count_dict = {"clip": 7, "compdl": 10}

for split in ["valid", "test"]:
    args.split = split
    for template_src, template_count in template_count_dict.items():
        for template_id in range(template_count):
            print("=" * 70)
            print(f"| {args.split} | template {template_src} | template_id {template_id} |")
            args.precomputed_data_root = os.path.join(args.precomputed_data_dir, template_src, f"template_{template_id}")

            # Load dataset
            data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
            seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
            seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

            train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
            valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
            test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
            seen_mask = train_dataset.seen_mask

            # Load model
            model = load_model(template_src, template_id)

            # Prepare groundtruth labels
            seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

            if args.split == "valid":
                labels_attr = [sample["attr_id"] for sample in data.valid_data]
                labels_obj = [sample["obj_id"] for sample in data.valid_data]
                labels = [sample["pair_id"] for sample in data.valid_data]
                gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
                gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
            elif args.split == "test":
                labels_attr = [sample["attr_id"] for sample in data.test_data]
                labels_obj = [sample["obj_id"] for sample in data.test_data]
                labels = [sample["pair_id"] for sample in data.test_data]
                gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
                gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))

            gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
            gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
            gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
            gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
            print(f"Intervention features: {gt_features_concat.shape}")

            # Prepare "binary" primitive concept activations
            # Set GT concepts to 1
            seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

            if args.split == "valid":
                labels_attr = [sample["attr_id"] for sample in data.valid_data]
                labels_obj = [sample["obj_id"] for sample in data.valid_data]
                labels = [sample["pair_id"] for sample in data.valid_data]
                features_interv = valid_dataset.image_embs
                features_interv = torch.tensor(features_interv).to(device).double()
                features_interv = F.normalize(features_interv, dim=1)

            elif args.split == "test":
                labels_attr = [sample["attr_id"] for sample in data.test_data]
                labels_obj = [sample["obj_id"] for sample in data.test_data]
                labels = [sample["pair_id"] for sample in data.test_data]
                features_interv = test_dataset.image_embs
                features_interv = torch.tensor(features_interv).to(device).double()
                features_interv = F.normalize(features_interv, dim=1)

            features_interv = torch.where(gt_features_concat==1, 1.0, features_interv)
            print(f"Intervention features: {features_interv.shape}")

            # Intervene!!!
            with torch.no_grad():
                logits, loss = model(features_interv, pair_labels=None, is_train=False, is_intervene=True)
            features = logits.softmax(dim=-1).log()
            print("features: ", features.shape)

            # Compute metrics
            overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])
            output_path = get_output_path(template_src, template_id, args)
            with open(os.path.join(output_path, f"{args.split}_metrics_interv_gtx.json"), "w") as f:
                json.dump(overall_metrics, f)
            print("\n\n")



| valid | template clip | template_id 0 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
features:  torch.Size([10420, 1962])
topk: 1
best_seen_acc: 0.7240
best_unseen_acc: 0.7240
best_harmonic

image_dim:    360
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
features:  torch.Size([10420, 1962])
topk: 1
best_seen_acc: 0.6535
best_unseen_acc: 0.7062
best_harmonic_mean: 0.5891
auc: 0.4288
topk: 2
best_seen_acc: 0.8113
best_unseen_acc: 0.8455
best_harmonic_mean: 0.7176
auc: 0.6464
topk: 3
best_seen_acc: 0.8633
best_unseen_acc: 0.8904
best_harmonic_mean: 0.7954
auc: 0.7418



| valid | template compdl | template_id 0 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| tr

train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
features:  torch.Size([10420, 1962])
topk: 1
best_seen_acc: 0.6367
best_unseen_acc: 0.7134
best_harmonic_mean: 0.5683
auc: 0.4174
topk: 2
best_see

image_dim:    360
seen_indices: 2380 | unseen_indices: 10615
Intervention features: torch.Size([12995, 360])
seen_indices: 2380 | unseen_indices: 10615
Intervention features: torch.Size([12995, 360])
features:  torch.Size([12995, 1962])
topk: 1
best_seen_acc: 0.5739
best_unseen_acc: 0.7502
best_harmonic_mean: 0.5640
auc: 0.4016
topk: 2
best_seen_acc: 0.7336
best_unseen_acc: 0.8643
best_harmonic_mean: 0.6959
auc: 0.5979
topk: 3
best_seen_acc: 0.8172
best_unseen_acc: 0.9068
best_harmonic_mean: 0.7753
auc: 0.7121



| test | template clip | template_id 3 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| tra

train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
seen_indices: 2380 | unseen_indices: 10615
Intervention features: torch.Size([12995, 360])
seen_indices: 2380 | unseen_indices: 10615
Intervention features: torch.Size([12995, 360])
features:  torch.Size([12995, 1962])
topk: 1
best_seen_acc: 0.5597
best_unseen_acc: 0.6904
best_harmonic_mean: 0.5162
auc: 0.3499
topk: 2
best_s

image_dim:    360
seen_indices: 2380 | unseen_indices: 10615
Intervention features: torch.Size([12995, 360])
seen_indices: 2380 | unseen_indices: 10615
Intervention features: torch.Size([12995, 360])
features:  torch.Size([12995, 1962])
topk: 1
best_seen_acc: 0.4908
best_unseen_acc: 0.6229
best_harmonic_mean: 0.4404
auc: 0.2644
topk: 2
best_seen_acc: 0.6836
best_unseen_acc: 0.8100
best_harmonic_mean: 0.6036
auc: 0.4883
topk: 3
best_seen_acc: 0.7563
best_unseen_acc: 0.8642
best_harmonic_mean: 0.6938
auc: 0.6081



| test | template compdl | template_id 9 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| t

In [76]:
# Aggregate interpretability results for Interv(GT)
# 
template_count_dict = {"clip": 7, "compdl": 10}

for template_src, template_count in template_count_dict.items():
    # For each possible template source, intialize a new statistics dict
    all_statistics_dict = {
        "valid": {
            "1": {"auc": []},
            "2": {"auc": []},
            "3": {"auc": []},
        },
        "test": {
            "1": {
                "auc": [],
                "best_seen_acc": [],
                "best_unseen_acc": [],
                "best_harmonic_mean": [],
            },
            "2": {"auc": []},
            "3": {"auc": []},
        },
    }

    # Sum all the required statistics
    for template_id in range(template_count):
        exp_name = "clip_prompt_retrieval_model_TP{}_TID{}_open0_init1_imgproj1_bias0_{}_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07".format(
            template_src,
            template_id,
            args.feature,
        )
        output_dir = get_output_path(template_src, template_id, args)
        
        # Valid metrics
        with open(os.path.join(output_dir, "valid_metrics_interv_gtx.json"), "rb") as f:
            data = json.load(f)
        all_statistics_dict["valid"]["1"]["auc"].append(data["1"]["auc"])
        all_statistics_dict["valid"]["2"]["auc"].append(data["2"]["auc"])
        all_statistics_dict["valid"]["3"]["auc"].append(data["3"]["auc"])
        
        # Test metrics
        with open(os.path.join(output_dir, "test_metrics_interv_gtx.json"), "rb") as f:
            data = json.load(f)
        all_statistics_dict["test"]["1"]["auc"].append(data["1"]["auc"])
        all_statistics_dict["test"]["2"]["auc"].append(data["2"]["auc"])
        all_statistics_dict["test"]["3"]["auc"].append(data["3"]["auc"])
        all_statistics_dict["test"]["1"]["best_seen_acc"].append(data["1"]["best_seen_acc"])
        all_statistics_dict["test"]["1"]["best_unseen_acc"].append(data["1"]["best_unseen_acc"])
        all_statistics_dict["test"]["1"]["best_harmonic_mean"].append(data["1"]["best_harmonic_mean"])
        
    # Take the mean of the required statistics and output
    print(f"template_src: {template_src}")
    print("split: valid")
    for topk in all_statistics_dict["valid"].keys():
        for k, v in all_statistics_dict["valid"][topk].items():
            mean = 100 * np.mean(all_statistics_dict["valid"][topk][k])
            std = 100 * np.std(all_statistics_dict["valid"][topk][k])
            print(f"| topk {topk} | {k} | mean: {mean:8.6f} | std: {std:8.6f}")
    print()
    
    print("split: test")
    for topk in all_statistics_dict["test"].keys():
        for k, v in all_statistics_dict["test"][topk].items():
            mean = 100 * np.mean(all_statistics_dict["test"][topk][k])
            std = 100 * np.std(all_statistics_dict["test"][topk][k])
            print(f"| topk {topk} | {k} | mean: {mean:8.6f} | std: {std:8.6f}")
    print()

template_src: clip
split: valid
| topk 1 | auc | mean: 43.616760 | std: 4.273825
| topk 2 | auc | mean: 63.981882 | std: 4.311276
| topk 3 | auc | mean: 73.299077 | std: 4.150668

split: test
| topk 1 | auc | mean: 44.037432 | std: 3.430984
| topk 1 | best_seen_acc | mean: 62.803121 | std: 3.177578
| topk 1 | best_unseen_acc | mean: 75.168562 | std: 2.123174
| topk 1 | best_harmonic_mean | mean: 59.028622 | std: 2.875708
| topk 2 | auc | mean: 63.670394 | std: 3.489931
| topk 3 | auc | mean: 73.324113 | std: 3.069886

template_src: compdl
split: valid
| topk 1 | auc | mean: 36.527693 | std: 4.332185
| topk 2 | auc | mean: 56.577485 | std: 5.268111
| topk 3 | auc | mean: 68.128082 | std: 5.697588

split: test
| topk 1 | auc | mean: 34.980297 | std: 4.061202
| topk 1 | best_seen_acc | mean: 56.605042 | std: 3.547861
| topk 1 | best_unseen_acc | mean: 68.974093 | std: 2.961350
| topk 1 | best_harmonic_mean | mean: 50.995273 | std: 3.191393
| topk 2 | auc | mean: 55.112662 | std: 3.670802


## 4.4. Usefulness - Open World

In [89]:
# Aggregate open-world usefulness
template_count_dict = {"clip": 7, "compdl": 10}
args.feature = "pair"

for template_src, template_count in template_count_dict.items():
    # For each possible template source, intialize a new statistics dict
    all_statistics_dict = {
        "valid": {
            "1": {"auc": []},
            "2": {"auc": []},
            "3": {"auc": []},
        },
        "test": {
            "1": {
                "auc": [],
                "best_seen_acc": [],
                "best_unseen_acc": [],
                "best_harmonic_mean": [],
            },
            "2": {"auc": []},
            "3": {"auc": []},
        },
    }

    # Sum all the required statistics
    for template_id in range(template_count):
        exp_name = "clip_prompt_retrieval_model_TP{}_TID{}_open1_init1_imgproj1_bias0_{}_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07".format(
            template_src,
            template_id,
            args.feature,
        )
        output_dir = os.path.join("../../outputs/mit_states/clip_prompt_open_world_eval", exp_name)
        
        # Valid metrics
        with open(os.path.join(output_dir, "valid_metrics.json"), "rb") as f:
            data = json.load(f)
        all_statistics_dict["valid"]["1"]["auc"].append(data["1"]["auc"])
        all_statistics_dict["valid"]["2"]["auc"].append(data["2"]["auc"])
        all_statistics_dict["valid"]["3"]["auc"].append(data["3"]["auc"])
        
        # Test metrics
        with open(os.path.join(output_dir, "test_metrics.json"), "rb") as f:
            data = json.load(f)
        all_statistics_dict["test"]["1"]["auc"].append(data["1"]["auc"])
        all_statistics_dict["test"]["2"]["auc"].append(data["2"]["auc"])
        all_statistics_dict["test"]["3"]["auc"].append(data["3"]["auc"])
        all_statistics_dict["test"]["1"]["best_seen_acc"].append(data["1"]["best_seen_acc"])
        all_statistics_dict["test"]["1"]["best_unseen_acc"].append(data["1"]["best_unseen_acc"])
        all_statistics_dict["test"]["1"]["best_harmonic_mean"].append(data["1"]["best_harmonic_mean"])
        
    # Take the mean of the required statistics and output
    print(f"template_src: {template_src}")
    print("split: valid")
    for topk in all_statistics_dict["valid"].keys():
        for k, v in all_statistics_dict["valid"][topk].items():
            mean = 100 * np.mean(all_statistics_dict["valid"][topk][k])
            std = 100 * np.std(all_statistics_dict["valid"][topk][k])
            print(f"| topk {topk} | {k} | mean: {mean:8.6f} | std: {std:8.6f}")
    print()
    
    print("split: test")
    for topk in all_statistics_dict["test"].keys():
        for k, v in all_statistics_dict["test"][topk].items():
            mean = 100 * np.mean(all_statistics_dict["test"][topk][k])
            std = 100 * np.std(all_statistics_dict["test"][topk][k])
            print(f"| topk {topk} | {k} | mean: {mean:8.6f} | std: {std:8.6f}")
    print()

template_src: clip
split: valid
| topk 1 | auc | mean: 2.630426 | std: 0.240937
| topk 2 | auc | mean: 6.052778 | std: 0.412947
| topk 3 | auc | mean: 9.103339 | std: 0.452032

split: test
| topk 1 | auc | mean: 2.153398 | std: 0.120531
| topk 1 | best_seen_acc | mean: 32.803121 | std: 0.906104
| topk 1 | best_unseen_acc | mean: 10.055851 | std: 0.545023
| topk 1 | best_harmonic_mean | mean: 11.030260 | std: 0.324100
| topk 2 | auc | mean: 5.241396 | std: 0.255591
| topk 3 | auc | mean: 7.897845 | std: 0.411861

template_src: compdl
split: valid
| topk 1 | auc | mean: 2.516759 | std: 0.132692
| topk 2 | auc | mean: 5.833674 | std: 0.258408
| topk 3 | auc | mean: 8.817559 | std: 0.357837

split: test
| topk 1 | auc | mean: 2.057855 | std: 0.183082
| topk 1 | best_seen_acc | mean: 32.147059 | std: 1.081629
| topk 1 | best_unseen_acc | mean: 9.848328 | std: 0.582881
| topk 1 | best_harmonic_mean | mean: 10.809411 | std: 0.533516
| topk 2 | auc | mean: 4.984341 | std: 0.391096
| topk 3 | a

In [92]:
templates = {
    "clip": [
        "itap of a {}",
        "a bad photo of the {}",
        "a origami {}",
        "a photo of the large {}",
        "a {} in a video game",
        "art of the {}",
        "a photo of the small {}",
    ],
    "compdl": [
        "this is {}",
        "the object is {}",
        "the item is {}",
        "the item in the given picture is {}",
        "the thing in this bad photo is {}",
        "the item in the photo is {}",
        "the item in this cool photo is {}",
        "the main object in the photo is {}",
        "the item in the low resolution image is {}",
        "the object in the photo is {}",
    ],
}

# 5. Evaluation for (Random) Multiple Prompts

In [27]:
# Load trained model
def load_model(template_src, num_prompts):
    ckpt_dir = f"../../outputs/mit_states/clip_prompt_retrieval_model_random"
    ckpt_name = "clip_prompt_retrieval_model_TP{}_NP{}_open{}_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
        template_src,
        num_prompts,
        1 if args.is_open_world else 0,
        1 if args.emb_init else 0,
        1 if args.is_image_projection else 0,
        1 if args.is_bias else 0,
        args.feature,
        1 if args.is_limit else 0, 
        args.num_epochs,
        args.train_warmup,
        args.batch_size,
        args.learning_rate,
        args.weight_decay,
        args.logit_scale,
    )
    ckpt_path = os.path.join(ckpt_dir, ckpt_name, "retrieval_model.ckpt")

    limit_pairs = list(train_dataset.limit_pair2idx.keys())
    open_world_pairs = list(valid_dataset.open_world_pair2idx.keys())
    model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, open_world_pairs, args)
    model.load_state_dict(torch.load(ckpt_path))
    model.eval()
    model.to(device)
    return model

def get_output_path(template_src, num_prompts, args):
    ckpt_dir = f"../../outputs/mit_states/clip_prompt_retrieval_model_random"
    ckpt_name = "clip_prompt_retrieval_model_TP{}_NP{}_open{}_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
        template_src,
        num_prompts,
        1 if args.is_open_world else 0,
        1 if args.emb_init else 0,
        1 if args.is_image_projection else 0,
        1 if args.is_bias else 0,
        args.feature,
        1 if args.is_limit else 0, 
        args.num_epochs,
        args.train_warmup,
        args.batch_size,
        args.learning_rate,
        args.weight_decay,
        args.logit_scale,
    )
    return os.path.join(ckpt_dir, ckpt_name)

In [16]:
!ls ../../outputs/mit_states/clip_prompt_retrieval_model_random

clip_prompt_retrieval_model_TPclip_NP4_open0_init1_imgproj1_bias0_pair_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07
clip_prompt_retrieval_model_TPclip_NP4_open0_init1_imgproj1_bias0_primitive_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07


## 5.1. Interpretabilty for Interv(GT)

In [18]:
!ls ../../outputs/mit_states/clip_prompt_retrieval_model_random

clip_prompt_retrieval_model_TPclip_NP4_open0_init1_imgproj1_bias0_pair_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07
clip_prompt_retrieval_model_TPclip_NP4_open0_init1_imgproj1_bias0_primitive_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07


In [33]:
# Interpretability for Interv(GT)

class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "primitive"
        self.model = "clip"
        self.num_prompts = 4
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_root = f"../../outputs/mit_states/clip_prompt_precompute_features_random/{self.model}/num_prompts_{self.num_prompts}"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

# Start!!
template_count_dict = {"clip": 7, "compdl": 10}

for split in ["valid", "test"]:
    args.split = split
    for template_src in ["clip"]:
        print("=" * 70)
        print(f"| {args.split} | template {template_src} | num_prompts {args.num_prompts} |")

        # Load dataset
        data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
        seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
        seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

        train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
        valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
        test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
        seen_mask = train_dataset.seen_mask

        # Load model
        model = load_model(template_src, args.num_prompts)

        # Prepare groundtruth labels
        seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

        if args.split == "valid":
            labels_attr = [sample["attr_id"] for sample in data.valid_data]
            labels_obj = [sample["obj_id"] for sample in data.valid_data]
            labels = [sample["pair_id"] for sample in data.valid_data]
            gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
            gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
        elif args.split == "test":
            labels_attr = [sample["attr_id"] for sample in data.test_data]
            labels_obj = [sample["obj_id"] for sample in data.test_data]
            labels = [sample["pair_id"] for sample in data.test_data]
            gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
            gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))

        gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
        gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
        gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
        gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
        print(f"Intervention features: {gt_features_concat.shape}")

        # Intervene!!!
        with torch.no_grad():
            logits, loss = model(gt_features_concat, pair_labels=None, is_train=False, is_intervene=True)
        features = logits.softmax(dim=-1).log()
        print("features: ", features.shape)

        # Compute metrics
        overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])
        output_path = get_output_path(template_src, args.num_prompts, args)
        with open(os.path.join(output_path, f"{args.split}_metrics_interv_gt.json"), "w") as f:
            json.dump(overall_metrics, f)
        print("\n\n")



| valid | template clip | num_prompts 4 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
features:  torch.Size([10420, 1962])
topk: 1
best_seen_acc: 0.6855
best_unseen_acc: 0.7164
best_harmonic_mean: 0.6094
auc: 0.4588
topk: 2
best_seen_acc: 0.8048
best_unseen_acc: 0.8443
best_harmo

## 5.2. Interpretability for Interv(GTX)

In [34]:
# Interpretability for Interv(GT)

class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "primitive"
        self.model = "clip"
        self.num_prompts = 4
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_root = f"../../outputs/mit_states/clip_prompt_precompute_features_random/{self.model}/num_prompts_{self.num_prompts}"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

# Start!!
template_count_dict = {"clip": 7, "compdl": 10}

for split in ["valid", "test"]:
    args.split = split
    for template_src in ["clip"]:
        print("=" * 70)
        print(f"| {args.split} | template {template_src} | num_prompts {args.num_prompts} |")

        # Load dataset
        data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
        seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
        seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

        train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
        valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
        test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
        seen_mask = train_dataset.seen_mask

        # Load model
        model = load_model(template_src, args.num_prompts)

        # Prepare groundtruth labels
        seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

        if args.split == "valid":
            labels_attr = [sample["attr_id"] for sample in data.valid_data]
            labels_obj = [sample["obj_id"] for sample in data.valid_data]
            labels = [sample["pair_id"] for sample in data.valid_data]
            gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
            gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
        elif args.split == "test":
            labels_attr = [sample["attr_id"] for sample in data.test_data]
            labels_obj = [sample["obj_id"] for sample in data.test_data]
            labels = [sample["pair_id"] for sample in data.test_data]
            gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
            gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))

        gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
        gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
        gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
        gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
        print(f"Intervention features: {gt_features_concat.shape}")
        
        # Prepare "binary" primitive concept activations
        # Set GT concepts to 1
        seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

        if args.split == "valid":
            labels_attr = [sample["attr_id"] for sample in data.valid_data]
            labels_obj = [sample["obj_id"] for sample in data.valid_data]
            labels = [sample["pair_id"] for sample in data.valid_data]
            features_interv = valid_dataset.image_embs
            features_interv = torch.tensor(features_interv).to(device).double()
            features_interv = F.normalize(features_interv, dim=1)

        elif args.split == "test":
            labels_attr = [sample["attr_id"] for sample in data.test_data]
            labels_obj = [sample["obj_id"] for sample in data.test_data]
            labels = [sample["pair_id"] for sample in data.test_data]
            features_interv = test_dataset.image_embs
            features_interv = torch.tensor(features_interv).to(device).double()
            features_interv = F.normalize(features_interv, dim=1)

        features_interv = torch.where(gt_features_concat==1, 1.0, features_interv)
        print(f"Intervention features: {features_interv.shape}")

        # Intervene!!!
        with torch.no_grad():
            logits, loss = model(features_interv, pair_labels=None, is_train=False, is_intervene=True)
        features = logits.softmax(dim=-1).log()
        print("features: ", features.shape)

        # Compute metrics
        overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])
        output_path = get_output_path(template_src, args.num_prompts, args)
        with open(os.path.join(output_path, f"{args.split}_metrics_interv_gt.json"), "w") as f:
            json.dump(overall_metrics, f)
        print("\n\n")



| valid | template clip | num_prompts 4 |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360])
features:  torch.Size([10420, 1962])
topk: 1
best_seen_acc: 0.7164
best_unseen_acc: 0.7463
best_harmonic

# 6. Evaluation for All Prompts

In [53]:
# Load trained model
def load_model(template_src, template_id):
    ckpt_dir = f"../../outputs/mit_states/clip_prompt_retrieval_model"
    ckpt_name = "clip_prompt_retrieval_model_TP{}_TID{}_open{}_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
        template_src,
        template_id,
        1 if args.is_open_world else 0,
        1 if args.emb_init else 0,
        1 if args.is_image_projection else 0,
        1 if args.is_bias else 0,
        args.feature,
        1 if args.is_limit else 0, 
        args.num_epochs,
        args.train_warmup,
        args.batch_size,
        args.learning_rate,
        args.weight_decay,
        args.logit_scale,
    )
    ckpt_path = os.path.join(ckpt_dir, ckpt_name, "retrieval_model.ckpt")
    print(ckpt_path)

    limit_pairs = list(train_dataset.limit_pair2idx.keys())
    open_world_pairs = list(valid_dataset.open_world_pair2idx.keys())
    model = RetrievalModel(data, train_dataset.image_dim, limit_pairs, open_world_pairs, args)
    model.load_state_dict(torch.load(ckpt_path))
    model.eval()
    model.to(device)
    return model

def get_output_path(template_src, template_id, args):
    ckpt_dir = f"../../outputs/mit_states/clip_prompt_retrieval_model"
    ckpt_name = "clip_prompt_retrieval_model_TP{}_TID{}_open{}_init{}_imgproj{}_bias{}_{}_Lim{}_N{}_TW{}_B{}_LR{}_WD{}_L{}".format(
        template_src,
        template_id,
        1 if args.is_open_world else 0,
        1 if args.emb_init else 0,
        1 if args.is_image_projection else 0,
        1 if args.is_bias else 0,
        args.feature,
        1 if args.is_limit else 0, 
        args.num_epochs,
        args.train_warmup,
        args.batch_size,
        args.learning_rate,
        args.weight_decay,
        args.logit_scale,
    )
    return os.path.join(ckpt_dir, ckpt_name)

## 6.1. Interpretabilty for Interv(GT)

In [57]:
# Interpretability for Interv(GT)

class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "primitive"
        self.model = "clip"
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_dir = f"../../outputs/mit_states/clip_prompt_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

# Start!!
template_count_dict = {"clip": 7, "compdl": 10}

for split in ["valid", "test"]:
    args.split = split
    template_src = "clip"
    template_id = "all"  # 3 or all
    
    print("=" * 70)
    print(f"| {args.split} | template {template_src} | template_id {template_id} |")
    if template_id == "all":
        args.precomputed_data_root = os.path.join(args.precomputed_data_dir, template_src, "combined")
    else:
        args.precomputed_data_root = os.path.join(args.precomputed_data_dir, template_src, f"template_{template_id}")

    # Load dataset
    data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
    seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
    seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

    train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
    valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
    test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
    seen_mask = train_dataset.seen_mask

    # Load model
    model = load_model(template_src, template_id)

    # Prepare groundtruth labels
    seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

    if args.split == "valid":
        labels_attr = [sample["attr_id"] for sample in data.valid_data]
        labels_obj = [sample["obj_id"] for sample in data.valid_data]
        labels = [sample["pair_id"] for sample in data.valid_data]
        gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
        gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
    elif args.split == "test":
        labels_attr = [sample["attr_id"] for sample in data.test_data]
        labels_obj = [sample["obj_id"] for sample in data.test_data]
        labels = [sample["pair_id"] for sample in data.test_data]
        gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
        gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))

    gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
    gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
    gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
    gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
    print(f"Intervention features: {gt_features_concat.shape}")

    # Intervene!!!
    with torch.no_grad():
        logits, loss = model(gt_features_concat, pair_labels=None, is_train=False, is_intervene=True)
    features = logits.softmax(dim=-1).log()
    print("features: ", features.shape)

    # Compute metrics
    overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])
    output_path = get_output_path(template_src, template_id, args)
    print(output_path)
    with open(os.path.join(output_path, f"{args.split}_metrics_interv_gt.json"), "w") as f:
        json.dump(overall_metrics, f)
    print("\n\n")



| valid | template clip | template_id all |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
../../outputs/mit_states/clip_prompt_retrieval_model/clip_prompt_retrieval_model_TPclip_TIDall_open0_init1_imgproj1_bias0_primitive_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07/retrieval_model.ckpt
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360]

## 6.2. Interpretability for Interv(GTX)

In [58]:
# Interpretability for Interv(GTX)

class parseArguments:
    def __init__(self):
        self.split = "valid"
        self.is_open_world = False
        
        self.feature = "primitive"
        self.model = "clip"
        self.train_warmup = 0
        self.is_image_projection = True
        self.is_bias = False
        self.is_limit = True
        
        self.data_root = "../../data/mit_states"
        self.emb_root = "../../data"
        self.precomputed_data_dir = f"../../outputs/mit_states/clip_prompt_precompute_features"
        
        self.emb_init = True
        self.input_dim = 600
        self.num_epochs = 400
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.weight_decay = 5e-5
        self.logit_scale = 0.07

args = parseArguments()

# Start!!
template_count_dict = {"clip": 7, "compdl": 10}

for split in ["valid", "test"]:
    args.split = split
    template_src = "clip"
    template_id = "all"  # 3 or all
    
    print("=" * 70)
    print(f"| {args.split} | template {template_src} | template_id {template_id} |")
    if template_id == "all":
        args.precomputed_data_root = os.path.join(args.precomputed_data_dir, template_src, "combined")
    else:
        args.precomputed_data_root = os.path.join(args.precomputed_data_dir, template_src, f"template_{template_id}")

    # Load dataset
    data = MITStatesDataset(root=args.data_root, split="train")  # split can be ignored here
    seen_ids_valid, unseen_ids_valid = get_seen_unseen_indices("valid", data)
    seen_ids_test, unseen_ids_test = get_seen_unseen_indices("test", data)

    train_dataset = Precomputed_MITStatesDataset(split="train", feature="primitive", data=data, args=args, is_limit=True)
    valid_dataset = Precomputed_MITStatesDataset(split="valid", feature="primitive", data=data, args=args, is_limit=True)
    test_dataset = Precomputed_MITStatesDataset(split="test", feature="primitive", data=data, args=args, is_limit=True)
    seen_mask = train_dataset.seen_mask

    # Load model
    model = load_model(template_src, template_id)

    # Prepare groundtruth labels
    seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

    if args.split == "valid":
        labels_attr = [sample["attr_id"] for sample in data.valid_data]
        labels_obj = [sample["obj_id"] for sample in data.valid_data]
        labels = [sample["pair_id"] for sample in data.valid_data]
        gt_features_attr = np.zeros((valid_dataset.image_embs.shape[0], 115))
        gt_features_obj = np.zeros((valid_dataset.image_embs.shape[0], 245))
    elif args.split == "test":
        labels_attr = [sample["attr_id"] for sample in data.test_data]
        labels_obj = [sample["obj_id"] for sample in data.test_data]
        labels = [sample["pair_id"] for sample in data.test_data]
        gt_features_attr = np.zeros((test_dataset.image_embs.shape[0], 115))
        gt_features_obj = np.zeros((test_dataset.image_embs.shape[0], 245))

    gt_features_attr[np.arange(len(labels_attr)), labels_attr] = 1
    gt_features_obj[np.arange(len(labels_obj)), labels_obj] = 1
    gt_features_concat = np.concatenate([gt_features_attr, gt_features_obj], axis=-1)
    gt_features_concat = torch.tensor(gt_features_concat).to(device).double()
    print(f"Intervention features: {gt_features_concat.shape}")

    # Prepare "binary" primitive concept activations
    # Set GT concepts to 1
    seen_ids, unseen_ids = get_seen_unseen_indices(args.split, data)

    if args.split == "valid":
        labels_attr = [sample["attr_id"] for sample in data.valid_data]
        labels_obj = [sample["obj_id"] for sample in data.valid_data]
        labels = [sample["pair_id"] for sample in data.valid_data]
        features_interv = valid_dataset.image_embs
        features_interv = torch.tensor(features_interv).to(device).double()
        features_interv = F.normalize(features_interv, dim=1)

    elif args.split == "test":
        labels_attr = [sample["attr_id"] for sample in data.test_data]
        labels_obj = [sample["obj_id"] for sample in data.test_data]
        labels = [sample["pair_id"] for sample in data.test_data]
        features_interv = test_dataset.image_embs
        features_interv = torch.tensor(features_interv).to(device).double()
        features_interv = F.normalize(features_interv, dim=1)

    features_interv = torch.where(gt_features_concat==1, 1.0, features_interv)
    print(f"Intervention features: {features_interv.shape}")

    # Intervene!!!
    with torch.no_grad():
        logits, loss = model(features_interv, pair_labels=None, is_train=False, is_intervene=True)
    features = logits.softmax(dim=-1).log()
    print("features: ", features.shape)

    # Compute metrics
    overall_metrics = get_overall_metrics(features, labels, seen_ids, unseen_ids, seen_mask, data, topk_list=[1,2,3])
    output_path = get_output_path(template_src, template_id, args)
    with open(os.path.join(output_path, f"{args.split}_metrics_interv_gt.json"), "w") as f:
        json.dump(overall_metrics, f)
    print("\n\n")



| valid | template clip | template_id all |
train pairs: 1262 | valid pairs: 600 | test pairs: 800
train images: 30338 | valid images: 10420 | test images: 12995
seen_indices: 1844 | unseen_indices: 8576
seen_indices: 2380 | unseen_indices: 10615
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
attr 	| train (30338, 115) 	| valid (10420, 115) 	| test (12995, 115)
obj 	| train (30338, 245) 	| valid (10420, 245) 	| test (12995, 245)
image_dim:    360
../../outputs/mit_states/clip_prompt_retrieval_model/clip_prompt_retrieval_model_TPclip_TIDall_open0_init1_imgproj1_bias0_primitive_Lim1_N400_TW0_B128_LR5e-05_WD5e-05_L0.07/retrieval_model.ckpt
seen_indices: 1844 | unseen_indices: 8576
Intervention features: torch.Size([10420, 360]