Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to run LabelModel on GPU #1466

Merged
merged 4 commits into from
Sep 20, 2019
Merged

Option to run LabelModel on GPU #1466

merged 4 commits into from
Sep 20, 2019

Conversation

paroma
Copy link
Contributor

@paroma paroma commented Sep 20, 2019

Description of proposed changes

Fixed some parameter definitions to allow LabelModel to run on GPU.

Related issue(s)

Fixes #1430

Test plan

  • Edited existing tests to work with change in parameter definition
  • No new tests added. Tested locally by adding device='cuda' for LabelModel tests.

Checklist

Need help on these? Just ask!

  • I have read the CONTRIBUTING document.
  • I have updated the documentation accordingly.
  • I have added tests to cover my changes.
  • I have run tox -e complex and/or tox -e spark if appropriate.
  • All new and existing tests passed.

@codecov
Copy link

codecov bot commented Sep 20, 2019

Codecov Report

Merging #1466 into master will increase coverage by <.01%.
The diff coverage is 100%.

@@            Coverage Diff             @@
##           master    #1466      +/-   ##
==========================================
+ Coverage   97.58%   97.59%   +<.01%     
==========================================
  Files          55       55              
  Lines        2032     2034       +2     
  Branches      334      334              
==========================================
+ Hits         1983     1985       +2     
  Misses         22       22              
  Partials       27       27
Impacted Files Coverage Δ
snorkel/labeling/model/label_model.py 95.83% <100%> (+0.02%) ⬆️

Copy link
Member

@bhancock8 bhancock8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved assuming there's a good answer for the clone() question.

@@ -334,7 +336,7 @@ def get_conditional_probs(self) -> np.ndarray:
np.ndarray
An [m, k + 1, k] np.ndarray conditional probabilities table.
"""
return self._get_conditional_probs(self.mu.detach().numpy())
return self._get_conditional_probs(self.mu.clone().cpu().detach().numpy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is clone() necessary here but not the other places where you're adding .cpu().detach()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, probably added during debugging. removed

@@ -750,7 +752,7 @@ def _count_accurate_lfs(self, mu: np.ndarray) -> int:
int
Number of LFs better than random
"""
P = self.P.numpy()
P = self.P.clone().cpu().detach().numpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same Q.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@@ -95,7 +95,9 @@ def test_generate_O(self):
[1 / 4, 0, 0, 1 / 4, 0, 1 / 4],
]
)
np.testing.assert_array_almost_equal(label_model.O.numpy(), true_O)
np.testing.assert_array_almost_equal(
label_model.O.cpu().detach().numpy(), true_O
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels very strange to see `label_model.O', but I know that's not part of this PR's scope.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error when using LabelModel with GPU
2 participants