In [1]:
import kronfluence

In [2]:
from examples.uci.train import train, evaluate

In [43]:
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import ScoreArguments
from kronfluence.task import Task
from typing import Tuple
import torch
from torch import nn
import math
import torch.nn.functional as F

In [4]:
from examples.uci.pipeline import construct_regression_mlp, get_regression_dataset

In [12]:
dataset_name = "concrete"
dataset_dir = "./data"
train_batch_size = 32
eval_batch_size = 128
num_train_epochs = 40
learning_rate = 0.03
weight_decay = 1e-05

In [13]:
train_dataset = get_regression_dataset(data_name=dataset_name, split="train", dataset_dir=dataset_dir)

model = train(
    dataset=train_dataset,
    batch_size=train_batch_size,
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
)

Epoch 0: 100%|██████████| 28/28 [00:00<00:00, 545.57batch/s, loss=0.889]
Epoch 1: 100%|██████████| 28/28 [00:00<00:00, 582.99batch/s, loss=0.56]
Epoch 2: 100%|██████████| 28/28 [00:00<00:00, 576.00batch/s, loss=0.436]
Epoch 3: 100%|██████████| 28/28 [00:00<00:00, 574.70batch/s, loss=0.365]
Epoch 4: 100%|██████████| 28/28 [00:00<00:00, 554.65batch/s, loss=0.309]
Epoch 5: 100%|██████████| 28/28 [00:00<00:00, 539.66batch/s, loss=0.251]
Epoch 6: 100%|██████████| 28/28 [00:00<00:00, 531.06batch/s, loss=0.224]
Epoch 7: 100%|██████████| 28/28 [00:00<00:00, 547.77batch/s, loss=0.217]
Epoch 8: 100%|██████████| 28/28 [00:00<00:00, 574.51batch/s, loss=0.186]
Epoch 9: 100%|██████████| 28/28 [00:00<00:00, 546.78batch/s, loss=0.21]
Epoch 10: 100%|██████████| 28/28 [00:00<00:00, 540.93batch/s, loss=0.189]
Epoch 11: 100%|██████████| 28/28 [00:00<00:00, 534.31batch/s, loss=0.181]
Epoch 12: 100%|██████████| 28/28 [00:00<00:00, 551.12batch/s, loss=0.171]
Epoch 13: 100%|██████████| 28/28 [00:00<00:00, 548

In [14]:
query_dataset = get_regression_dataset(data_name=dataset_name, split="valid", dataset_dir=dataset_dir)

evaluate(model=model, dataset=query_dataset, batch_size=eval_batch_size)

0.16043664876697133

In [23]:
BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor]

class RegressionTask(Task):
    def compute_train_loss(
        self,
        batch: BATCH_DTYPE,
        model: nn.Module,
        sample: bool = False,
    ) -> torch.Tensor:
        inputs, targets = batch
        outputs = model(inputs)
        if not sample:
            return F.mse_loss(outputs, targets, reduction="sum")
        with torch.no_grad():
            sampled_targets = torch.normal(outputs, std=math.sqrt(0.5))
        return F.mse_loss(outputs, sampled_targets.detach(), reduction="sum")

    def compute_measurement(
        self,
        batch: BATCH_DTYPE,
        model: nn.Module,
    ) -> torch.Tensor:
        # The measurement function is set as a training loss.
        return self.compute_train_loss(batch, model, sample=False)

In [24]:
task = RegressionTask()
model = prepare_model(model, task)
analyzer = Analyzer(
    analysis_name="tutorial",
    model=model,
    task=task,
    cpu=True,
)

In [27]:
list(model.named_modules())

[('',
  Sequential(
    (0): TrackedLinear(
      (original_module): Linear(in_features=8, out_features=128, bias=True)
    )
    (1): ReLU()
    (2): TrackedLinear(
      (original_module): Linear(in_features=128, out_features=128, bias=True)
    )
    (3): ReLU()
    (4): TrackedLinear(
      (original_module): Linear(in_features=128, out_features=128, bias=True)
    )
    (5): ReLU()
    (6): TrackedLinear(
      (original_module): Linear(in_features=128, out_features=1, bias=True)
    )
  )),
 ('0',
  TrackedLinear(
    (original_module): Linear(in_features=8, out_features=128, bias=True)
  )),
 ('0.original_module', Linear(in_features=8, out_features=128, bias=True)),
 ('1', ReLU()),
 ('2',
  TrackedLinear(
    (original_module): Linear(in_features=128, out_features=128, bias=True)
  )),
 ('2.original_module', Linear(in_features=128, out_features=128, bias=True)),
 ('3', ReLU()),
 ('4',
  TrackedLinear(
    (original_module): Linear(in_features=128, out_features=128, bias=True)
  

In [25]:
covariance_matrices = analyzer.fit_covariance_matrices(
    factors_name="ekfac",
    dataset=train_dataset,
    per_device_batch_size=None,
    overwrite_output_dir=True,
)

Fitting covariance matrices [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Fitting covariance matrices [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]


In [33]:
covariance_matrices["activation_covariance"]["2"].shape

torch.Size([129, 129])

In [34]:
eigen_factors = analyzer.perform_eigendecomposition(
    factors_name="ekfac",
    overwrite_output_dir=True,
)

Performing Eigendecomposition [4/4] 100%|██████████ [time left: 00:00, time spent: 00:00]


In [36]:
eigen_factors["activation_eigenvectors"]["2"].shape

torch.Size([129, 129])

In [37]:
lambda_matrices = analyzer.fit_lambda_matrices(
    factors_name="ekfac",
    dataset=train_dataset,
    per_device_batch_size=None,
    overwrite_output_dir=True,
)

Fitting Lambda matrices [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Fitting Lambda matrices [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]


In [38]:
lambda_matrices["lambda_matrix"]["2"].shape

torch.Size([128, 129])

In [39]:
scores = analyzer.compute_pairwise_scores(
    scores_name="pairwise",
    factors_name="ekfac",
    query_dataset=query_dataset,
    train_dataset=train_dataset,
    per_device_query_batch_size=len(query_dataset),
    overwrite_output_dir=True,
)

Computing pairwise influence scores [0/1]   0%|           [time left: ?, time spent: 00:00]
Computing dot products on training dataset [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [0/1]   0%|           [time left: ?, time spent: 00:00]
Computing dot products on training dataset [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]


In [41]:
scores["all_modules"]

tensor([[ 5.0061e-02, -3.6199e-02, -1.2702e+00,  ...,  2.3542e-03,
          5.4149e-01,  9.2300e-01],
        [-4.9688e-01, -2.7395e-02, -3.8494e+00,  ...,  5.6789e-01,
         -2.4231e+01, -2.9652e+00],
        [-7.3917e-01, -2.1874e-01, -4.9690e+01,  ..., -2.0099e-01,
         -7.6540e-01, -1.0384e+00],
        ...,
        [ 6.8861e-01,  6.6297e-01,  9.4206e-01,  ..., -8.3435e-01,
          6.2277e+00,  7.0717e-02],
        [-1.9911e+00,  6.5527e-01,  6.2227e+00,  ...,  9.6123e-01,
         -2.0532e+01,  7.6572e+00],
        [ 1.5878e+00, -2.6181e+00, -5.0869e+00,  ..., -1.7254e-01,
         -3.7389e+00,  1.4460e+01]])

In [44]:
score_args = ScoreArguments(per_module_score=True)

per_module_scores = analyzer.compute_pairwise_scores(
    scores_name="per_module",
    factors_name="ekfac",
    query_dataset=query_dataset,
    train_dataset=train_dataset,
    score_args=score_args,
    per_device_query_batch_size=len(query_dataset),
    overwrite_output_dir=True,
)

Computing pairwise influence scores [0/1]   0%|           [time left: ?, time spent: 00:00]
Computing dot products on training dataset [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [0/1]   0%|           [time left: ?, time spent: 00:00]
Computing dot products on training dataset [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]


In [45]:
per_module_scores

{'0': tensor([[-2.6266e-02, -1.8842e-02,  1.1664e-01,  ...,  2.4500e-03,
           1.3508e-01, -4.2287e-02],
         [ 3.0197e-01, -1.4576e-01,  3.2981e-01,  ..., -1.7151e-01,
           2.1003e+00, -3.2144e+00],
         [ 1.1908e-02, -1.3091e-02, -3.8957e+00,  ...,  4.9687e-03,
          -1.2466e+00, -7.2670e-03],
         ...,
         [ 8.1234e-01, -6.6214e-01, -7.0342e-02,  ..., -2.6749e-01,
           6.9416e-02,  2.1124e-02],
         [ 1.4178e+00,  3.2288e+00, -9.8175e-01,  ...,  1.6657e-01,
          -4.7388e+00, -3.4997e+00],
         [-2.0218e-01, -1.7329e+00, -3.7108e+00,  ..., -3.1222e-01,
           1.6311e+00,  7.7284e+00]]),
 '2': tensor([[  0.1232,  -0.0557,  -0.7605,  ...,   0.1088,   0.3196,   0.7115],
         [ -1.1788,   0.1449,  -2.5917,  ...,   0.3668,  -2.0784,   1.3429],
         [ -0.1844,  -0.1768, -23.9932,  ...,   0.3725,  -0.4919,  -1.0844],
         ...,
         [ -0.3336,   0.6093,   0.3979,  ...,  -0.9447,   3.0782,  -0.2292],
         [ -2.0491,  -

In [47]:
per_module_scores["2"].shape

torch.Size([103, 927])

In [50]:
self_scores = analyzer.compute_self_scores(
    scores_name="self",
    factors_name="ekfac",
    train_dataset=train_dataset,
    per_device_train_batch_size=len(query_dataset),
    overwrite_output_dir=True,
)

Computing self-influence scores [9/9] 100%|██████████ [time left: 00:00, time spent: 00:00]


In [53]:
self_scores["all_modules"].shape

torch.Size([927])

## Counterfactual

In [56]:
small_query_dataset = torch.utils.data.Subset(query_dataset, list(range(10)))

In [57]:
len(small_query_dataset)

10

In [58]:
evaluate(model=model, dataset=small_query_dataset, batch_size=eval_batch_size)

0.10985381603240967

In [59]:
scores = analyzer.compute_pairwise_scores(
    scores_name="counterfactual",
    factors_name="ekfac",
    query_dataset=small_query_dataset,
    train_dataset=train_dataset,
    per_device_query_batch_size=len(small_query_dataset),
    overwrite_output_dir=True,
)

Computing pairwise influence scores [0/1]   0%|           [time left: ?, time spent: 00:00]
Computing dot products on training dataset [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [0/1]   0%|           [time left: ?, time spent: 00:00]
Computing dot products on training dataset [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise influence scores [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]


In [62]:
summed_scores = scores["all_modules"].sum(dim=0)

In [64]:
torch.topk(summed_scores, 3).indices

tensor([647, 503, 326])

In [65]:
def get_top_k_indices(current_score, top_k=1):
    return torch.topk(current_score, top_k).indices

In [66]:
get_top_k_indices(summed_scores, top_k=10)

tensor([647, 503, 326, 256, 217,  36, 221, 550, 288, 240])

In [74]:
def get_keep_indices(remove_indices):
    remove_indices = [tensor.item() for tensor in remove_indices]
    return list(set(list(range(len(train_dataset)))) - set(remove_indices))

In [76]:
keep_indices = get_top_k_indices(summed_scores, top_k=10)

{0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,


In [73]:
set(get_top_k_indices(summed_scores, top_k=10))

{tensor(36),
 tensor(217),
 tensor(221),
 tensor(240),
 tensor(256),
 tensor(288),
 tensor(326),
 tensor(503),
 tensor(550),
 tensor(647)}

In [77]:
def train_and_evaluate(current_train_dataset, current_eval_dataset):
    current_model = train(
        dataset=current_train_dataset,
        batch_size=train_batch_size,
        num_train_epochs=num_train_epochs,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )
    return evaluate(model=current_model, dataset=current_eval_dataset, batch_size=eval_batch_size)

In [78]:
train_and_evaluate(train_dataset, current_eval_dataset=small_query_dataset)

Epoch 0: 100%|██████████| 28/28 [00:00<00:00, 585.87batch/s, loss=0.911]
Epoch 1: 100%|██████████| 28/28 [00:00<00:00, 638.57batch/s, loss=0.603]
Epoch 2: 100%|██████████| 28/28 [00:00<00:00, 648.91batch/s, loss=0.42]
Epoch 3: 100%|██████████| 28/28 [00:00<00:00, 635.11batch/s, loss=0.346]
Epoch 4: 100%|██████████| 28/28 [00:00<00:00, 629.13batch/s, loss=0.296]
Epoch 5: 100%|██████████| 28/28 [00:00<00:00, 619.84batch/s, loss=0.254]
Epoch 6: 100%|██████████| 28/28 [00:00<00:00, 576.69batch/s, loss=0.221]
Epoch 7: 100%|██████████| 28/28 [00:00<00:00, 671.24batch/s, loss=0.22]
Epoch 8: 100%|██████████| 28/28 [00:00<00:00, 635.33batch/s, loss=0.204]
Epoch 9: 100%|██████████| 28/28 [00:00<00:00, 637.90batch/s, loss=0.175]
Epoch 10: 100%|██████████| 28/28 [00:00<00:00, 633.97batch/s, loss=0.169]
Epoch 11: 100%|██████████| 28/28 [00:00<00:00, 661.69batch/s, loss=0.164]
Epoch 12: 100%|██████████| 28/28 [00:00<00:00, 651.29batch/s, loss=0.157]
Epoch 13: 100%|██████████| 28/28 [00:00<00:00, 666

0.07795017957687378

In [79]:
num_iter = 5
base_loss = 0
for _ in range(num_iter):
    base_loss += train_and_evaluate(current_train_dataset=train_dataset, current_eval_dataset=small_query_dataset)
base_loss /= num_iter

Epoch 0: 100%|██████████| 28/28 [00:00<00:00, 626.29batch/s, loss=0.916]
Epoch 1: 100%|██████████| 28/28 [00:00<00:00, 605.61batch/s, loss=0.577]
Epoch 2: 100%|██████████| 28/28 [00:00<00:00, 658.10batch/s, loss=0.399]
Epoch 3: 100%|██████████| 28/28 [00:00<00:00, 642.07batch/s, loss=0.346]
Epoch 4: 100%|██████████| 28/28 [00:00<00:00, 646.28batch/s, loss=0.299]
Epoch 5: 100%|██████████| 28/28 [00:00<00:00, 640.60batch/s, loss=0.257]
Epoch 6: 100%|██████████| 28/28 [00:00<00:00, 609.95batch/s, loss=0.23]
Epoch 7: 100%|██████████| 28/28 [00:00<00:00, 644.70batch/s, loss=0.205]
Epoch 8: 100%|██████████| 28/28 [00:00<00:00, 672.63batch/s, loss=0.191]
Epoch 9: 100%|██████████| 28/28 [00:00<00:00, 643.22batch/s, loss=0.188]
Epoch 10: 100%|██████████| 28/28 [00:00<00:00, 577.76batch/s, loss=0.191]
Epoch 11: 100%|██████████| 28/28 [00:00<00:00, 639.78batch/s, loss=0.158]
Epoch 12: 100%|██████████| 28/28 [00:00<00:00, 659.67batch/s, loss=0.157]
Epoch 13: 100%|██████████| 28/28 [00:00<00:00, 65

In [80]:
base_loss

0.1237425720691681

In [84]:
top_indices = get_top_k_indices(summed_scores, top_k=50)
keep_indices = get_keep_indices(top_indices)

new_loss = 0
for _ in range(num_iter):
    new_loss += train_and_evaluate(current_train_dataset=torch.utils.data.Subset(train_dataset, keep_indices), current_eval_dataset=small_query_dataset)
new_loss /= num_iter

Epoch 0: 100%|██████████| 27/27 [00:00<00:00, 570.47batch/s, loss=0.926]
Epoch 1: 100%|██████████| 27/27 [00:00<00:00, 567.78batch/s, loss=0.569]
Epoch 2: 100%|██████████| 27/27 [00:00<00:00, 570.60batch/s, loss=0.386]
Epoch 3: 100%|██████████| 27/27 [00:00<00:00, 584.24batch/s, loss=0.335]
Epoch 4: 100%|██████████| 27/27 [00:00<00:00, 543.31batch/s, loss=0.276]
Epoch 5: 100%|██████████| 27/27 [00:00<00:00, 611.63batch/s, loss=0.239]
Epoch 6: 100%|██████████| 27/27 [00:00<00:00, 579.03batch/s, loss=0.213]
Epoch 7: 100%|██████████| 27/27 [00:00<00:00, 587.25batch/s, loss=0.193]
Epoch 8: 100%|██████████| 27/27 [00:00<00:00, 632.51batch/s, loss=0.175]
Epoch 9: 100%|██████████| 27/27 [00:00<00:00, 581.18batch/s, loss=0.177]
Epoch 10: 100%|██████████| 27/27 [00:00<00:00, 577.28batch/s, loss=0.158]
Epoch 11: 100%|██████████| 27/27 [00:00<00:00, 419.68batch/s, loss=0.148]
Epoch 12: 100%|██████████| 27/27 [00:00<00:00, 531.87batch/s, loss=0.157]
Epoch 13: 100%|██████████| 27/27 [00:00<00:00, 5

In [85]:
new_loss

0.35664366722106927