-
Notifications
You must be signed in to change notification settings - Fork 492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gpu support to tracincp rand projection #969
Conversation
c846d1d
to
b652271
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great :) !
tests/influence/_utils/common.py
Outdated
return ( | ||
wrap_model_in_dataparallel(net) if use_gpu else net, | ||
dataset, | ||
[s.cuda() for s in test_samples] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Maybe a helper function for moving to GPU for either single tensor or list ([elem.cuda() for elem in val] if isinstance(val, list) else val.cuda()
) might be helpful to make the if elses here more concise as well as the use-cases above?
@@ -568,10 +563,28 @@ def _basic_computation_tracincp_fast( | |||
targets (tensor): If computing influence scores on a loss function, | |||
these are the labels corresponding to the batch `inputs`. | |||
""" | |||
global layer_inputs | |||
layer_inputs = [] | |||
layer_inputs: Dict[Module, Dict[device, Tuple[Tensor, ...]]] = defaultdict(dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Looks like we only use this dictionary for results of one module, could instead just make this Dict[device, Tuple[Tensor, ...]] ?
Seems like extract_device_ids expects this Dict of Dict structure, but can probably just be replaced with getting the device_id attribute of the model if available.
@NarineK has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
b4bc0fa
to
a1bcfba
Compare
@NarineK has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@NarineK has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@NarineK has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Adds gpu support to tracincp rand projection.
Cleaned up un-passed args to _load_flexible_state_dict