In [None]:
import torch
import torch
import numpy as np
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import pickle
from tqdm import tqdm
import sys
sys.path.append("../")
from model_utils import get_model_parts
from argparse import Namespace 
from concept_utils import get_concept_scores_mv_valid, ConceptBank, EasyDict
from skimage import io
import glob


print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
data_root = "/path/dataset/metadataset/MetaDataset/subsets"
im2node = "/path/metadataset-concepts/MetaDataset-Distribution-Shift/generate_dataset/meta_data/img_to_node.pkl"
with open(im2node, "rb") as f: 
    im2node = pickle.load(f)

args = Namespace()
args.model_name = "resnet"
args.input_size = 224
args.batch_size = 8
args.SEED = 4
args.num_workers = 4
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.bank_path = '/path/conceptual-explanations/banks/concept_resnet_170.pkl'
mean_pxs = np.array([0.485, 0.456, 0.406])
std_pxs = np.array([0.229, 0.224, 0.225])

data_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(args.input_size),
    transforms.CenterCrop(args.input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean_pxs, std_pxs)
])

experiment_root = f'/path/outputs/conceptualexplanations/metadataset/{args.model_name}'

In [None]:
control_folder = os.path.join(experiment_root, f"control")
experiments = os.listdir(experiment_root)

exp_concepts = ["sand", "paper", "water", "snow", "bed", "keyboard", "cabinet",
               "carpet", "horse", "door", "paper", "fence", "tree",
                "computer", "grass", "branch", "car", "building", "plate",
                "bush", "book"]

all_concepts = pickle.load(open(args.bank_path, 'rb'))
concept_bank = ConceptBank(all_concepts, args.device)

In [None]:
class DriftDataset():
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        # Return the length of the dataset
        return len(self.images)
    
    def __getitem__(self, idx):
        # Return the observation based on an index. Ex. dataset[0] will return the first element from the dataset, in this case the image and the label.
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = self.images[idx]
        image = io.imread(img_path)
        if self.transform:
            image = self.transform(image)
        return image
    
def get_drift_dataset(experiment_folder, data_root, drift_dist, exclude_concepts, class_name, transforms, take_diff=True,
                     seed=1, n_max_images=50):
    
    drift_dist_args = drift_dist.split("(")
    print(drift_dist_args)
    
    if len(drift_dist_args) > 1:
        drift_folder = os.path.join(data_root, drift_dist_args[0], drift_dist)
        drift_ims = glob.glob(os.path.join(drift_folder, "*.jpg"), recursive=False)
    else:
        drift_folder = os.path.join(data_root, drift_dist_args[0])
        drift_ims = glob.glob(os.path.join(drift_folder, "**/*.jpg"), recursive=True)
    
    train_folder = os.path.join(experiment_folder, "train", class_name)
    train_ims = os.listdir(train_folder)
    common_ims = set([im_path for im_path in drift_ims if (im_path.split("/")[-1] in train_ims)])

    exclude_concept_ims = []
    print(f"Excluding images for : {exclude_concepts}")
    for c in exclude_concepts:
        if c == drift_dist:
            continue
        concept_folder = os.path.join(data_root, c.split("(")[0], c)
        concept_ims = os.listdir(concept_folder)
        exclude_concept_ims.extend(concept_ims)
        print(len(concept_ims))
    
    exclude_ims = set([im_path for im_path in drift_ims if im_path.split("/")[-1] in exclude_concept_ims])
    if take_diff:
        drift_ims = list(set(drift_ims).difference(common_ims))
    
    drift_ims = list(set(drift_ims).difference(exclude_ims))
    #print(f"{len(drift_ims)} test images for {drift_dist}. Common ims w/training: {len(common_ims)}. Others: {len(exclude_ims)}")
    np.random.seed(seed)
    np.random.shuffle(drift_ims)
    drift_ims = drift_ims[:n_max_images]
    drift_ims = [path for path in drift_ims if len(io.imread(path).shape) > 2]
    drift_ds = DriftDataset(drift_ims, transforms)
    print(f"Final image count: {len(drift_ims[:n_max_images])}")
    return drift_ds


In [None]:
eval_modes = ["mistakes"]
batch_sizes = [1]
experiments = ["bird(sand)", "dog(snow)", "dog(bed)", "dog(water)", "dog(horse)", "cat(cabinet)", "cat(bed)"]


all_out = []
# Control model
control_folder = os.path.join(experiment_root, "control")
model_control = torch.load(open(os.path.join(control_folder, "result", "confounded-model.pt"), "rb"))
model_bottom_control, model_top_control = get_model_parts(model_control, args.model_name)
model_bottom_control.eval()
model_top_control.eval()
all_concept_names = concept_bank.concept_names.copy()

for experiment in experiments:

    if len(experiment.split("-")) > 1:
        continue
    experiment_folder = os.path.join(experiment_root, experiment)
    try:
        with open(os.path.join(experiment_folder, "result", "concept_config.pkl"), "rb") as f:
            concept_config = pickle.load(f)
    except Exception as e:
        print(e, experiment)
        continue
    true_concepts = []
    for cl in concept_config["in_distributions"]:
        if "(" in cl:
            class_name = cl.split("(")[0]
            true_concept = cl.split("(")[-1][:-1]    
            if true_concept == 'branch':
                true_concept = 'tree'
            true_concepts.append(true_concept)
    
    if len(true_concepts) != 1:
        continue
    
    are_invalid_concepts = np.array([((c not in exp_concepts) or (c not in concept_bank.concept_names)) for c in true_concepts])
    if np.any(are_invalid_concepts):
        continue
    
    dists_to_remove = [c for c in concept_config["in_distributions"] if "(" in c]
    print(f"Animal: {class_name}, True Concept: {true_concepts}")
    
    
    # Get models 
    model_ft = torch.load(open(os.path.join(experiment_folder, "result", "confounded-model.pt"), "rb"))
    # Model to test
    model_bottom, model_top = get_model_parts(model_ft, args.model_name)
    model_bottom.eval()
    model_top.eval()
    
    train_dataset = datasets.ImageFolder(os.path.join(experiment_folder, "train"), data_transforms)
    class_idx = train_dataset.class_to_idx[class_name]
    labels_orig = torch.tensor(class_idx).long().view(1).to(args.device)        
    for batch_size in batch_sizes:
        drift_ds = get_drift_dataset(experiment_folder, data_root, class_name, dists_to_remove,
                                     class_name, data_transforms, take_diff=True, seed=3)
        
        dataloaders_drift = torch.utils.data.DataLoader(drift_ds, batch_size=50, shuffle=False, num_workers=args.num_workers)
        for eval_mode in eval_modes:
            batch_inputs = None
            batch_labels = None
            for inputs in dataloaders_drift: 
                inputs = inputs.to(args.device)
                labels = labels_orig.repeat(inputs.shape[0])
                label_mask = (labels == class_idx)
                if label_mask.float().sum() < 1:
                    continue
                if eval_mode == "mistakes":
                    preds = model_top(model_bottom(inputs)).argmax(dim=1)
                    label_mask = (label_mask & (preds.squeeze() != labels.squeeze()))
                    if label_mask.float().sum() < 1:
                        continue
                
                if batch_inputs is None:
                    batch_inputs, batch_labels = inputs[label_mask], labels[label_mask]
                else:
                    batch_inputs = torch.cat([batch_inputs, inputs[label_mask]], dim=0)
                    batch_labels = torch.cat([batch_labels, labels[label_mask]], dim=0)
                                              
            score_dict = {name: [] for name in concept_bank.concept_names}
            control_score_dict = {name: [] for name in concept_bank.concept_names}
            random_score_dict = {name: [] for name in concept_bank.concept_names}
            sample_size = max(batch_inputs.shape[0] - (batch_inputs.shape[0] % batch_size), (batch_inputs.shape[0] % batch_size))
            tqdm_iterator = tqdm(range(max(1, sample_size // batch_size)))
            
            for k in tqdm_iterator:                
                batch_X = batch_inputs[k*batch_size : (k+1)*batch_size]
                batch_Y = batch_labels[k*batch_size : (k+1)*batch_size]

                opt_result = get_concept_scores_mv_valid(batch_X, batch_Y, 
                                                         concept_bank, 
                                                         model_bottom, model_top,
                                                         alpha=1e-2, beta=1e-1, lr=1e-2,
                                                         enforce_validity=True, momentum=0.9)
                
                control_opt_result = get_concept_scores_mv_valid(batch_X, batch_Y, 
                                                         concept_bank, 
                                                         model_bottom_control, model_top_control,
                                                         alpha=1e-2, beta=1e-1, lr=1e-2,
                                                         enforce_validity=True, momentum=0.9)
                np.random.shuffle(all_concept_names)
                random_result = EasyDict({
                    "concept_scores_list": all_concept_names
                    
                })
                for i, name in enumerate(opt_result.concept_scores_list): 
                    score_dict[name].append(i)
                for i, name in enumerate(control_opt_result.concept_scores_list): 
                    control_score_dict[name].append(i)
                for i, name in enumerate(random_result.concept_scores_list): 
                    random_score_dict[name].append(i)
                score_arr = np.array(score_dict[true_concepts[0]])
                control_score_arr = np.array(control_score_dict[true_concepts[0]])
                desc = f"Mean: {np.mean(score_arr):.2f}, Top3:{(score_arr < 3).mean()}"
                tqdm_iterator.set_description(desc)
            
            print(score_dict[true_concepts[0]])
            print(f"Top3: {(np.array(score_dict[true_concepts[0]])<3).mean()}")
            mean_score_dict = {}
            for k in score_dict.keys(): 
                mean_score_dict[k] = np.mean(score_dict[k])
            
            ordered = sorted(concept_bank.concept_names, key=mean_score_dict.get)
            control_score = (np.array(control_score_dict[true_concepts[0]])<3).mean()
            random_score = (np.array(random_score_dict[true_concepts[0]])<3).mean()
            ces_top3_score =  (np.array(score_dict[true_concepts[0]])<3).mean()
            
            all_out.append({
                "experiment": experiment,
                "target": true_concepts,
                "batch_size": batch_size,
                "eval_mode": eval_mode,
                "sample_size": sample_size,
                "top5_concepts": ordered[:5],
                "overall_rank": ordered.index(true_concept)+1,
                "Top3": ces_top3_score,
                "Top3-Control": control_score,
                "Top3-Random": random_score
            })
            plt.figure(figsize=(7, 5))
            x_names = ['Random', 'CCE (Control)', 'CCE']
            y_vals = [3/150, control_score, ces_top3_score]
            plt.bar(x_names, y_vals, color=['silver','gray', 'black'])
            plt.yticks(fontname='DejaVu Sans', fontsize=20)
            plt.xticks(fontname='DejaVu Sans', fontsize=20)
            plt.ylim([0, 1])
            plt.ylabel('Fraction of samples\n concept is in top 3', fontname='DejaVu Sans', fontsize=20)
            plt.title('"{}" correlated with "{}"'.format(class_name.title(), true_concepts[0].title()), fontsize=20, fontname='Arial')
            plt.tight_layout()
            plt.savefig(f"./paper_figures/figure3/fig3_{experiment}_{args.model_name}.png")
            plt.savefig(f"./paper_figures/figure3/fig3_{experiment}_{args.model_name}.pdf")
            plt.show()
            