Drift Compression Experiment - See paper Section 4.1 for more details


Implementation outline for Experiment
1. read in data
2. embed data with pretrained model
3. run NMF over embeddings and get the nmf representation of the input
4. train the drift localizer(DL) on the nmf reps
5. estimate global importance of drift localizer
6. estimate local importance of each input
7. run different h /tilde models over the local importances and compare to DL


In [1]:
from datasets import load_dataset
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from concept_helpers.DeepView_Craft import CraftTorchDV as Craft
from concept_helpers.DeepView_Craft import CraftTorchSupervised as CraftS
from concept_helpers.combined_crafts import CombinedCrafts

import urllib.request
import glob
import torch
import torch.nn as nn
from torchvision import transforms
import timm

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from scipy.sparse.linalg import eigs
from sklearn.ensemble import  RandomForestClassifier

import numpy as np
import matplotlib.pyplot as plt

from sklearn.decomposition import NMF
from sklearn.metrics import accuracy_score
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

import random

from xplique.concepts.craft import BaseCraft, DisplayImportancesOrder, Factorization, Sensitivity
from sklearn.decomposition import non_negative_factorization
from experiment_helpers.helper_function import *
from experiment_helpers.driftLocalizer import Localizer


device = 'cuda'

# loading any timm model
model = timm.create_model('nf_resnet50.ra2_in1k', pretrained=True)
model = model.to(device)

# processing
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
to_pil = transforms.ToPILImage()

# cut the model in twop arts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model
g = nn.Sequential(*(list(model.children())[:4]))  # input to penultimate layer
h = nn.Sequential(*(list(model.children())[4:]))  # penultimate layer to logits


with urllib.request.urlopen('https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt') as f:
        imagenet_class_names = np.array(f.read().decode('utf-8').split('\n'))

def gen_images(filelist,folder_names,folder_name2class_id):
        for f in filelist:
            # print(f)
            folder_name = f.split('/')[-2]
            if folder_name in folder_names:
                # print(folder_name)
                class_id = folder_name2class_id[folder_name]
                im = Image.open(f)
                if len(im.getbands()) == 3:
                    yield np.array(im.resize((224, 224))), class_id

ood_folder = 'data/ninco_data/NINCO/NINCO_OOD_classes'

ood_folder_names = ['french_fries','donuts','waffles','glass_of_milk','cup_cakes','chicken_quesadilla']#, 'donuts'] #'grey_fox', 'Arctic_fox']
ood_class_names = ['french_fries','donuts','waffles','glass_of_milk','cup_cakes','chicken_quesadilla']#, 'donuts'] #'grey fox', 'Arctic fox']


ood_class_ids = [1001 + i for i,class_name in enumerate(ood_class_names)]
ood_folder_name2class_id = dict(zip(ood_folder_names, ood_class_ids))
ood_filelist = glob.glob(f'{ood_folder}/*/*.jpg')
# print(ood_filelist)

ood_images, ood_labels = zip(*gen_images(ood_filelist,ood_folder_names,ood_folder_name2class_id))
ood_images, ood_labels = np.array(ood_images), np.array(ood_labels)
ood_preprocessed_images = torch.stack([transform(to_pil(img)) for img in ood_images], 0)



2024-11-27 13:00:55.636624: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-27 13:00:55.644981: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732708855.655562   56208 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732708855.658732   56208 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-27 13:00:55.669662: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [None]:


label_maps = []
drift_ratios = []
drift_localizer = []
drift_comparison = []
drift_forest = []


one_local_one_global = []
one_local = []
two_local = []
three_local = []
one_global = []
two_global = []
three_global = []

one_local_one_global_preds = []
one_local_preds = []
two_local_preds = []
three_local_preds = []
one_global_preds = []
two_global_preds = []
three_global_preds = []

one_local_one_global_l = []
one_local_l = []
two_local_l = []
three_local_l = []
one_global_l = []
two_global_l = []
three_global_l = []

one_local_one_global_preds_l = []
one_local_preds_l = []
two_local_preds_l = []
three_local_preds_l = []
one_global_preds_l = []
two_global_preds_l = []
three_global_preds_l = []

# all_local_l = []
# all_global_l = []
# all_local_preds_l = []
# all_global_preds_l = []

one_local_l_probs = []
two_local_l_probs = []
three_local_l_probs = []
one_local_preds_l_probs = []
two_local_preds_l_probs = []
three_local_preds_l_probs = []

reconstructed_single_concepts = []
reconstructed_single_concepts_preds = []
reconstructed_2_concepts = []
reconstructed_2_concepts_preds = []
reconstructed_3_concepts = []
reconstructed_3_concepts_preds = []
reconstructed_all_concepts = []
reconstructed_all_concepts_preds = []


for j in range(50):

    sample_ids = np.random.choice(len(ood_preprocessed_images),500, False)
    
    sample_images = ood_preprocessed_images[sample_ids]
    
    
    # Map each digit to a label indicating whether it occurs before or after the change point, or both, or neither
    #  0 - never, 1 - before, 2 - after, 3 - both
    
    # 0-before, 1-after
    import random

    # Initialize the label_map keys
    keys = ood_class_ids
    
    # Shuffle the keys for more randomness
    random.shuffle(keys)
    
    # Assign at least one of each label (0, 1, 2)
    initial_labels = [0, 1, 2]
    random.shuffle(initial_labels)
    
    # Ensure that the first three keys have 0, 1, and 2 respectively
    label_map = {keys[i]: initial_labels[i] for i in range(3)}
    
    # Randomly assign labels for the remaining keys
    for i in range(3, len(keys)):
        label_map[keys[i]] = random.randint(0, 2)
    
    label_maps.append(label_map)

    
    labels_mapped = np.array([label_map[class_id] for class_id in ood_labels])
    
    drift_labels = labels_mapped[sample_ids]
    
    # Randomly assign labels of 1 or 2 to samples with label 3
    #  (i.e., digits that occur both before and after the change point)
    label_2_idx = np.where(drift_labels == 2)[0]
    y_mixed = drift_labels.copy()
    y_mixed[label_2_idx] = np.random.choice([0, 1], size=len(label_2_idx))
    
    sample_labels = y_mixed
            
    drift_ratios.append({"BD": len(np.where(drift_labels == 0)[0]),
                         "AD": len(np.where(drift_labels == 1)[0]),
                         "Both": len(np.where(drift_labels == 2)[0])})
    
    full_size = 256
    patch_size= 100
    
    
    #Supervised CRAFT Training
    h_craftdv = CraftS(input_to_latent_model=g,
                        latent_to_logit_model=h,
                        number_of_concepts=5,
                        inputs=sample_images,
                        labels=sample_labels,
                        batch_size=64,
                        patch_size=full_size,
                        device=device)
    
    patches, patch_act, train_labels = h_craftdv._extract_patches(sample_images, sample_labels )

    bd_indices = np.where(sample_labels != 1)[0]
    ad_indices = np.where(sample_labels != 0)[0]

    bd_fit = Craft(input_to_latent_model=g,
                    latent_to_logit_model=h,
                    number_of_concepts=10,
                    # labels=h_y,
                    patch_size=patch_size,
                    batch_size=64,
                    device=device)
    print("Fitting Unsupervised Craft....")
    bd_crops, bd_crops_u, bd_w = bd_fit.fit(sample_images[bd_indices])
    
    
    ad_fit = Craft(input_to_latent_model=g,
                        latent_to_logit_model=h,
                        number_of_concepts=10,
                        # labels=h_y,
                        patch_size=patch_size,
                        batch_size=64,
                        device=device)
    print("Fitting Unsupervised Craft....")
    ad_crops, ad_crops_u, ad_w = ad_fit.fit(sample_images[ad_indices])

    drift_basis = np.vstack([bd_w, ad_w])

    drift_craft = CombinedCrafts(input_to_latent_model=g,
                    latent_to_logit_model=h,
                    number_of_concepts=len(drift_basis),
                    inputs=sample_images,
                    labels=sample_labels,
                    basis = drift_basis,
                    batch_size=64,
                    patch_size=patch_size,
                    device=device)
    print("Fitting Craft....")
    drift_craft.transform_all()

    
    X_clean = patch_act
    y_clean = train_labels

    # Initialize a random forest model with max_leaf_nodes=150
    localizer_model = Localizer()
    

    # Perform the train-test split on X_clean and sample_labels
    X_train_clean, X_test_clean, y_train, y_test = \
        train_test_split(X_clean, y_clean, train_size=0.7, random_state=42)

    # Fit the model to the mixed set (group 3 is randomly assigned to 1 or 2)
    print('Fitting Random Forest classifier...')
    localizer_model.fit(X_train_clean, y_train);
    print('Fitting complete.')

    localizer_bin_preds = localizer_model.l_predict(X_test_clean)
    drift_localizer.append(accuracy_score(localizer_bin_preds, y_test))

    drift_imp = np.round(estimate_importance_l(localizer_model, drift_craft, drift_basis, X_train_clean),3)


    
    # y_preds_l, _ = compute_predictions(localizer_model,X_test_clean)
    image_drift_imp_l = [estimate_importance_helper_l(drift_craft,localizer_model,drift_basis,
                                                  image,class_of_interest=localizer_bin_preds[i]) 
                               for i,image in enumerate(X_test_clean)]

    
    
    one_local_one_global_l.append(local_one_imp_concept_globally_l(drift_craft,image_drift_imp_l,y_test))
    one_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=1,labels=y_test))
    two_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=2,labels=y_test))
    three_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=3,labels=y_test))
    # all_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=20,labels=y_test))

    one_global_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=1,labels=y_test))
    two_global_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=2,labels=y_test))
    three_global_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=3,labels=y_test))
    # all_global_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=20,labels=y_test))

    # all_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=20,labels=y_test))

    
    one_local_one_global_preds_l.append(local_one_imp_concept_globally_l(drift_craft,image_drift_imp_l,localizer_bin_preds))
    one_local_preds_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=1,labels=localizer_bin_preds))
    two_local_preds_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=2,labels=localizer_bin_preds))
    three_local_preds_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=3,labels=localizer_bin_preds))


    # all_local_preds_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=20,labels=localizer_bin_preds))

    one_global_preds_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=1,labels=localizer_bin_preds))
    two_global_preds_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=2,labels=localizer_bin_preds))
    three_global_preds_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=3,labels=localizer_bin_preds))
    # all_global_preds_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=20,labels=localizer_bin_preds))


    localizer_bin_train_preds = localizer_model.l_predict(X_train_clean)
    image_drift_imp_l_train = [estimate_importance_helper_l(drift_craft,localizer_model,drift_basis,
                                                  image,class_of_interest=localizer_bin_train_preds[i]) 
                               for i,image in enumerate(X_train_clean)]
    concept_dist = concept_counter(image_drift_imp_l_train,localizer_bin_train_preds)

    one_local_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=1,labels=y_test))
    two_local_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=2,labels=y_test))
    three_local_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=3,labels=y_test))

    one_local_preds_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=1,labels=localizer_bin_preds))
    two_local_preds_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=2,labels=localizer_bin_preds))
    three_local_preds_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=3,labels=localizer_bin_preds))

    reconstructed_single_concept = reconstruct_inputs(X_test_clean, image_drift_imp_l, drift_basis, num_concepts=1)
    localizer_preds = localizer_model.l_predict(reconstructed_single_concept)
    reconstructed_single_concepts.append(accuracy_score(localizer_preds, y_test))
    reconstructed_single_concepts_preds.append(accuracy_score(localizer_preds, localizer_bin_preds))


    reconstructed_2_concept = reconstruct_inputs(X_test_clean, image_drift_imp_l, drift_basis, num_concepts=2)
    localizer_preds = localizer_model.l_predict(reconstructed_2_concept)
    reconstructed_2_concepts.append(accuracy_score(localizer_preds, y_test))
    reconstructed_2_concepts_preds.append(accuracy_score(localizer_preds, localizer_bin_preds))
    
    reconstructed_3_concept = reconstruct_inputs(X_test_clean, image_drift_imp_l, drift_basis, num_concepts=3)
    localizer_preds = localizer_model.l_predict(reconstructed_3_concept)
    reconstructed_3_concepts.append(accuracy_score(localizer_preds, y_test))
    reconstructed_3_concepts_preds.append(accuracy_score(localizer_preds, localizer_bin_preds))


    reconstructed_all_concept = reconstruct_inputs(X_test_clean, image_drift_imp_l, drift_basis, num_concepts=21)
    localizer_preds = localizer_model.l_predict(reconstructed_all_concept)
    reconstructed_all_concepts.append(accuracy_score(localizer_preds, y_test))
    reconstructed_all_concepts_preds.append(accuracy_score(localizer_preds, localizer_bin_preds))

    print("Run:",j)
    
    

    

In [5]:
import csv

# Example accuracy lists (replace with your actual data)
# method_1 = [0.85, 0.87, 0.86, ...]  # List of 50 accuracies for method 1
# method_2 = [0.82, 0.84, 0.83, ...]  # List of 50 accuracies for method 2
# Repeat for all 7 methods
# methods = [drift_localizer,one_local_one_global, one_local,two_local,
#            three_local,one_global,two_global,three_global,one_local_one_global_p,
#             one_local_p, two_local_p ,three_local_p ,one_global_p ,two_global_p,
#             three_global_p]  # List of lists

# methods = [drift_localizer,one_local_one_global_p,
#             one_local_p, two_local_p ,three_local_p ,one_global_p ,two_global_p,
#             three_global_p]

# one_local_l_probs = []
# two_local_l_probs = []
# three_local_l_probs = []
# one_local_preds_l_probs = []
# two_local_preds_l_probs = []
# three_local_preds_l_probs = []

methods = [ drift_localizer,
            one_local_one_global_l,
            one_local_l,
            two_local_l,
            three_local_l,
            one_local_l_probs,
            two_local_l_probs,
            three_local_l_probs,
        
            one_global_l,
            two_global_l,
            three_global_l,
       
            one_local_one_global_preds_l,
            one_local_preds_l,
            two_local_preds_l,
            three_local_preds_l,
           one_local_preds_l_probs,
            two_local_preds_l_probs,
            three_local_preds_l_probs,
           
            one_global_preds_l,
            two_global_preds_l,
            three_global_preds_l,

           reconstructed_single_concepts,
reconstructed_single_concepts_preds,
reconstructed_2_concepts,
reconstructed_2_concepts_preds,
reconstructed_3_concepts,
reconstructed_3_concepts_preds,
reconstructed_all_concepts,
reconstructed_all_concepts_preds,
           
          
            # drift_comparison,
            # drift_forest,
            # one_local_one_global,
            # one_local,
            # two_local,
            # three_local,
            # one_global,
            # two_global,
            # three_global,
            # one_local_one_global_preds,
            # one_local_preds,
            # two_local_preds,
            # three_local_preds,
            # one_global_preds,
            # two_global_preds,
            # three_global_preds,
            label_maps,
            drift_ratios]

# method_names = ['Drift Localizer', '1_Local_global','1_Local', '2_Local', '3_Local', '1_Global', '2_Global', '3_Global',
#                '1_Local_global_p','1_Local_p', '2_Local_p', '3_Local_p', '1_Global_p', '2_Global_p', '3_Global_p']
# method_names = ['Drift Localizer', 
#                '1_Local_global','1_Local', '2_Local', '3_Local', '1_Global', '2_Global', '3_Global',
#                '1_Local_global_preds','1_Local_preds', '2_Local_preds', '3_Local_preds', '1_Global_preds', '2_Global_preds', '3_Global_preds']

method_names = [ "drift_localizer",            
            "one_local_one_global_l",
            "one_local_l",
            "two_local_l",
            "three_local_l",
                "one_local_l_probs",
            "two_local_l_probs",
            "three_local_l_probs",
            
            "one_global_l",
            "two_global_l",
            "three_global_l",
        
            "one_local_one_global_preds_l",
            "one_local_preds_l",
            "two_local_preds_l",
            "three_local_preds_l",
            "one_local_preds_l_probs",
            "two_local_preds_l_probs",
            "three_local_preds_l_probs",
    
            "one_global_preds_l",
            "two_global_preds_l",
            "three_global_preds_l",
                "reconstructed_single_concepts",
            "reconstructed_single_concepts_preds",
            "reconstructed_2_concepts",
            "reconstructed_2_concepts_preds",
            "reconstructed_3_concepts",
            "reconstructed_3_concepts_preds",
            "reconstructed_all_concepts",
            "reconstructed_all_concepts_preds",
            # "drift_comparison",
            #     "drift_forest",
            # "one_local_one_global",
            # "one_local",
            # "two_local",
            # "three_local",
            # "one_global",
            # "two_global",
            # "three_global",
            # "one_local_one_global_preds",
            # "one_local_preds",
            # "two_local_preds",
            # "three_local_preds",
            # "one_global_preds",
            # "two_global_preds",
            # "three_global_preds",
               "label_maps",
               "drift_ratios"]

# Write to CSV
with open('paper_experiment_reproduce.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Method'] + [f'Run_{i+1}' for i in range(50)])  # Header row
    for method, accuracies in zip(method_names, methods):
        writer.writerow([method] + accuracies)


In [3]:
import pandas as pd
import numpy as np

# Load CSV
# df = pd.read_csv('new_experiments_run5_only_bdad.csv')

df = pd.read_csv('paper_experiment_D1.csv')
df = df.iloc[:29]

# Calculate mean and standard deviation
stats = {}
for method in df['Method']:
    accuracies = df[df['Method'] == method].drop('Method', axis=1).values.flatten().astype(float)
    mean = np.mean(accuracies)
    # median = np.median(accuracies)
    
    std = np.std(accuracies)
    # stats[method] = (mean, std, median)
    stats[method] = (mean, std)
# Example output for stats
print(stats)

{'drift_localizer': (0.8148000000000001, 0.06316525064375894), 'one_local_one_global_l': (0.7786666666666667, 0.058575687030636786), 'one_local_l': (0.7885333333333335, 0.05674175417333048), 'two_local_l': (0.7698666666666667, 0.06158268154819719), 'three_local_l': (0.7452, 0.06062328118983848), 'one_local_l_probs': (0.7878666666666667, 0.05545272260623859), 'two_local_l_probs': (0.772, 0.05845416057808792), 'three_local_l_probs': (0.7484000000000001, 0.05791656834524028), 'one_global_l': (0.6836000000000001, 0.07112067991176062), 'two_global_l': (0.7294666666666666, 0.06584952543488828), 'three_global_l': (0.7446666666666666, 0.06060803027102378), 'one_local_one_global_preds_l': (0.7902666666666666, 0.10004630039247939), 'one_local_preds_l': (0.8206666666666667, 0.08527341646466122), 'two_local_preds_l': (0.7980000000000002, 0.0904433524367601), 'three_local_preds_l': (0.7786666666666666, 0.08652295777550732), 'one_local_preds_l_probs': (0.8290666666666668, 0.08494453092067396), 'two_

In [4]:
latex_table = """
\\begin{table}[h!]
\\centering
\\begin{tabular}{l|c}
\\hline
Method & Accuracy (Mean ± Std Dev) \\\\
\\hline
"""

for method, (mean, std) in stats.items():
    latex_table += f"{method} & {mean:.3f} ± {std:.3f} \\\\ \n"

latex_table += """
\\hline
\\end{tabular}
\\caption{Accuracy of different methods}
\\end{table}
"""


#include all and see what happens

# Output the LaTeX table
print(latex_table)

##Model h tilde for paper is encompassed by "one_local_l_probs"

## We have other models here which use more concepts for possible future work


\begin{table}[h!]
\centering
\begin{tabular}{l|c}
\hline
Method & Accuracy (Mean ± Std Dev) \\
\hline
drift_localizer & 0.815 ± 0.063 \\ 
one_local_one_global_l & 0.779 ± 0.059 \\ 
one_local_l & 0.789 ± 0.057 \\ 
two_local_l & 0.770 ± 0.062 \\ 
three_local_l & 0.745 ± 0.061 \\ 
one_local_l_probs & 0.788 ± 0.055 \\ 
two_local_l_probs & 0.772 ± 0.058 \\ 
three_local_l_probs & 0.748 ± 0.058 \\ 
one_global_l & 0.684 ± 0.071 \\ 
two_global_l & 0.729 ± 0.066 \\ 
three_global_l & 0.745 ± 0.061 \\ 
one_local_one_global_preds_l & 0.790 ± 0.100 \\ 
one_local_preds_l & 0.821 ± 0.085 \\ 
two_local_preds_l & 0.798 ± 0.090 \\ 
three_local_preds_l & 0.779 ± 0.087 \\ 
one_local_preds_l_probs & 0.829 ± 0.085 \\ 
two_local_preds_l_probs & 0.813 ± 0.080 \\ 
three_local_preds_l_probs & 0.791 ± 0.074 \\ 
one_global_preds_l & 0.703 ± 0.095 \\ 
two_global_preds_l & 0.761 ± 0.090 \\ 
three_global_preds_l & 0.771 ± 0.086 \\ 
reconstructed_single_concepts & 0.764 ± 0.062 \\ 
reconstructed_single_concepts_preds