In [1]:
import os
import torch
import pickle

In [2]:
def check_idx(fast_outputs, slow_outputs):
    # Check all the sample indices match.
    for target in range(10):
        fast_output = fast_outputs[target]
        slow_output = slow_outputs[target]
        idx_keys = [key for key in fast_output if "idx" in key]
        for idx_key in idx_keys:
            assert torch.equal(fast_output[idx_key], slow_output[idx_key])
    print("Indices all match!")

            
def check_attribution(fast_outputs, slow_outputs):
    # Check that the attribution scores match closely.
    for target in range(10):
        fast_attribution = fast_outputs[target]["attributions"]
        slow_attribution = slow_outputs[target]["attributions"]

        fast_flat_attribution = fast_attribution.flatten()
        slow_flat_attribution = slow_attribution.flatten()

        assert torch.all(
            torch.isclose(fast_flat_attribution, slow_flat_attribution, atol=1e-07)
        )
    print("Attributions all match closely!")

In [3]:
overall_path = "/projects/leelab/cl-explainability/results/cifar/simsiam_18"

In [4]:
fast_contrastive_label_free_path = os.path.join(
    overall_path,
    "fast_normalized_contrastive_self_weighted_int_grad",
    "123",
    "same_class_outputs_explicand_size=25_foil_size=1500_removal=blurring_blur_strength=5.0.pkl",
)
slow_contrastive_label_free_path = os.path.join(
    overall_path,
    "slow_normalized_contrastive_self_weighted_int_grad",
    "123",
    "same_class_outputs_explicand_size=25_foil_size=1500_removal=blurring_blur_strength=5.0.pkl",
)

with open(fast_contrastive_label_free_path, "rb") as handle:
    fast_outputs = pickle.load(handle)
with open(slow_contrastive_label_free_path, "rb") as handle:
    slow_outputs = pickle.load(handle)
    
check_idx(fast_outputs, slow_outputs)
check_attribution(fast_outputs, slow_outputs)

Indices all match!
Attributions all match closely!


In [5]:
fast_corpus_path = os.path.join(
    overall_path,
    "fast_normalized_corpus_int_grad",
    "123",
    "same_class_outputs_explicand_size=25_corpus_size=100_removal=blurring_blur_strength=5.0.pkl",
)
slow_corpus_path = os.path.join(
    overall_path,
    "slow_normalized_corpus_int_grad",
    "123",
    "same_class_outputs_explicand_size=25_corpus_size=100_removal=blurring_blur_strength=5.0.pkl",
)

with open(fast_corpus_path, "rb") as handle:
    fast_outputs = pickle.load(handle)
with open(slow_corpus_path, "rb") as handle:
    slow_outputs = pickle.load(handle)
    
check_idx(fast_outputs, slow_outputs)
check_attribution(fast_outputs, slow_outputs)

Indices all match!
Attributions all match closely!


In [6]:
fast_contrastive_corpus_path = os.path.join(
    overall_path,
    "fast_normalized_contrastive_corpus_int_grad",
    "123",
    "same_class_outputs_explicand_size=25_corpus_size=100_foil_size=1500_removal=blurring_blur_strength=5.0.pkl",
)
slow_contrastive_corpus_path = os.path.join(
    overall_path,
    "slow_normalized_contrastive_corpus_int_grad",
    "123",
    "same_class_outputs_explicand_size=25_corpus_size=100_foil_size=1500_removal=blurring_blur_strength=5.0.pkl",
)

with open(fast_contrastive_corpus_path, "rb") as handle:
    fast_outputs = pickle.load(handle)
with open(slow_contrastive_corpus_path, "rb") as handle:
    slow_outputs = pickle.load(handle)
    
check_idx(fast_outputs, slow_outputs)
check_attribution(fast_outputs, slow_outputs)

Indices all match!
Attributions all match closely!
