Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpu support to tracincp rand projection #969

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
62 changes: 42 additions & 20 deletions captum/influence/_core/tracincp_fast_rand_proj.py
@@ -1,11 +1,15 @@
#!/usr/bin/env python3

import threading
import warnings
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
from collections import defaultdict
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union

import torch
from captum._utils.common import _format_inputs, _get_module_from_name
from captum._utils.common import _format_inputs, _get_module_from_name, _sort_key_list
from captum._utils.gradient import _gather_distributed_tensors
from captum._utils.progress import progress

from captum.influence._core.tracincp import (
_influence_route_to_helpers,
KMostInfluentialResults,
Expand All @@ -25,19 +29,10 @@
NearestNeighbors,
)
from captum.log import log_usage
from torch import Tensor
from torch import device, Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset

layer_inputs = []


def _capture_inputs(layer: Module, input: Tensor, output: Tensor) -> None:
r"""Save activations into layer.activations in forward pass"""

layer_inputs.append(input[0].detach())


r"""
Implements abstract DataInfluence class and also provides implementation details for
influence computation based on the logic provided in TracIn paper
Expand Down Expand Up @@ -713,10 +708,26 @@ def _basic_computation_tracincp_fast(
targets (tensor): If computing influence scores on a loss function,
these are the labels corresponding to the batch `inputs`.
"""
global layer_inputs
layer_inputs = []
layer_inputs: Dict[device, Tuple[Tensor, ...]] = defaultdict()
lock = threading.Lock()

def hook_wrapper(original_module):
def _capture_inputs(layer, input, output) -> None:
r"""Save activations into layer_inputs in forward pass"""
with lock:
is_eval_tuple = isinstance(input, tuple)
if is_eval_tuple:
layer_inputs_val = tuple(inp.detach() for inp in input)
else:
layer_inputs_val = input.detach()
layer_inputs[layer_inputs_val[0].device] = layer_inputs_val

return _capture_inputs

assert isinstance(influence_instance.final_fc_layer, Module)
handle = influence_instance.final_fc_layer.register_forward_hook(_capture_inputs)
handle = influence_instance.final_fc_layer.register_forward_hook(
hook_wrapper(influence_instance.final_fc_layer)
)
out = influence_instance.model(*inputs)

assert influence_instance.loss_fn is not None, "loss function is required"
Expand All @@ -732,7 +743,16 @@ def _basic_computation_tracincp_fast(
influence_instance.reduction_type,
)
handle.remove()
_layer_inputs = layer_inputs[0]

device_ids = cast(
Union[None, List[int]],
influence_instance.model.device_ids
if hasattr(influence_instance.model, "device_ids")
else None,
)
key_list = _sort_key_list(list(layer_inputs.keys()), device_ids)

_layer_inputs = _gather_distributed_tensors(layer_inputs, key_list=key_list)[0]

assert len(input_jacobians.shape) == 2

Expand Down Expand Up @@ -1242,6 +1262,7 @@ def _set_projections_tracincp_fast_rand_proj(
layer_input_dim = batch_layer_inputs.shape[
1
] # this is the dimension of the input of the last fully-connected layer
device = batch_jacobians.device

# choose projection if needed
# without projection, the dimension of the intermediate quantities returned
Expand Down Expand Up @@ -1270,7 +1291,9 @@ def _set_projections_tracincp_fast_rand_proj(
1.0 / layer_input_projection_dim**0.5,
)

projection_quantities = jacobian_projection, layer_input_projection
projection_quantities = jacobian_projection.to(
device
), layer_input_projection.to(device)

return projection_quantities

Expand Down Expand Up @@ -1341,9 +1364,8 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj(
# each element in this list will be of shape (batch_size, projection_dim)
checkpoint_projections: List[Any] = [[] for _ in self.checkpoints]

if projection_quantities is None:
project = False
else:
project = False
if projection_quantities is not None:
project = True
jacobian_projection, layer_input_projection = projection_quantities

Expand Down
26 changes: 12 additions & 14 deletions captum/influence/_utils/common.py
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
from captum._utils.progress import progress

from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -55,7 +56,6 @@ def _gradient_dot_product(
total = _tensor_batch_dot(*next(iterator))
for input_grad, src_grad in iterator:
total += _tensor_batch_dot(input_grad, src_grad)
total = torch.Tensor(total)

return total

Expand Down Expand Up @@ -141,9 +141,7 @@ def _jacobian_loss_wrt_inputs(
return input_jacobians


def _load_flexible_state_dict(
model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None
) -> int:
def _load_flexible_state_dict(model: Module, path: str) -> float:
r"""
Helper to load pytorch models. This function attempts to find compatibility for
loading models that were trained on different devices / with DataParallel but are
Expand All @@ -156,21 +154,15 @@ def _load_flexible_state_dict(
Args:
model: The model for which to load a checkpoint
path: The filepath to the checkpoint
keyname: The key under which the model state_dict is stored, if any.

The module state_dict is modified in-place, and the learning rate is returned.
"""

device = device_ids

checkpoint = torch.load(path, map_location=device)
checkpoint = torch.load(path)

learning_rate = checkpoint.get("learning_rate", 1)
learning_rate = checkpoint.get("learning_rate", 1.0)
# can get learning rate from optimizer state_dict?

if keyname is not None:
checkpoint = checkpoint[keyname]

if "module." in next(iter(checkpoint)):
if isinstance(model, nn.DataParallel):
model.load_state_dict(checkpoint)
Expand Down Expand Up @@ -288,9 +280,15 @@ def _get_k_most_influential_helper(
num_instances_processed += batch_size

# combine the top-k for the batch with those for previously seen batches
topk_indices = torch.cat([topk_indices, batch_topk_indices], dim=1)
topk_indices = torch.cat(
[topk_indices.to(batch_topk_indices.device), batch_topk_indices], dim=1
)
topk_tracin_scores = torch.cat(
[topk_tracin_scores, batch_topk_tracin_scores], dim=1
[
topk_tracin_scores.to(batch_topk_tracin_scores.device),
batch_topk_tracin_scores,
],
dim=1,
)

# retain only the top-k in terms of tracin_scores
Expand Down
Expand Up @@ -18,59 +18,82 @@


class TestTracInGetKMostInfluential(BaseTest):
"""
This test constructs a random BasicLinearNet, and checks that the proponents
obtained by calling `influence` and sorting are equal to the proponents
obtained by calling `_get_k_most_influential`. Those calls are made through
the calls to wrapper method `influence`.
"""

use_gpu_list = (
[True, False]
if torch.cuda.is_available() and torch.cuda.device_count() != 0
else [False]
)

param_list = []
for (batch_size, k) in [(4, 7), (7, 4), (40, 5), (5, 40), (40, 45)]:
for unpack_inputs in [True, False]:
for proponents in [True, False]:
for use_gpu in use_gpu_list:
for reduction, constr in [
("none", DataInfluenceConstructor(TracInCP)),
(
"sum",
DataInfluenceConstructor(
TracInCP,
name="TracInCPFastRandProjTests",
sample_wise_grads_per_batch=True,
),
),
("sum", DataInfluenceConstructor(TracInCPFast)),
("sum", DataInfluenceConstructor(TracInCPFastRandProj)),
("mean", DataInfluenceConstructor(TracInCPFast)),
("mean", DataInfluenceConstructor(TracInCPFastRandProj)),
]:
if not (
"sample_wise_grads_per_batch" in constr.kwargs
and constr.kwargs["sample_wise_grads_per_batch"]
and use_gpu
):
param_list.append(
(
reduction,
constr,
unpack_inputs,
proponents,
batch_size,
k,
use_gpu,
)
)

@parameterized.expand(
[
(reduction, constr, unpack_inputs, proponents, batch_size, k)
# calls test helper method `test_tracin_get_k_most_influential` for several
# combinations of `batch_size` and `k`. This is important because the
# behavior of `_get_k_most_influential` depends on whether `k` is larger
# than `batch_size`.
for (batch_size, k) in [(4, 7), (7, 4), (40, 5), (5, 40), (40, 45)]
for unpack_inputs in [True, False]
for proponents in [True, False]
for reduction, constr in [
("none", DataInfluenceConstructor(TracInCP)),
(
"sum",
DataInfluenceConstructor(
TracInCP,
name="TracInCPFastRandProjTests",
sample_wise_grads_per_batch=True,
),
),
("sum", DataInfluenceConstructor(TracInCPFast)),
("sum", DataInfluenceConstructor(TracInCPFastRandProj)),
("mean", DataInfluenceConstructor(TracInCPFast)),
("mean", DataInfluenceConstructor(TracInCPFastRandProj)),
]
],
param_list,
name_func=build_test_name_func(),
)
def test_tracin_get_k_most_influential(
def test_tracin_k_most_influential(
self,
reduction: str,
tracin_constructor: Callable,
unpack_inputs: bool,
proponents: bool,
batch_size: int,
k: int,
use_gpu: bool,
) -> None:

"""
This test constructs a random BasicLinearNet, and checks that the proponents
obtained by calling `influence` and sorting are equal to the proponents
obtained by calling `_k_most_influential`. Those calls are made through
the calls to wrapper method `influence`.
"""
with tempfile.TemporaryDirectory() as tmpdir:

(
net,
train_dataset,
test_samples,
test_labels,
) = get_random_model_and_data(tmpdir, unpack_inputs, return_test_data=True)
) = get_random_model_and_data(
tmpdir,
unpack_inputs,
True,
use_gpu,
)

self.assertTrue(isinstance(reduction, str))
self.assertTrue(callable(tracin_constructor))
Expand Down