Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpu support to tracincp rand projection #969

Closed
wants to merge 27 commits into from

Conversation

NarineK
Copy link
Contributor

@NarineK NarineK commented Jun 6, 2022

Adds gpu support to tracincp rand projection.
Cleaned up un-passed args to _load_flexible_state_dict

Copy link
Contributor

@vivekmig vivekmig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great :) !

return (
wrap_model_in_dataparallel(net) if use_gpu else net,
dataset,
[s.cuda() for s in test_samples]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe a helper function for moving to GPU for either single tensor or list ([elem.cuda() for elem in val] if isinstance(val, list) else val.cuda()) might be helpful to make the if elses here more concise as well as the use-cases above?

@@ -568,10 +563,28 @@ def _basic_computation_tracincp_fast(
targets (tensor): If computing influence scores on a loss function,
these are the labels corresponding to the batch `inputs`.
"""
global layer_inputs
layer_inputs = []
layer_inputs: Dict[Module, Dict[device, Tuple[Tensor, ...]]] = defaultdict(dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Looks like we only use this dictionary for results of one module, could instead just make this Dict[device, Tuple[Tensor, ...]] ?

Seems like extract_device_ids expects this Dict of Dict structure, but can probably just be replaced with getting the device_id attribute of the model if available.

@facebook-github-bot
Copy link
Contributor

@NarineK has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@NarineK NarineK force-pushed the add_gpu_support_to_tracincp branch from b4bc0fa to a1bcfba Compare August 7, 2022 05:37
@facebook-github-bot
Copy link
Contributor

@NarineK has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@NarineK has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@NarineK has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants