Drift Compression Experiment - See paper Section 4.1 for more details


Implementation outline for Experiment
1. read in data
2. embed data with pretrained model
3. run NMF over embeddings and get the nmf representation of the input
4. train the drift localizer(DL) on the nmf reps
5. estimate global importance of drift localizer
6. estimate local importance of each input
7. run different h /tilde models over the local importances and compare to DL


In [1]:
from datasets import load_dataset
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from concept_helpers.DeepView_Craft import CraftTorchDV as Craft
from concept_helpers.DeepView_Craft import CraftTorchSupervised as CraftS
from concept_helpers.combined_crafts import CombinedCrafts

import urllib.request
import glob
import torch
import torch.nn as nn
from torchvision import transforms
import timm

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from scipy.sparse.linalg import eigs
from sklearn.ensemble import  RandomForestClassifier

import numpy as np
import matplotlib.pyplot as plt

from sklearn.decomposition import NMF
from sklearn.metrics import accuracy_score
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

import random

from xplique.concepts.craft import BaseCraft, DisplayImportancesOrder, Factorization, Sensitivity
from sklearn.decomposition import non_negative_factorization
from experiment_helpers.helper_function import *
from experiment_helpers.driftLocalizer import Localizer


device = 'cuda' 

# loading any timm model
model = timm.create_model('nf_resnet50.ra2_in1k', pretrained=True)
model = model.to(device)

# processing
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
to_pil = transforms.ToPILImage()

# cut the model in twop arts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model
g = nn.Sequential(*(list(model.children())[:4]))  # input to penultimate layer
h = nn.Sequential(*(list(model.children())[4:]))  # penultimate layer to logits


with urllib.request.urlopen('https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt') as f:
        imagenet_class_names = np.array(f.read().decode('utf-8').split('\n'))

def gen_images(filelist,folder_names,folder_name2class_id):
        for f in filelist:
            folder_name = f.split('/')[-2]
            if folder_name in folder_names:
                class_id = folder_name2class_id[folder_name]
                im = Image.open(f)
                if len(im.getbands()) == 3:
                    yield np.array(im.resize((224, 224))), class_id

idd_folder = 'path/to/subset/of/imagenet'


idd_folder_names = os.listdir('path/to/subset/of/imagenet')
idd_class_names = idd_folder_names

idd_class_ids = [np.where(imagenet_class_names == class_name)[0][0] for class_name in idd_class_names]
folder_name2class_id = dict(zip(idd_folder_names, idd_class_ids))
filelist = glob.glob(f'{idd_folder}/*/*.jpg')


images, labels = zip(*gen_images(filelist,idd_folder_names,folder_name2class_id))
images, labels = np.array(images), np.array(labels)
preprocessed_images = torch.stack([transform(to_pil(img)) for img in images], 0)
print(preprocessed_images.shape)

2024-11-26 13:16:36.113385: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-26 13:16:36.145080: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-26 13:16:36.145110: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-26 13:16:36.146038: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-26 13:16:36.151519: I tensorflow/core/platform/cpu_feature_guar

torch.Size([2777, 3, 256, 256])


In [3]:
import random


label_maps = []
drift_ratios = []
drift_localizer = []
drift_comparison = []
drift_forest = []


one_local_one_global = []
one_local = []
two_local = []
three_local = []
one_global = []
two_global = []
three_global = []

one_local_one_global_preds = []
one_local_preds = []
two_local_preds = []
three_local_preds = []
one_global_preds = []
two_global_preds = []
three_global_preds = []

one_local_one_global_l = []
one_local_l = []
two_local_l = []
three_local_l = []
one_global_l = []
two_global_l = []
three_global_l = []

one_local_one_global_preds_l = []
one_local_preds_l = []
two_local_preds_l = []
three_local_preds_l = []
one_global_preds_l = []
two_global_preds_l = []
three_global_preds_l = []

one_local_l_probs = []
two_local_l_probs = []
three_local_l_probs = []
one_local_preds_l_probs = []
two_local_preds_l_probs = []
three_local_preds_l_probs = []

reconstructed_single_concepts = []
reconstructed_single_concepts_preds = []
reconstructed_2_concepts = []
reconstructed_2_concepts_preds = []
reconstructed_3_concepts = []
reconstructed_3_concepts_preds = []
reconstructed_all_concepts = []
reconstructed_all_concepts_preds = []


for j in range(50):

    sample_ids = np.random.choice(len(preprocessed_images),500, False)
    
    sample_images = preprocessed_images[sample_ids]
    
    
    # Map each digit to a label indicating whether it occurs before or after the change point, or both, or neither
    #  0 - never, 1 - before, 2 - after, 3 - both
    

    # Initialize the label_map keys
    keys = idd_class_ids
    
    # Shuffle the keys for more randomness
    random.shuffle(keys)
    
    # Assign at least one of each label (0, 1, 2)
    initial_labels = [0, 1, 2]
    random.shuffle(initial_labels)
    
    # Ensure that the first three keys have 0, 1, and 2 respectively
    label_map = {keys[i]: initial_labels[i] for i in range(3)}
    
    # Randomly assign labels for the remaining keys
    for i in range(3, len(keys)):
        label_map[keys[i]] = random.randint(0, 2)
    
    label_maps.append(label_map)
    
    labels_mapped = np.array([label_map[class_id] for class_id in labels])
    
    drift_labels = labels_mapped[sample_ids]
    
    # Randomly assign labels of 1 or 2 to samples with label 3
    #  (i.e., digits that occur both before and after the change point)
    label_2_idx = np.where(drift_labels == 2)[0]
    y_mixed = drift_labels.copy()
    y_mixed[label_2_idx] = np.random.choice([0, 1], size=len(label_2_idx))
    
    sample_labels = y_mixed

    drift_ratios.append({"BD": len(np.where(drift_labels == 0)[0]),
                         "AD": len(np.where(drift_labels == 1)[0]),
                         "Both": len(np.where(drift_labels == 2)[0])})
    
    full_size = 256
    patch_size= 100
    
    
    #Supervised CRAFT Training
    h_craftdv = CraftS(input_to_latent_model=g,
                        latent_to_logit_model=h,
                        number_of_concepts=5,
                        inputs=sample_images,
                        labels=sample_labels,
                        batch_size=64,
                        patch_size=full_size,
                        device=device)
    
    patches, patch_act, train_labels = h_craftdv._extract_patches(sample_images, sample_labels )

    bd_indices = np.where(sample_labels != 1)[0]
    ad_indices = np.where(sample_labels != 0)[0]

    bd_fit = Craft(input_to_latent_model=g,
                    latent_to_logit_model=h,
                    number_of_concepts=10,
                    # labels=h_y,
                    patch_size=patch_size,
                    batch_size=64,
                    device=device)
    print("Fitting Unsupervised Craft....")
    bd_crops, bd_crops_u, bd_w = bd_fit.fit(sample_images[bd_indices])
    
    
    ad_fit = Craft(input_to_latent_model=g,
                        latent_to_logit_model=h,
                        number_of_concepts=10,
                        # labels=h_y,
                        patch_size=patch_size,
                        batch_size=64,
                        device=device)
    print("Fitting Unsupervised Craft....")
    ad_crops, ad_crops_u, ad_w = ad_fit.fit(sample_images[ad_indices])

    drift_basis = np.vstack([bd_w, ad_w])

    drift_craft = CombinedCrafts(input_to_latent_model=g,
                    latent_to_logit_model=h,
                    number_of_concepts=len(drift_basis),
                    inputs=sample_images,
                    labels=sample_labels,
                    basis = drift_basis,
                    batch_size=64,
                    patch_size=patch_size,
                    device=device)
    print("Fitting Craft....")
    drift_craft.transform_all()

    
    X_clean = patch_act
    y_clean = train_labels

    # Initialize a random forest model with max_leaf_nodes=150
    localizer_model = Localizer()
    

    # Perform the train-test split on X_clean and sample_labels
    X_train_clean, X_test_clean, y_train, y_test = \
        train_test_split(X_clean, y_clean, train_size=0.7, random_state=42)

    # Fit the model to the mixed set (group 3 is randomly assigned to 1 or 2)
    print('Fitting Random Forest classifier...')
    localizer_model.fit(X_train_clean, y_train);
    print('Fitting complete.')

    localizer_bin_preds = localizer_model.l_predict(X_test_clean)
    drift_localizer.append(accuracy_score(localizer_bin_preds, y_test))

    drift_imp = np.round(estimate_importance_l(localizer_model, drift_craft, drift_basis, X_train_clean),3)


    
    # y_preds_l, _ = compute_predictions(localizer_model,X_test_clean)
    image_drift_imp_l = [estimate_importance_helper_l(drift_craft,localizer_model,drift_basis,
                                                  image,class_of_interest=localizer_bin_preds[i]) 
                               for i,image in enumerate(X_test_clean)]

    
    
    one_local_one_global_l.append(local_one_imp_concept_globally_l(drift_craft,image_drift_imp_l,y_test))
    one_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=1,labels=y_test))
    two_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=2,labels=y_test))
    three_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=3,labels=y_test))
    # all_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=20,labels=y_test))

    one_global_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=1,labels=y_test))
    two_global_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=2,labels=y_test))
    three_global_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=3,labels=y_test))
    # all_global_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=20,labels=y_test))

    # all_local_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=20,labels=y_test))

    
    one_local_one_global_preds_l.append(local_one_imp_concept_globally_l(drift_craft,image_drift_imp_l,localizer_bin_preds))
    one_local_preds_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=1,labels=localizer_bin_preds))
    two_local_preds_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=2,labels=localizer_bin_preds))
    three_local_preds_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=3,labels=localizer_bin_preds))


    # all_local_preds_l.append(local_imp_concepts_globally_l(drift_craft,image_drift_imp_l,num=20,labels=localizer_bin_preds))

    one_global_preds_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=1,labels=localizer_bin_preds))
    two_global_preds_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=2,labels=localizer_bin_preds))
    three_global_preds_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=3,labels=localizer_bin_preds))
    # all_global_preds_l.append(global_imp_concepts_locally_l(drift_craft,image_drift_imp_l,num=20,labels=localizer_bin_preds))


    localizer_bin_train_preds = localizer_model.l_predict(X_train_clean)
    image_drift_imp_l_train = [estimate_importance_helper_l(drift_craft,localizer_model,drift_basis,
                                                  image,class_of_interest=localizer_bin_train_preds[i]) 
                               for i,image in enumerate(X_train_clean)]
    concept_dist = concept_counter(image_drift_imp_l_train,localizer_bin_train_preds)

    one_local_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=1,labels=y_test))
    two_local_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=2,labels=y_test))
    three_local_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=3,labels=y_test))

    one_local_preds_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=1,labels=localizer_bin_preds))
    two_local_preds_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=2,labels=localizer_bin_preds))
    three_local_preds_l_probs.append(local_imp_concepts_probability(concept_dist,image_drift_imp_l,num=3,labels=localizer_bin_preds))

    reconstructed_single_concept = reconstruct_inputs(X_test_clean, image_drift_imp_l, drift_basis, num_concepts=1)
    localizer_preds = localizer_model.l_predict(reconstructed_single_concept)
    reconstructed_single_concepts.append(accuracy_score(localizer_preds, y_test))
    reconstructed_single_concepts_preds.append(accuracy_score(localizer_preds, localizer_bin_preds))


    reconstructed_2_concept = reconstruct_inputs(X_test_clean, image_drift_imp_l, drift_basis, num_concepts=2)
    localizer_preds = localizer_model.l_predict(reconstructed_2_concept)
    reconstructed_2_concepts.append(accuracy_score(localizer_preds, y_test))
    reconstructed_2_concepts_preds.append(accuracy_score(localizer_preds, localizer_bin_preds))
    
    reconstructed_3_concept = reconstruct_inputs(X_test_clean, image_drift_imp_l, drift_basis, num_concepts=3)
    localizer_preds = localizer_model.l_predict(reconstructed_3_concept)
    reconstructed_3_concepts.append(accuracy_score(localizer_preds, y_test))
    reconstructed_3_concepts_preds.append(accuracy_score(localizer_preds, localizer_bin_preds))


    reconstructed_all_concept = reconstruct_inputs(X_test_clean, image_drift_imp_l, drift_basis, num_concepts=21)
    localizer_preds = localizer_model.l_predict(reconstructed_all_concept)
    reconstructed_all_concepts.append(accuracy_score(localizer_preds, y_test))
    reconstructed_all_concepts_preds.append(accuracy_score(localizer_preds, localizer_bin_preds))

    print("Run:",j)
    
    

    

Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.35 Mean:0.5085714285714286 High threshold:0.7, No. Leaves:20
Fitting complete.
Run: 0
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.15 Mean:0.2857142857142857 High threshold:0.45, No. Leaves:20
Fitting complete.
Run: 1
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.1 Mean:0.2742857142857143 High threshold:0.45, No. Leaves:20
Fitting complete.
Run: 2
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.4085714285714286 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 3
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.3514285714285714 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 4
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.3 Mean:0.4714285714285714 High threshold:0.65, No. Leaves:20
Fitting complete.
Run: 5
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.43714285714285717 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 6
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.4 Mean:0.6028571428571429 High threshold:0.8, No. Leaves:20
Fitting complete.
Run: 7
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.35 Mean:0.5542857142857143 High threshold:0.75, No. Leaves:20
Fitting complete.
Run: 8
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.44 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 9
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.3 Mean:0.5028571428571429 High threshold:0.7, No. Leaves:20
Fitting complete.
Run: 10
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.3514285714285714 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 11
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.38571428571428573 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 12
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.35 Mean:0.5257142857142857 High threshold:0.7, No. Leaves:20
Fitting complete.
Run: 13
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.3 Mean:0.45714285714285713 High threshold:0.65, No. Leaves:20
Fitting complete.
Run: 14
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.44857142857142857 High threshold:0.65, No. Leaves:20
Fitting complete.
Run: 15
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.15 Mean:0.2885714285714286 High threshold:0.45, No. Leaves:20
Fitting complete.
Run: 16
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.37142857142857144 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 17
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.38 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 18
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.3 Mean:0.4857142857142857 High threshold:0.65, No. Leaves:20
Fitting complete.
Run: 19
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.3628571428571429 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 20
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.35 Mean:0.5171428571428571 High threshold:0.7, No. Leaves:20
Fitting complete.
Run: 21
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.4114285714285714 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 22
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.4542857142857143 High threshold:0.65, No. Leaves:20
Fitting complete.
Run: 23
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.55 Mean:0.7228571428571429 High threshold:0.9, No. Leaves:20
Fitting complete.
Run: 24
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.45 Mean:0.6228571428571429 High threshold:0.8, No. Leaves:20
Fitting complete.
Run: 25
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.4 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 26
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.3457142857142857 High threshold:0.5, No. Leaves:20
Fitting complete.
Run: 27
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.5 Mean:0.6942857142857143 High threshold:0.85, No. Leaves:20
Fitting complete.
Run: 28
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.3914285714285714 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 29
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.1 Mean:0.23714285714285716 High threshold:0.4, No. Leaves:20
Fitting complete.
Run: 30
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.44 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 31
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.4114285714285714 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 32
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.3 Mean:0.47714285714285715 High threshold:0.65, No. Leaves:20
Fitting complete.
Run: 33
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.55 Mean:0.7257142857142858 High threshold:0.9, No. Leaves:20
Fitting complete.
Run: 34
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.37142857142857144 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 35
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.3914285714285714 High threshold:0.55, No. Leaves:20
Fitting complete.
Run: 36
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.4 Mean:0.6 High threshold:0.8, No. Leaves:20
Fitting complete.
Run: 37
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.15 Mean:0.33714285714285713 High threshold:0.5, No. Leaves:20
Fitting comple



Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.3 Mean:0.4685714285714286 High threshold:0.65, No. Leaves:20
Fitting complete.
Run: 40
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.05 Mean:0.19714285714285715 High threshold:0.35, No. Leaves:20
Fitting complete.
Run: 41
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.45 Mean:0.6342857142857142 High threshold:0.8, No. Leaves:20
Fitting complete.
Run: 42
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.4 Mean:0.5628571428571428 High threshold:0.75, No. Leaves:20
Fitting complete.
Run: 43
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.4 Mean:0.58 High threshold:0.75, No. Leaves:20
Fitting complete.
Run: 44
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.15 Mean:0.31142857142857144 High threshold:0.5, No. Leaves:20
Fitting complete.
Run: 45
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.2 Mean:0.4 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 46
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.45 Mean:0.62 High threshold:0.8, No. Leaves:20
Fitting complete.
Run: 47
Fitting Unsupervised Craft....




Fitting Unsupervised Craft....
Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.35 Mean:0.5342857142857143 High threshold:0.7, No. Leaves:20
Fitting complete.
Run: 48
Fitting Unsupervised Craft....
Fitting Unsupervised Craft....




Fitting Craft....
Fitting Random Forest classifier...
Determine optimal parameters using cross validation
low threshold: 0.25 Mean:0.4342857142857143 High threshold:0.6, No. Leaves:20
Fitting complete.
Run: 49


In [4]:
import csv



methods = [ drift_localizer,
            one_local_one_global_l,
            one_local_l,
            two_local_l,
            three_local_l,
            one_local_l_probs,
            two_local_l_probs,
            three_local_l_probs,
        
            one_global_l,
            two_global_l,
            three_global_l,
       
            one_local_one_global_preds_l,
            one_local_preds_l,
            two_local_preds_l,
            three_local_preds_l,
           one_local_preds_l_probs,
            two_local_preds_l_probs,
            three_local_preds_l_probs,
           
            one_global_preds_l,
            two_global_preds_l,
            three_global_preds_l,

           reconstructed_single_concepts,
reconstructed_single_concepts_preds,
reconstructed_2_concepts,
reconstructed_2_concepts_preds,
reconstructed_3_concepts,
reconstructed_3_concepts_preds,
reconstructed_all_concepts,
reconstructed_all_concepts_preds,
           
        
            label_maps,
            drift_ratios]

method_names = [ "drift_localizer",            
            "one_local_one_global_l",
            "one_local_l",
            "two_local_l",
            "three_local_l",
                "one_local_l_probs",
            "two_local_l_probs",
            "three_local_l_probs",
            
            "one_global_l",
            "two_global_l",
            "three_global_l",
        
            "one_local_one_global_preds_l",
            "one_local_preds_l",
            "two_local_preds_l",
            "three_local_preds_l",
            "one_local_preds_l_probs",
            "two_local_preds_l_probs",
            "three_local_preds_l_probs",
    
            "one_global_preds_l",
            "two_global_preds_l",
            "three_global_preds_l",
                "reconstructed_single_concepts",
            "reconstructed_single_concepts_preds",
            "reconstructed_2_concepts",
            "reconstructed_2_concepts_preds",
            "reconstructed_3_concepts",
            "reconstructed_3_concepts_preds",
            "reconstructed_all_concepts",
            "reconstructed_all_concepts_preds",
               "label_maps",
               "drift_ratios"]

# Write to CSV
with open('paper_experiment_reproduced.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Method'] + [f'Run_{i+1}' for i in range(50)])  # Header row
    for method, accuracies in zip(method_names, methods):
        writer.writerow([method] + accuracies)


In [1]:
import pandas as pd
import numpy as np

# # Load CSV
# df = pd.read_csv('new_experiments_run6_l.csv')
# df = df.drop(['Method', axis=1)

df = pd.read_csv('paper_experiment_D2.csv')
df = df.iloc[:29]

# Calculate mean and standard deviation
stats = {}
for method in df['Method']:
    accuracies = df[df['Method'] == method].drop('Method', axis=1).values.flatten().astype(float)
    mean = np.mean(accuracies)
    # median = np.median(accuracies)
    std = np.std(accuracies)
    # stats[method] = (mean, std, median)
    stats[method] = (mean, std)
# Example output for stats
print(stats)

{'drift_localizer': (0.7901333333333332, 0.07494282264944833), 'one_local_one_global_l': (0.7602666666666668, 0.06827343716426046), 'one_local_l': (0.7736000000000002, 0.06962707008692015), 'two_local_l': (0.7649333333333334, 0.06793212298561949), 'three_local_l': (0.7399999999999999, 0.07415299499458311), 'one_local_l_probs': (0.7716000000000001, 0.06690371190499574), 'two_local_l_probs': (0.7633333333333334, 0.06457037504408143), 'three_local_l_probs': (0.7450666666666667, 0.06589634958697552), 'one_global_l': (0.6362666666666668, 0.1081082584983939), 'two_global_l': (0.7154666666666668, 0.09329942241097865), 'three_global_l': (0.7382666666666668, 0.09535719980974461), 'one_local_one_global_preds_l': (0.8269333333333333, 0.0784058954699482), 'one_local_preds_l': (0.8445333333333335, 0.07685414757838382), 'two_local_preds_l': (0.8369333333333333, 0.07423937260271293), 'three_local_preds_l': (0.8064, 0.08811063752149592), 'one_local_preds_l_probs': (0.8483999999999999, 0.07493772229139

In [2]:
latex_table = """
\\begin{table}[h!]
\\centering
\\begin{tabular}{l|c}
\\hline
Method & Accuracy (Mean ± Std Dev) \\\\
\\hline
"""

for method, (mean, std) in stats.items():
    latex_table += f"{method} & {mean:.3f} ± {std:.3f} \\\\ \n"

latex_table += """
\\hline
\\end{tabular}
\\caption{Accuracy of different methods}
\\end{table}
"""

# Output the LaTeX table
print(latex_table)

##Model h tilde for paper is encompassed by "one_local_l_probs"

## We have other models here which use more concepts for possible future work


\begin{table}[h!]
\centering
\begin{tabular}{l|c}
\hline
Method & Accuracy (Mean ± Std Dev) \\
\hline
drift_localizer & 0.790 ± 0.075 \\ 
one_local_one_global_l & 0.760 ± 0.068 \\ 
one_local_l & 0.774 ± 0.070 \\ 
two_local_l & 0.765 ± 0.068 \\ 
three_local_l & 0.740 ± 0.074 \\ 
one_local_l_probs & 0.772 ± 0.067 \\ 
two_local_l_probs & 0.763 ± 0.065 \\ 
three_local_l_probs & 0.745 ± 0.066 \\ 
one_global_l & 0.636 ± 0.108 \\ 
two_global_l & 0.715 ± 0.093 \\ 
three_global_l & 0.738 ± 0.095 \\ 
one_local_one_global_preds_l & 0.827 ± 0.078 \\ 
one_local_preds_l & 0.845 ± 0.077 \\ 
two_local_preds_l & 0.837 ± 0.074 \\ 
three_local_preds_l & 0.806 ± 0.088 \\ 
one_local_preds_l_probs & 0.848 ± 0.075 \\ 
two_local_preds_l_probs & 0.836 ± 0.068 \\ 
three_local_preds_l_probs & 0.817 ± 0.068 \\ 
one_global_preds_l & 0.687 ± 0.122 \\ 
two_global_preds_l & 0.779 ± 0.098 \\ 
three_global_preds_l & 0.809 ± 0.097 \\ 
reconstructed_single_concepts & 0.751 ± 0.071 \\ 
reconstructed_single_concepts_preds