In [1]:
import os

In [2]:
os.chdir("../../scripts")

In [3]:
import constants
import torch
from experiment_utils import (
    load_data,
    load_encoder,
)
from torch.utils.data import DataLoader, Subset

In [4]:
from cl_explain.explanations.contrastive_weighted_score import ContrastiveWeightedScore
from cl_explain.explanations.corpus_similarity import CorpusSimilarity
from cl_explain.explanations.contrastive_corpus_similarity import (
    ContrastiveCorpusSimilarity,
)

In [5]:
val_dataset, _, _ = load_data("imagenet", "val", 32)
val_labels = [sample[0].split("/")[-2] for sample in val_dataset.samples]
unique_labels = constants.IMAGENETTE_SYNSETS
target = unique_labels[0]

In [6]:
val_target_idx = (
    torch.Tensor([label == target for label in val_labels])
    .nonzero()
    .flatten()
)
explicand_idx = val_target_idx[:25]
corpus_idx = val_target_idx[25:]

In [7]:
foil_idx = (
    torch.Tensor(
        [label != target and label in unique_labels for label in val_labels]
    )
    .nonzero()
    .flatten()
)
foil_idx = foil_idx[:100]

In [8]:
explicand_dataloader = DataLoader(
    Subset(val_dataset, indices=explicand_idx),
    batch_size=10,
    shuffle=False,
)
corpus_dataloader = DataLoader(
    Subset(val_dataset, indices=corpus_idx),
    batch_size=10,
    shuffle=False,
)
foil_dataloader = DataLoader(
    Subset(val_dataset, indices=foil_idx),
    batch_size=32,
    shuffle=False,
)

In [9]:
device = 7
encoder = load_encoder("simclr_x1")
encoder.eval()
encoder.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn

In [10]:
contrastive_weighted_score = ContrastiveWeightedScore(
    encoder=encoder,
    foil_dataloader=foil_dataloader,
    normalize=True,
    batch_size=32,
)

fast_outs = []
slow_outs = []
for explicand, _ in explicand_dataloader:
    explicand = explicand.to(device)
    contrastive_weighted_score.generate_weight(explicand.detach().clone())
    fast_out = contrastive_weighted_score.forward(explicand, implementation="mean").detach().cpu()
    slow_out = contrastive_weighted_score.forward(explicand, implementation="pairwise").detach().cpu()
    fast_outs.append(fast_out)
    slow_outs.append(slow_out)
fast_outs = torch.cat(fast_outs)
slow_outs = torch.cat(slow_outs)
print(torch.all(torch.isclose(fast_outs, slow_outs)))
print(fast_outs)

tensor(True)
tensor([0.6939, 0.7007, 0.7347, 0.7576, 0.7563, 0.6887, 0.7761, 0.7717, 0.7657,
        0.7613, 0.7140, 0.7742, 0.7496, 0.7881, 0.7606, 0.7860, 0.7703, 0.7824,
        0.8042, 0.7562, 0.7138, 0.7650, 0.7633, 0.8058, 0.7943])


In [11]:
contrastive_weighted_score = ContrastiveWeightedScore(
    encoder=encoder,
    foil_dataloader=foil_dataloader,
    normalize=False,
    batch_size=32,
)

fast_outs = []
slow_outs = []
for explicand, _ in explicand_dataloader:
    explicand = explicand.to(device)
    contrastive_weighted_score.generate_weight(explicand.detach().clone())
    fast_out = contrastive_weighted_score.forward(explicand, implementation="mean").detach().cpu()
    slow_out = contrastive_weighted_score.forward(explicand, implementation="pairwise").detach().cpu()
    fast_outs.append(fast_out)
    slow_outs.append(slow_out)
fast_outs = torch.cat(fast_outs)
slow_outs = torch.cat(slow_outs)
print(torch.all(torch.isclose(fast_outs, slow_outs)))
print(fast_outs)

tensor(True)
tensor([ 53.5182, 153.3758, 183.1795, 151.6195, 166.7433, 131.0765, 224.8011,
        377.9021, 168.5114, 372.6056, 175.8909, 275.7900, 208.8300, 320.5536,
        190.5505, 266.5364, 204.7652, 201.4692, 434.9282, 151.1830, 108.3063,
        174.5839, 174.9215, 243.8135, 498.4396])


In [12]:
corpus_similarity = CorpusSimilarity(
    encoder=encoder,
    corpus_dataloader=corpus_dataloader,
    normalize=True,
    batch_size=32,
)

fast_outs = []
slow_outs = []
for explicand, _ in explicand_dataloader:
    explicand = explicand.to(device)
    fast_out = corpus_similarity.forward(explicand, implementation="mean").detach().cpu()
    slow_out = corpus_similarity.forward(explicand, implementation="pairwise").detach().cpu()
    fast_outs.append(fast_out)
    slow_outs.append(slow_out)
fast_outs = torch.cat(fast_outs)
slow_outs = torch.cat(slow_outs)
print(torch.all(torch.isclose(fast_outs, slow_outs)))
print(fast_outs)

tensor(True)
tensor([0.4137, 0.5685, 0.4762, 0.5629, 0.5704, 0.5290, 0.6032, 0.6264, 0.6177,
        0.5928, 0.3106, 0.6303, 0.6210, 0.5773, 0.5660, 0.5552, 0.5797, 0.5833,
        0.6014, 0.5433, 0.5157, 0.4920, 0.5831, 0.5509, 0.6220])


In [13]:
corpus_similarity = CorpusSimilarity(
    encoder=encoder,
    corpus_dataloader=corpus_dataloader,
    normalize=False,
    batch_size=32,
)

fast_outs = []
slow_outs = []
for explicand, _ in explicand_dataloader:
    explicand = explicand.to(device)
    fast_out = corpus_similarity.forward(explicand, implementation="mean").detach().cpu()
    slow_out = corpus_similarity.forward(explicand, implementation="pairwise").detach().cpu()
    fast_outs.append(fast_out)
    slow_outs.append(slow_out)
fast_outs = torch.cat(fast_outs)
slow_outs = torch.cat(slow_outs)
print(torch.all(torch.isclose(fast_outs, slow_outs)))
print(fast_outs)

tensor(True)
tensor([ 62.6374, 132.6735, 117.5995, 127.2166, 134.4800, 115.7148, 162.2486,
        213.9382, 146.4273, 202.1586,  74.8569, 186.7480, 163.8298, 182.3440,
        141.1676, 160.6954, 149.4769, 149.0140, 218.4368, 123.0029, 102.6436,
        117.1184, 140.2859, 152.8642, 241.2948])


In [14]:
contrastive_corpus_similarity = ContrastiveCorpusSimilarity(
    encoder=encoder,
    corpus_dataloader=corpus_dataloader,
    foil_dataloader=foil_dataloader,
    normalize=True,
    batch_size=32,
)

fast_outs = []
slow_outs = []
for explicand, _ in explicand_dataloader:
    explicand = explicand.to(device)
    fast_out = contrastive_corpus_similarity.forward(explicand, implementation="mean").detach().cpu()
    slow_out = contrastive_corpus_similarity.forward(explicand, implementation="pairwise").detach().cpu()
    fast_outs.append(fast_out)
    slow_outs.append(slow_out)
fast_outs = torch.cat(fast_outs)
slow_outs = torch.cat(slow_outs)
print(torch.all(torch.isclose(fast_outs, slow_outs)))
print(fast_outs)

tensor(True)
tensor([0.1076, 0.2692, 0.2109, 0.3205, 0.3267, 0.2177, 0.3793, 0.3980, 0.3833,
        0.3541, 0.0246, 0.4046, 0.3706, 0.3655, 0.3267, 0.3412, 0.3500, 0.3657,
        0.4056, 0.2995, 0.2295, 0.2570, 0.3464, 0.3568, 0.4163])


In [15]:
contrastive_corpus_similarity = ContrastiveCorpusSimilarity(
    encoder=encoder,
    corpus_dataloader=corpus_dataloader,
    foil_dataloader=foil_dataloader,
    normalize=False,
    batch_size=32,
)

fast_outs = []
slow_outs = []
for explicand, _ in explicand_dataloader:
    explicand = explicand.to(device)
    fast_out = contrastive_corpus_similarity.forward(explicand, implementation="mean").detach().cpu()
    slow_out = contrastive_corpus_similarity.forward(explicand, implementation="pairwise").detach().cpu()
    fast_outs.append(fast_out)
    slow_outs.append(slow_out)
fast_outs = torch.cat(fast_outs)
slow_outs = torch.cat(slow_outs)
print(torch.all(torch.isclose(fast_outs, slow_outs)))
print(fast_outs)

tensor(True)
tensor([ 22.2385,  72.8735,  61.1973,  80.4654,  85.6751,  57.1075, 111.2328,
        147.7859,  99.2391, 132.9886,  14.4126, 130.1294, 107.9386, 125.8127,
         90.2288, 108.3405,  99.2070, 102.0712, 158.4294,  75.6810,  53.5171,
         69.2056,  92.0039, 107.4272, 173.7889])
