In [None]:
import sys
sys.path.append("../")

import os
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
from os.path import join as oj
def add_path(path):
    if path not in sys.path:
        sys.path.insert(0, path)

add_path(os.path.abspath('..'))
from pycls.datasets.sampler import IndexedSequentialSampler
from torchvision import transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from collections import Counter

from utils.analysis import get_attention_new, attention_annotation_new, faiss_k_means, get_time, sample_from_group, compute_optimal_clusters, plot_clusters
from utils.utils import compute_sampling_probability
from utils.custom_data import CustomDataset

# ANSI escape codes
RED = "\033[31m"
BLUE = "\033[34m"
RESET = "\033[0m"

random_seed = 0 ### for reproductivity

### Load the dataset and reference model

In [None]:
dataset_name = 'waterbirds'
cur_exp = 'sol_1' ### sol_1 means using training set (compare slim_train) ### sol_2 means using validation set (slim_val)

dataset = 'waterbirds'
base_dir = "[waterbird-data-dir]"
model_path = '[trained-model-path].pt'
out_dir = "../outputs/waterbirds"

data_flag = 'train' ### train or val
img_dir = f'{base_dir}/waterbird_complete95_forest2water2'

total_budget = 120
att_budget_dict={0: 20, 1: 20}

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
model = getattr(models, 'resnet50')(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
data_transformer = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(mean=mean, std=std),
                                        transforms.Resize((224, 224))
                                        ])

# Data
trainset = CustomDataset(basedir=base_dir, split="train", transform=data_transformer, attention=True)
valset = CustomDataset(basedir=base_dir, split="val", transform=data_transformer, attention=True)
dataset = trainset
total_df = trainset.medadata

In [None]:
model.load_state_dict(torch.load(model_path, map_location="cpu"))
DEVICE="cuda:0"
model.cuda(DEVICE)
all_true_idx_set = np.array(range(len(dataset)))
subsetSampler = IndexedSequentialSampler(all_true_idx_set)
loader = DataLoader(dataset=dataset, batch_size=1, sampler=subsetSampler, shuffle=False)
print("data size: ", len(dataset))

### Generate GradCAM and required feature representations
##### GradCAM and feature vector generation-related steps can be skipped if we already have them

In [None]:
img_file_name_list = dataset.filename_array.tolist()
img_name_list = [img_file_name_list[i].split('/')[-1] for i in range(len(img_file_name_list))]
auto_att_scores, all_entropy_scores, total_prediction, total_label, img_names, img_paths, att_rep_embeddings, all_cam_masks, wfv, op_wfv = get_attention_new(model, loader, img_dir, out_dir, data_flag, dataset_name, device=DEVICE)

In [None]:
total_label = dataset.y_array
total_place = dataset.p_array
meta_info_str = np.array([f"label_{total_label[i]}.place_{total_place[i]}" for i in range(len(total_label))])
meta_subgroup_array = (total_label * 2 + total_place).astype('int')
confusion_matrix_str = []
for prediction, label in zip(total_prediction, total_label):
    if prediction == label:
        if prediction == 1:  # Assuming 1 is the positive class
            confusion_matrix_str.append("TP")  # True Positive
        else:
            confusion_matrix_str.append("TN")  # True Negative
    else:
        if prediction == 1:
            confusion_matrix_str.append("FP")  # False Positive
        else:
            confusion_matrix_str.append("FN")  # False Negative
confusion_matrix_str = np.array(confusion_matrix_str)
meta_colors = ['steelblue', 'gold', 'crimson', 'forestgreen']
unique_meta_groups = sorted(list(set(meta_info_str)))
meta_to_color = dict(zip(unique_meta_groups, meta_colors))
cm_colors = ['deepskyblue', 'tomato', 'mediumorchid', 'darkkhaki']
unique_cm_groups = sorted(list(set(confusion_matrix_str)))
cm_to_color = dict(zip(unique_cm_groups, cm_colors))

In [None]:
plt.figure(figsize=(15, 7))
plt.scatter(att_rep_embeddings[:, 0], att_rep_embeddings[:, 1], s=6)
plt.title(f"att_rep_embeddings ({len(att_rep_embeddings)})")
plt.show()
plt.close()

### An easy (less visual appealing) way for human annotation
##### Sample data and a human can provide binary annotation results through text input
Alternative: an interactive interface showing one image at a time with a binary selection

In [None]:
current_true_index = all_true_idx_set
current_true_index_name = 'all'

#### Below cell conducts data sampling for annotation, and prompts text: 
* "Image grid xxx that require attention annotation is saved at: xxx."
* "Annotate the attention correctness of each image in order (0: wrong; 1: correct; -1: not sure)."
* "Please separate by ', '"

A human can then provide annotations via text input

In [None]:
att_annotation_flag = f'{cur_exp}_' + current_true_index_name
print("att_budget_dict: ", att_budget_dict)
print("att_annotation_flag: ", att_annotation_flag)
if att_annotation_flag is not None:
    annotated_att_scores = attention_annotation_new(auto_att_scores[current_true_index], img_names[current_true_index], total_label[current_true_index], total_prediction[current_true_index], img_paths[current_true_index], att_rep_embeddings[current_true_index], out_dir, att_annotation_flag, att_budget_dict=att_budget_dict, gamma=[10, 5])

In [None]:
### Use automatic/annotated attention scores
all_att_scores = annotated_att_scores
total_wfv_embeddings = np.load(oj(out_dir, f'{data_flag}_total_weighted_fv_embedding.npy'))
total_op_wfv_embeddings = np.load(oj(out_dir, f'{data_flag}_opposite_weighted_fv_embedding.npy'))
wfv = np.load(oj(out_dir, f'{data_flag}_total_weighted_feature_vector.npy'))
op_wfv = np.load(oj(out_dir, f'{data_flag}_opposite_weighted_feature_vector.npy'))
print("total_wfv_embeddings_x: ", np.min(total_wfv_embeddings[:, 0]), np.max(total_wfv_embeddings[:, 0]))
print("total_wfv_embeddings_y: ", np.min(total_wfv_embeddings[:, 1]), np.max(total_wfv_embeddings[:, 1]))
print("total_op_wfv_embeddings_x: ", np.min(total_op_wfv_embeddings[:, 0]), np.max(total_op_wfv_embeddings[:, 0]))
print("total_op_wfv_embeddings_y: ", np.min(total_op_wfv_embeddings[:, 1]), np.max(total_op_wfv_embeddings[:, 1]))

In [None]:
temp_index = np.where(all_att_scores > np.median(all_att_scores))[0]
print(all_att_scores.shape, all_entropy_scores.shape, temp_index.shape)
true_index = current_true_index[temp_index]
current_true_index_name += '_higher_attention'
print(true_index.shape)
att_annotation_flag = f'{cur_exp}_' + current_true_index_name
print(att_annotation_flag)

#### Some additional visualizations

In [None]:
print(wfv[true_index].shape, total_wfv_embeddings[true_index].shape)
group_colors = np.array([meta_to_color[group] for group in meta_info_str[true_index]])
plt.figure(figsize=(15, 7))
plt.scatter(total_wfv_embeddings[true_index, 0], total_wfv_embeddings[true_index, 1], c=group_colors, s=6)
for group in unique_meta_groups:
    plt.scatter([], [], color=meta_to_color[group], label=group)
plt.legend(title='Meta Subgroups')
plt.show()

group_colors = np.array([cm_to_color[group] for group in confusion_matrix_str[true_index]])
plt.figure(figsize=(15, 7))
plt.scatter(total_wfv_embeddings[true_index, 0], total_wfv_embeddings[true_index, 1], c=group_colors, s=6)
for group in unique_cm_groups:
    plt.scatter([], [], color=cm_to_color[group], label=group)
plt.legend(title='Confusion Matrix')
plt.show()

In [None]:
print(wfv[true_index].shape, total_wfv_embeddings[true_index].shape)
group_colors = np.array([meta_to_color[group] for group in meta_info_str[true_index]])
plt.figure(figsize=(15, 7))
plt.scatter(total_op_wfv_embeddings[true_index, 0], total_op_wfv_embeddings[true_index, 1], c=group_colors, s=6)
for group in unique_meta_groups:
    plt.scatter([], [], color=meta_to_color[group], label=group)
plt.legend(title='Meta Subgroups')
plt.show()

group_colors = np.array([cm_to_color[group] for group in confusion_matrix_str[true_index]])
plt.figure(figsize=(15, 7))
plt.scatter(total_op_wfv_embeddings[true_index, 0], total_op_wfv_embeddings[true_index, 1], c=group_colors, s=6)
for group in unique_cm_groups:
    plt.scatter([], [], color=cm_to_color[group], label=group)
plt.legend(title='Confusion Matrix')
plt.show()

### Data subset curation
#### Based on spuriousness score propagated according to human annotation

In [None]:
current_time_str = get_time()
current_out_dir = oj(out_dir, f"{cur_exp}_{current_time_str}_sampling")

constructed_set_name = 'constructed_set'
print(constructed_set_name)

In [None]:
os.makedirs(current_out_dir, exist_ok=True)
print(f"save results in {current_out_dir}")

current_wfv_embeddings = total_wfv_embeddings[true_index]
current_op_wfv_embeddings = total_op_wfv_embeddings[true_index]

current_wfv = wfv[true_index]
current_op_wfv = op_wfv[true_index]
current_label = total_label[true_index]
print(true_index.shape)
print(current_wfv.shape, current_wfv_embeddings.shape)
print(current_op_wfv.shape, current_op_wfv_embeddings.shape)

#### Find the optimnal n_clusters

In [None]:
cluster_options = [2,3,4,5]
eval_metric = 'elbow' ### ['silhouette', 'elbow', 'davies_bouldin']
optimal_wfv_clusters, best_wfv_score, optimal_wfv_cluster_labels = compute_optimal_clusters(current_wfv_embeddings, cluster_options, eval_metric)
print(f"Org - Optimal number of clusters: {optimal_wfv_clusters} with a {eval_metric} score of {best_wfv_score}")
plot_clusters(current_wfv_embeddings, optimal_wfv_cluster_labels)

In [None]:
cluster_options = [2,3,4,5]
optimal_op_wfv_clusters, best_op_wfv_score, optimal_op_wfv_cluster_labels = compute_optimal_clusters(current_op_wfv_embeddings, cluster_options, eval_metric)
print(f"Rev - Optimal number of clusters: {optimal_op_wfv_clusters} with a {eval_metric} score of {best_op_wfv_score}")
plot_clusters(current_op_wfv_embeddings, optimal_op_wfv_cluster_labels)

### Apply clustering on feature representation and reverse feature representation spaces, respectively
##### Use the optimal n_clusters for each

In [None]:
cluster_1_kmeans_model, cluster_1_labels = faiss_k_means(current_wfv_embeddings, optimal_wfv_clusters)
cluster_2_kmeans_model, cluster_2_labels = faiss_k_means(current_op_wfv_embeddings, optimal_op_wfv_clusters)
plot_clusters(current_wfv_embeddings, cluster_1_labels)
plot_clusters(current_op_wfv_embeddings, cluster_2_labels)

In [None]:
final_cluster_labels = (cluster_1_labels * optimal_op_wfv_clusters + cluster_2_labels).astype('int')
final_cluster_label_count = len(np.unique(final_cluster_labels))
final_cluster_label_counter = Counter(final_cluster_labels)
print("true_index", true_index.shape)
print("cluster_1_labels", cluster_1_labels.shape, Counter(cluster_1_labels))
print("cluster_2_labels", cluster_2_labels.shape, Counter(cluster_2_labels))
print("final_cluster_labels", final_cluster_labels.shape, Counter(final_cluster_labels))
print("current_wfv_embeddings", current_wfv_embeddings.shape)
print("current_op_wfv_embeddings", current_op_wfv_embeddings.shape)
np.save(oj(current_out_dir, 'true_index.npy'), true_index)
np.save(oj(current_out_dir, 'cluster_1_labels.npy'), cluster_1_labels)
np.save(oj(current_out_dir, 'cluster_2_labels.npy'), cluster_2_labels)
np.save(oj(current_out_dir, 'final_cluster_labels.npy'), final_cluster_labels)
np.save(oj(current_out_dir, 'current_wfv_embeddings.npy'), current_wfv_embeddings)
np.save(oj(current_out_dir, 'current_op_wfv_embeddings.npy'), current_op_wfv_embeddings)

In [None]:
org_size = [final_cluster_label_counter[i] for i in range(final_cluster_label_count)]
sampling_weights = np.zeros(optimal_op_wfv_clusters * optimal_op_wfv_clusters)
cluster_1_prob = compute_sampling_probability(current_wfv_embeddings, cluster_1_labels)
print(f"Cluster 1 sampling probability: {cluster_1_prob}")
sampling_size = np.zeros(optimal_op_wfv_clusters * optimal_op_wfv_clusters, dtype=int)
t_budget = 0
for i in range(optimal_op_wfv_clusters):
    temp_index = np.where(cluster_1_labels == i)[0]
    cluster_2_prob = compute_sampling_probability(current_op_wfv_embeddings[temp_index], cluster_2_labels[temp_index])
    print(f"Cluster 1 label: {i} | Corresponding cluster 2 sampling probability: {cluster_2_prob}")
    for j in range(optimal_op_wfv_clusters):
        final_label = i * optimal_op_wfv_clusters + j
        final_label_instance_count = final_cluster_label_counter[final_label]
        final_probability = cluster_1_prob[i] * cluster_2_prob[j]  ####
        count = int(final_probability * total_budget)
        print(f"final cluster: {final_label} (clustering_1 label: {i} | clustering_2 label: {j}) | count: {final_label_instance_count} | sampling power({final_probability}) * budget({total_budget}={count})")
        sampling_weights[final_label] = final_probability
        sampling_size[final_label] = count
        index = np.where(final_cluster_labels==final_label)[0]
        t_index = true_index[index]
sampling_weights_sum = np.sum(sampling_weights)
sampling_weights = [w / sampling_weights_sum for w in sampling_weights]
sampling_size = [int(total_budget * w) for w in sampling_weights]
print("sampling_weights: ", sampling_weights, np.sum(sampling_weights))

print("org_size: ", org_size)
print("sampling_size: ", sampling_size)
print("total sampling: ", np.sum(sampling_size))

In [None]:
n_final_clusters = np.max(final_cluster_labels) + 1 
print("n_final_clusters: ", n_final_clusters)
cluster_group_idx = []
for cur_final_cluster_label in range(n_final_clusters):
    temp = np.where(final_cluster_labels == cur_final_cluster_label)[0]
    if len(temp) != 0:
        cluster_group_idx.append(temp)
balanced_index = []
for group_idx, t_sampling_size in zip(cluster_group_idx, sampling_size):
    t_sampling_size = int(t_sampling_size)
    sampled_group_idx = sample_from_group(group_idx, t_sampling_size, group_vector=current_op_wfv_embeddings[group_idx], seed=random_seed)
    if len(sampled_group_idx) > len(group_idx):
        print(f"{BLUE}sample {len(sampled_group_idx)} from {len(group_idx)}{RESET}. Duplicated samples: {len(sampled_group_idx) - len(group_idx)}")
    else:
        print(f"{BLUE}sample {len(sampled_group_idx)} from {len(group_idx)}{RESET}")
    balanced_index.append(sampled_group_idx)
balanced_index = np.concatenate(balanced_index)
constructed_set_name += f"_{len(balanced_index)}"
print("constructed_set_name: ", constructed_set_name)
print(f"In total, sampled {len(balanced_index)} from {len(true_index)}: ")
constructed_set = true_index[balanced_index]
print(f'{RED}', len(constructed_set), Counter(meta_info_str[constructed_set]), f'{RESET}')
remove_dup_constructed_set = set(constructed_set)
remove_dup_constructed_set = np.array(list(remove_dup_constructed_set))
print(f'{RED} Duplicated samples: ',len(constructed_set) - len(remove_dup_constructed_set), f'{RESET}')
with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
    f.write(f"random_seed: {random_seed}\n")
    f.write(f"sampling_weights: {sampling_weights}\n")
    f.write(f"sampling_size: {sampling_size}\n")
    f.write(f"Duplicated samples: {len(constructed_set) - len(remove_dup_constructed_set)}\n")
    f.write(f"true_index: {true_index.shape}\n")
    f.write(f"total_budget: {total_budget}\n")
    f.write(f"optimal_op_wfv_clusters: {optimal_op_wfv_clusters}\n")
    f.write(f"optimal_op_wfv_clusters: {optimal_op_wfv_clusters}\n\n\n")

In [None]:
print(f"\n\ncluster_1_labels {cluster_1_labels.shape}")
with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
    f.write(f"\n\ncluster_1_labels {cluster_1_labels.shape}\n")
cur_unique_groups = list(set(cluster_1_labels))
for group in cur_unique_groups:
    index = np.where(cluster_1_labels==group)[0]
    t_index = true_index[index]
    print(group, index.shape, '\t', Counter(meta_info_str[t_index]), '\t', Counter(confusion_matrix_str[t_index]))
    sampled_in_current_subgroup = []
    for i in constructed_set:
        if i in t_index:
            sampled_in_current_subgroup.append(i)
    if len(sampled_in_current_subgroup) != 0:
        sampled_in_current_subgroup = np.array(sampled_in_current_subgroup)
        print(BLUE + '  sampled: ', sampled_in_current_subgroup.shape, Counter(meta_info_str[sampled_in_current_subgroup]), f'{RESET}')
        print(BLUE + '\t\t\t', Counter(confusion_matrix_str[sampled_in_current_subgroup]), f'{RESET}')
        with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
            f.write(f"{group}, {index.shape}\n")
            f.write(f"\t{Counter(meta_info_str[t_index])}\t{Counter(confusion_matrix_str[t_index])}\n")
            f.write(f"\t{Counter(meta_info_str[sampled_in_current_subgroup])}\t{Counter(confusion_matrix_str[sampled_in_current_subgroup])}\n")
    else:
        print(BLUE + '  sampled: ', len(sampled_in_current_subgroup), f'{RESET}')
        with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
            f.write(f"{group}, {index.shape} | sampled: {len(sampled_in_current_subgroup)}\n")

print(f"\n\ncluster_2_labels {cluster_2_labels.shape}")
with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
    f.write(f"\n\ncluster_2_labels {cluster_2_labels.shape}\n")
cur_unique_groups = list(set(cluster_2_labels))
for group in cur_unique_groups:
    index = np.where(cluster_2_labels==group)[0]
    t_index = true_index[index]
    print(group, index.shape, '\t', Counter(meta_info_str[t_index]), '\t', Counter(confusion_matrix_str[t_index]))
    sampled_in_current_subgroup = []
    for i in constructed_set:
        if i in t_index:
            sampled_in_current_subgroup.append(i)
    if len(sampled_in_current_subgroup) != 0:
        sampled_in_current_subgroup = np.array(sampled_in_current_subgroup)
        print(BLUE + '  sampled: ', sampled_in_current_subgroup.shape, Counter(meta_info_str[sampled_in_current_subgroup]), f'{RESET}')
        print(BLUE + '\t\t\t', Counter(confusion_matrix_str[sampled_in_current_subgroup]), f'{RESET}')
        with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
            f.write(f"{group}, {index.shape}\n")
            f.write(f"\t{Counter(meta_info_str[t_index])}\t{Counter(confusion_matrix_str[t_index])}\n")
            f.write(f"\t{Counter(meta_info_str[sampled_in_current_subgroup])}\t{Counter(confusion_matrix_str[sampled_in_current_subgroup])}\n")
    else:
        print(BLUE + '  sampled: ', len(sampled_in_current_subgroup), f'{RESET}')
        with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
            f.write(f"{group}, {index.shape} | sampled: {len(sampled_in_current_subgroup)}\n")

print(f"\n\nfinal {len(final_cluster_labels)}")
with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
    f.write(f"\n\nfinal\n")
cur_unique_groups = list(set(final_cluster_labels))
for group in cur_unique_groups:
    index = np.where(final_cluster_labels==group)[0]
    t_index = true_index[index]
    print(group, index.shape, '\t', Counter(meta_info_str[t_index]), '\t', Counter(confusion_matrix_str[t_index]))
    sampled_in_current_subgroup = []
    for i in constructed_set:
        if i in t_index:
            sampled_in_current_subgroup.append(i)
    if len(sampled_in_current_subgroup) != 0:
        sampled_in_current_subgroup = np.array(sampled_in_current_subgroup)
        print(BLUE + '  sampled: ', sampled_in_current_subgroup.shape, Counter(meta_info_str[sampled_in_current_subgroup]), f'{RESET}')
        print(BLUE + '\t\t\t', Counter(confusion_matrix_str[sampled_in_current_subgroup]), f'{RESET}')
        with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
            f.write(f"{group}, {index.shape}\n")
            f.write(f"\t{Counter(meta_info_str[t_index])}\t{Counter(confusion_matrix_str[t_index])}\n")
            f.write(f"\t{Counter(meta_info_str[sampled_in_current_subgroup])}\t{Counter(confusion_matrix_str[sampled_in_current_subgroup])}\n")
    else:
        print(BLUE + '  sampled: ', len(sampled_in_current_subgroup), f'{RESET}')
        with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
            f.write(f"{group}, {index.shape} | sampled: {len(sampled_in_current_subgroup)}\n")

constructed_set_path = os.path.join(current_out_dir, f"{constructed_set_name}.npy")
np.save(constructed_set_path, constructed_set)

label_counter = Counter(total_label[constructed_set])
class_weight = {}
for l in np.unique(total_label):
    w = 1 / label_counter[l] / (1 / label_counter[0])
    class_weight[l] = w
    print(f'label: {l} | count: {label_counter[l]} | weight: {w}')
print("\n", "class weight: ",  class_weight)

with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
    f.write(f"Index of length {len(constructed_set)} saved at {constructed_set_path}\n")
print(f"Index of length {len(constructed_set)} saved at {constructed_set_path}")
with open(os.path.join(current_out_dir, f"{constructed_set_name}_info.txt"), 'a') as f:
    f.write(f"{constructed_set.shape}\n\n")
    f.write(f"{Counter(meta_info_str[constructed_set])}\n\n")
    f.write(f"{label_counter}\n\n")
    f.write(f"{class_weight}\n\n")
    
constructed_set_path, constructed_set.shape, Counter(meta_info_str[constructed_set]), class_weight