Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instance/node dropout #44

Merged
merged 20 commits into from
May 15, 2024
Merged

Instance/node dropout #44

merged 20 commits into from
May 15, 2024

Conversation

aaprasad
Copy link
Contributor

@aaprasad aaprasad commented May 14, 2024

As per #42, there is some correlation between the instance and nodes disappearing and model failures. Here we add a Node/Instance dropout augmentation as well as the ability to handle missing nodes when only using single node crops in different ways such as using a zeros mask, dropping the instance, or defaulting to the centroid

Summary by CodeRabbit

  • New Features

    • Introduced node dropout augmentation, allowing for more robust training by randomly dropping nodes during training.
    • Added support for handling missing data using the "centroid" method in datasets.
  • Enhancements

    • Improved augmentation handling logic to better manage training conditions.
    • Expanded dataset modes to include 'test' in addition to 'train' and 'val'.
  • Bug Fixes

    • Adjusted initialization conditions in various classes to ensure proper handling of bounding boxes and centroids.
  • Tests

    • Added new test cases for node dropout functionality to ensure robustness.

Copy link
Contributor

coderabbitai bot commented May 14, 2024

Walkthrough

The recent updates enhance the handling of augmentations and missing data in datasets, introduce a new NodeDropout class for node dropout augmentation, and refine the initialization logic in several classes. Additionally, new test cases have been added to ensure these functionalities work as intended.

Changes

File Change Summary
biogtr/data_structures.py Modified __init__ method in SomeClass to handle _bbox and centroid conditions.
biogtr/datasets/base_dataset.py Added instance_dropout and node_dropout attributes; adjusted augmentations assignment logic.
biogtr/datasets/sleap_dataset.py Changed mode and augmentations attributes; updated __init__ method to include handle_missing.
biogtr/datasets/data_utils.py Introduced NodeDropout class with methods for node dropout augmentation.
tests/test_models.py Removed times parameter from test_multianchor_embedding().
tests/test_datasets.py Added import for NodeDropout and a new test case for node dropout functionality.

In code's realm where data flows,
New nodes drop and structure grows.
Augmentations now refined,
Handling modes of every kind.
Tests ensure we're on the right track,
With every change, there's no lack.
🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@aaprasad aaprasad marked this pull request as ready for review May 14, 2024 23:13
@aaprasad aaprasad requested a review from talmo May 14, 2024 23:13
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

Out of diff range and nitpick comments (1)
biogtr/datasets/sleap_dataset.py (1)

Line range hint 30-55: Improve the readability of the docstring for the handle_missing parameter by formatting it as a list.

- handle_missing: how to handle missing single nodes. one of `["drop", "ignore", "centroid"]`.
-                 if "drop" then we dont include instances which are missing the `anchor`.
-                 if "ignore" then we use a mask instead of a crop and nan centroids/bboxes.
-                 if "centroid" then we default to the pose centroid as the node to crop around.
+ handle_missing: How to handle missing single nodes. One of:
+     - "drop": Do not include instances which are missing the `anchor`.
+     - "ignore": Use a mask instead of a crop and NaN centroids/bboxes.
+     - "centroid": Default to the pose centroid as the node to crop around.
Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 3d6bf9a and af67669.
Files selected for processing (4)
  • biogtr/data_structures.py (1 hunks)
  • biogtr/datasets/base_dataset.py (1 hunks)
  • biogtr/datasets/sleap_dataset.py (10 hunks)
  • tests/test_models.py (1 hunks)
Additional Context Used
Ruff (8)
biogtr/datasets/sleap_dataset.py (1)

9-9: warnings imported but unused

tests/test_models.py (7)

302-302: Local variable times is assigned to but never used


335-335: Redefinition of unused test_multianchor_embedding from line 292


338-338: Local variable d_model is assigned to but never used


345-345: Local variable times is assigned to but never used


368-368: Redefinition of unused test_multianchor_embedding from line 335


378-378: Local variable times is assigned to but never used


501-501: Local variable img_shape is assigned to but never used

Additional comments not posted (1)
biogtr/datasets/sleap_dataset.py (1)

87-87: Initialization of the handle_missing attribute looks good.

Comment on lines 84 to 86
if self._bbox.shape[0] and len(self._bbox.shape) == 1:
self._bbox = self._bbox.unsqueeze(0) # (n_anchors, 4)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant code block.

The code block from lines 84 to 86 is redundant as it duplicates the logic already present earlier in the __init__ method. This can be safely removed to avoid unnecessary repetition.

-        if self._bbox.shape[0] and len(self._bbox.shape) == 1:
-            self._bbox = self._bbox.unsqueeze(0)  # (n_anchors, 4)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if self._bbox.shape[0] and len(self._bbox.shape) == 1:
self._bbox = self._bbox.unsqueeze(0) # (n_anchors, 4)

Comment on lines 87 to 89
if self._bbox.shape[1] and len(self._bbox.shape) == 2:
self._bbox = self._bbox.unsqueeze(0) # (1, n_anchors, 4)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant code block.

The code block from lines 87 to 89 is redundant as it duplicates the logic already present earlier in the __init__ method. This can be safely removed to avoid unnecessary repetition.

-        if self._bbox.shape[1] and len(self._bbox.shape) == 2:
-            self._bbox = self._bbox.unsqueeze(0)  # (1, n_anchors, 4)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if self._bbox.shape[1] and len(self._bbox.shape) == 2:
self._bbox = self._bbox.unsqueeze(0) # (1, n_anchors, 4)

Comment on lines 90 to 92
if centroid is not None:
self._centroid = centroid

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant code block.

The code block from lines 90 to 92 is redundant as it duplicates the logic already present earlier in the __init__ method. This can be safely removed to avoid unnecessary repetition.

-        if centroid is not None:
-            self._centroid = centroid

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if centroid is not None:
self._centroid = centroid

Comment on lines 93 to 96
elif self.bbox.shape[1]:
y1, x1, y2, x2 = self.bbox.squeeze(dim=0).nanmean(dim=0)
self._centroid = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant code block.

The code block from lines 93 to 96 is redundant as it duplicates the logic already present earlier in the __init__ method. This can be safely removed to avoid unnecessary repetition.

-        elif self.bbox.shape[1]:
-            y1, x1, y2, x2 = self.bbox.squeeze(dim=0).nanmean(dim=0)
-            self._centroid = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
elif self.bbox.shape[1]:
y1, x1, y2, x2 = self.bbox.squeeze(dim=0).nanmean(dim=0)
self._centroid = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}

Comment on lines 97 to 99
else:
self._centroid = {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant code block.

The code block from lines 97 to 99 is redundant as it duplicates the logic already present earlier in the __init__ method. This can be safely removed to avoid unnecessary repetition.

-        else:
-            self._centroid = {}

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
else:
self._centroid = {}

Comment on lines 180 to 182
if len(img.shape) == 2:
img = img.expand_dims(0)
h, w, c = img.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use np.expand_dims instead of expand_dims to avoid potential errors.

- img = img.expand_dims(0)
+ img = np.expand_dims(img, 0)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if len(img.shape) == 2:
img = img.expand_dims(0)
h, w, c = img.shape
if len(img.shape) == 2:
img = np.expand_dims(img, 0)
h, w, c = img.shape

biogtr/datasets/sleap_dataset.py Outdated Show resolved Hide resolved
Comment on lines 335 to 366
def test_multianchor_embedding():
frames = 32
objects = 10
d_model = 256
n_anchors = 10
features = 256

N = frames * objects

boxes = torch.rand(size=(N, n_anchors, 4))
times = torch.rand(size=(N,))

fixed_emb = Embedding("pos", "fixed", features=features, n_mlp_layers=3)
learned_emb = Embedding(
"pos", "learned", features=features, n_mlp_layers=3, n_points=n_anchors
)
assert not isinstance(fixed_emb.mlp, torch.nn.Identity)
assert not isinstance(learned_emb.mlp, torch.nn.Identity)

emb = fixed_emb(boxes)
assert emb.size() == (N, features)

emb = learned_emb(boxes)
assert emb.size() == (N, features)

fixed_emb = Embedding("pos", "fixed", features=features, n_mlp_layers=0)
learned_emb = Embedding("pos", "learned", features=features, n_mlp_layers=0)
with pytest.raises(RuntimeError):
_ = fixed_emb(boxes)
with pytest.raises(RuntimeError):
_ = learned_emb(boxes)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consolidate the redefinitions of the test_multianchor_embedding function to avoid redundancy.

- def test_multianchor_embedding():
-     frames = 32
-     objects = 10
-     d_model = 256
-     n_anchors = 10
-     features = 256
- 
-     N = frames * objects
- 
-     boxes = torch.rand(size=(N, n_anchors, 4))
-     times = torch.rand(size=(N,))
- 
-     fixed_emb = Embedding("pos", "fixed", features=features, n_mlp_layers=3)
-     learned_emb = Embedding(
-         "pos", "learned", features=features, n_mlp_layers=3, n_points=n_anchors
-     )
-     assert not isinstance(fixed_emb.mlp, torch.nn.Identity)
-     assert not isinstance(learned_emb.mlp, torch.nn.Identity)
- 
-     emb = fixed_emb(boxes)
-     assert emb.size() == (N, features)
- 
-     emb = learned_emb(boxes)
-     assert emb.size() == (N, features)
- 
-     fixed_emb = Embedding("pos", "fixed", features=features, n_mlp_layers=0)
-     learned_emb = Embedding("pos", "learned", features=features, n_mlp_layers=0)
-     with pytest.raises(RuntimeError):
-         _ = fixed_emb(boxes)
-     with pytest.raises(RuntimeError):
-         _ = learned_emb(boxes)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def test_multianchor_embedding():
frames = 32
objects = 10
d_model = 256
n_anchors = 10
features = 256
N = frames * objects
boxes = torch.rand(size=(N, n_anchors, 4))
times = torch.rand(size=(N,))
fixed_emb = Embedding("pos", "fixed", features=features, n_mlp_layers=3)
learned_emb = Embedding(
"pos", "learned", features=features, n_mlp_layers=3, n_points=n_anchors
)
assert not isinstance(fixed_emb.mlp, torch.nn.Identity)
assert not isinstance(learned_emb.mlp, torch.nn.Identity)
emb = fixed_emb(boxes)
assert emb.size() == (N, features)
emb = learned_emb(boxes)
assert emb.size() == (N, features)
fixed_emb = Embedding("pos", "fixed", features=features, n_mlp_layers=0)
learned_emb = Embedding("pos", "learned", features=features, n_mlp_layers=0)
with pytest.raises(RuntimeError):
_ = fixed_emb(boxes)
with pytest.raises(RuntimeError):
_ = learned_emb(boxes)

Comment on lines 368 to 409
def test_multianchor_embedding():
frames = 32
objects = 10
d_model = 256
n_anchors = 10
features = 256

N = frames * objects

boxes = torch.rand(size=(N, n_anchors, 4))
times = torch.rand(size=(N,))

fixed_emb = Embedding(
"pos",
"fixed",
features=features,
n_points=n_anchors,
mlp_cfg={"num_layers": 3, "hidden_dim": 2 * d_model},
)
learned_emb = Embedding(
"pos",
"learned",
features=features,
n_points=n_anchors,
mlp_cfg={"num_layers": 3, "hidden_dim": 2 * d_model},
)
assert not isinstance(fixed_emb.mlp, torch.nn.Identity)
assert not isinstance(learned_emb.mlp, torch.nn.Identity)

emb = fixed_emb(boxes)
assert emb.size() == (N, features)

emb = learned_emb(boxes)
assert emb.size() == (N, features)

fixed_emb = Embedding("pos", "fixed", features=features)
learned_emb = Embedding("pos", "learned", features=features)
with pytest.raises(RuntimeError):
_ = fixed_emb(boxes)
with pytest.raises(RuntimeError):
_ = learned_emb(boxes)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the redundant redefinition of the test_multianchor_embedding function.

- def test_multianchor_embedding():
-     frames = 32
-     objects = 10
-     d_model = 256
-     n_anchors = 10
-     features = 256
- 
-     N = frames * objects
- 
-     boxes = torch.rand(size=(N, n_anchors, 4))
-     times = torch.rand(size=(N,))
- 
-     fixed_emb = Embedding(
-         "pos",
-         "fixed",
-         features=features,
-         n_points=n_anchors,
-         mlp_cfg={"num_layers": 3, "hidden_dim": 2 * d_model},
-     )
-     learned_emb = Embedding(
-         "pos",
-         "learned",
-         features=features,
-         n_points=n_anchors,
-         mlp_cfg={"num_layers": 3, "hidden_dim": 2 * d_model},
-     )
-     assert not isinstance(fixed_emb.mlp, torch.nn.Identity)
-     assert not isinstance(learned_emb.mlp, torch.nn.Identity)
- 
-     emb = fixed_emb(boxes)
-     assert emb.size() == (N, features)
- 
-     emb = learned_emb(boxes)
-     assert emb.size() == (N, features)
- 
-     fixed_emb = Embedding("pos", "fixed", features=features)
-     learned_emb = Embedding("pos", "learned", features=features)
-     with pytest.raises(RuntimeError):
-         _ = fixed_emb(boxes)
-     with pytest.raises(RuntimeError):
-         _ = learned_emb(boxes)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def test_multianchor_embedding():
frames = 32
objects = 10
d_model = 256
n_anchors = 10
features = 256
N = frames * objects
boxes = torch.rand(size=(N, n_anchors, 4))
times = torch.rand(size=(N,))
fixed_emb = Embedding(
"pos",
"fixed",
features=features,
n_points=n_anchors,
mlp_cfg={"num_layers": 3, "hidden_dim": 2 * d_model},
)
learned_emb = Embedding(
"pos",
"learned",
features=features,
n_points=n_anchors,
mlp_cfg={"num_layers": 3, "hidden_dim": 2 * d_model},
)
assert not isinstance(fixed_emb.mlp, torch.nn.Identity)
assert not isinstance(learned_emb.mlp, torch.nn.Identity)
emb = fixed_emb(boxes)
assert emb.size() == (N, features)
emb = learned_emb(boxes)
assert emb.size() == (N, features)
fixed_emb = Embedding("pos", "fixed", features=features)
learned_emb = Embedding("pos", "learned", features=features)
with pytest.raises(RuntimeError):
_ = fixed_emb(boxes)
with pytest.raises(RuntimeError):
_ = learned_emb(boxes)

N = frames * objects

boxes = torch.rand(size=(N, n_anchors, 4))
times = torch.rand(size=(N,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the unused times variable to clean up the code.

- times = torch.rand(size=(N,))

Also applies to: 378-378


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
times = torch.rand(size=(N,))

@@ -202,7 +219,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
poses.append(
dict(
zip(
[n.name for n in instance.skeleton.nodes],
[n.name.lower() for n in instance.skeleton.nodes],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why .lower() here? I think this one should probably be case sensitive TBH since it's a config-specifiable thing and we want to be precise.

Comment on lines 304 to 314
anchors_to_drop = np.random.permutation(anchors)
anchor_dropout_p = np.random.uniform(size=len(anchors_to_drop))
dropped_anchor_inds = np.where(
anchor_dropout_p < self.node_dropout["p"]
)
anchor_dropout_p = anchor_dropout_p[dropped_anchor_inds]
n_anchors_to_drop = min(self.node_dropout["n"], len(anchor_dropout_p))
dropped_anchor_inds = np.argpartition(
anchor_dropout_p, -n_anchors_to_drop
)[-n_anchors_to_drop:]
dropped_anchors = anchors_to_drop[dropped_anchor_inds]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all a bit wild and should be in its own, very heavily tested function.

There are a lot of vectorized conditionals going on and I'm not super clear on the edge case behaviors.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Out of diff range and nitpick comments (3)
tests/test_models.py (2)

Line range hint 378-378: Remove the unused times variable to clean up the code.

-    times = torch.rand(size=(N,))

Line range hint 424-424: The local variable img_shape is assigned but never used. Consider removing it to clean up the code.

-    img_shape = (1, 50, 50)
biogtr/datasets/data_utils.py (1)

Line range hint 383-383: Avoid using bare except clauses. Replace it with a more specific exception to make debugging easier.

-            except Exception as e:
+            except (ValueError, TypeError) as e:
Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between af67669 and 6b9f070.
Files selected for processing (5)
  • biogtr/datasets/base_dataset.py (1 hunks)
  • biogtr/datasets/data_utils.py (1 hunks)
  • biogtr/datasets/sleap_dataset.py (10 hunks)
  • tests/test_datasets.py (2 hunks)
  • tests/test_models.py (1 hunks)
Files skipped from review as they are similar to previous changes (1)
  • biogtr/datasets/base_dataset.py
Additional Context Used
Ruff (4)
biogtr/datasets/data_utils.py (1)

383-383: Do not use bare except

biogtr/datasets/sleap_dataset.py (1)

9-9: warnings imported but unused

tests/test_datasets.py (1)

448-449: SyntaxError: Expected ':', found newline

tests/test_models.py (1)

424-424: Local variable img_shape is assigned to but never used

Additional comments not posted (8)
tests/test_datasets.py (2)

4-4: The import statement for NodeDropout looks good.


445-446: The list of nodes is correctly defined and looks good.

biogtr/datasets/sleap_dataset.py (4)

30-30: Review the necessity of using .lower() for the handle_missing parameter. It might be better to keep it case-sensitive for precision.


52-55: The docstring update to include the handle_missing parameter looks good.


87-87: The addition of the handle_missing attribute in the __init__ method looks good.


304-308: The node dropout logic is correctly implemented and looks good.

biogtr/datasets/data_utils.py (2)

183-191: The docstring for the __call__ method is clear and well-documented. Ensure that the forward method is consistent with this documentation.


195-203: The docstring for the forward method is clear and well-documented. Ensure that the logic for dropping nodes is correctly implemented.

Comment on lines 172 to 179
class NodeDropout:
def __init__(self, p: float, n: int) -> None:
"""Initialize Node Dropout Augmentation

Args
p: the probability with which to drop the nodes
n: the maximum number of nodes to drop
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for the __init__ method is missing a colon after "Args". Add the colon for consistency and clarity.

-        Args
+        Args:

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
class NodeDropout:
def __init__(self, p: float, n: int) -> None:
"""Initialize Node Dropout Augmentation
Args
p: the probability with which to drop the nodes
n: the maximum number of nodes to drop
"""
class NodeDropout:
def __init__(self, p: float, n: int) -> None:
"""Initialize Node Dropout Augmentation
Args:
p: the probability with which to drop the nodes
n: the maximum number of nodes to drop
"""

Comment on lines 447 to 449
for p in [0, 1]:
for n in [0, 1, len(nodes)]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the syntax error by adding a colon after the for statement.

-        for n in [0, 1, len(nodes)]
+        for n in [0, 1, len(nodes)]:

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
for p in [0, 1]:
for n in [0, 1, len(nodes)]
for p in [0, 1]:
for n in [0, 1, len(nodes)]:

biogtr/datasets/sleap_dataset.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Out of diff range and nitpick comments (1)
tests/test_datasets.py (1)

Line range hint 8-8: Remove unused import get_device.

- from biogtr.models.model_utils import get_device
Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 6b9f070 and 1622722.
Files selected for processing (1)
  • tests/test_datasets.py (2 hunks)
Additional Context Used
Ruff (1)
tests/test_datasets.py (1)

8-8: biogtr.models.model_utils.get_device imported but unused

tests/test_datasets.py Show resolved Hide resolved
@aaprasad aaprasad requested a review from talmo May 15, 2024 17:42
@aaprasad aaprasad merged commit f0c6f3c into main May 15, 2024
3 checks passed
@aaprasad aaprasad deleted the aadi/instance-node-dropout branch May 15, 2024 17:57
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Out of diff range and nitpick comments (3)
biogtr/datasets/sleap_dataset.py (1)

Line range hint 9-9: Remove unused import.

The warnings module is imported but not used in this file. Removing it will clean up the code.

- import warnings
biogtr/datasets/data_utils.py (2)

172-176: Add a more detailed class-level docstring.

Consider adding a more detailed class-level docstring to explain the purpose and usage of the NodeDropout class. This will help future developers understand the context and functionality of the class.


Line range hint 390-390: Avoid using bare except.

Using a bare except can catch unexpected exceptions and make debugging difficult. Specify the exception type to catch.

- except Exception as e:
+ except (SpecificExceptionType) as e:

Replace SpecificExceptionType with the appropriate exception type you expect.

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 1622722 and 4e1021c.
Files selected for processing (3)
  • biogtr/datasets/data_utils.py (1 hunks)
  • biogtr/datasets/sleap_dataset.py (10 hunks)
  • tests/test_datasets.py (2 hunks)
Additional Context Used
Ruff (3)
biogtr/datasets/data_utils.py (1)

390-390: Do not use bare except

biogtr/datasets/sleap_dataset.py (1)

9-9: warnings imported but unused

tests/test_datasets.py (1)

8-8: biogtr.models.model_utils.get_device imported but unused

Comment on lines +178 to +186
def __init__(self, p: float, n: int) -> None:
"""Initialize Node Dropout Augmentation.

Args:
p: the probability with which to drop the nodes
n: the maximum number of nodes to drop
"""
self.n = n
self.p = p
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type hints for p and n parameters.

The __init__ method is missing type hints for the p and n parameters. Adding type hints improves code readability and helps with static analysis.

- def __init__(self, p: float, n: int) -> None:
+ def __init__(self, p: float, n: int) -> None:

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def __init__(self, p: float, n: int) -> None:
"""Initialize Node Dropout Augmentation.
Args:
p: the probability with which to drop the nodes
n: the maximum number of nodes to drop
"""
self.n = n
self.p = p
def __init__(self, p: float, n: int) -> None:
"""Initialize Node Dropout Augmentation.
Args:
p: the probability with which to drop the nodes
n: the maximum number of nodes to drop
"""
self.n = n
self.p = p

Comment on lines +188 to +197
def __call__(self, nodes: list[str]) -> list[str]:
"""Wrap `drop_nodes` to enable class call.

Args:
nodes: A list of available node names to drop.

Returns:
dropped_nodes: A list of up to `self.n` nodes to drop.
"""
return self.forward(nodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming forward method to drop_nodes.

The forward method name is not very descriptive. Consider renaming it to drop_nodes to better reflect its functionality.

- def forward(self, nodes: list[str]) -> list[str]:
+ def drop_nodes(self, nodes: list[str]) -> list[str]:

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def __call__(self, nodes: list[str]) -> list[str]:
"""Wrap `drop_nodes` to enable class call.
Args:
nodes: A list of available node names to drop.
Returns:
dropped_nodes: A list of up to `self.n` nodes to drop.
"""
return self.forward(nodes)
def __call__(self, nodes: list[str]) -> list[str]:
"""Wrap `drop_nodes` to enable class call.
Args:
nodes: A list of available node names to drop.
Returns:
dropped_nodes: A list of up to `self.n` nodes to drop.
"""
return self.drop_nodes(nodes)

Comment on lines +199 to +225
def forward(self, nodes: list[str]) -> list[str]:
"""Drop up to `n` random nodes with probability p.

Args:
nodes: A list of available node names to drop.

Returns:
dropped_nodes: A list of up to `self.n` nodes to drop.
"""
if self.n == 0 or self.p == 0:
return []

nodes_to_drop = np.random.permutation(nodes)
node_dropout_p = np.random.uniform(size=len(nodes_to_drop))

dropped_node_inds = np.where(node_dropout_p < self.p)
node_dropout_p = node_dropout_p[dropped_node_inds]

n_nodes_to_drop = min(self.n, len(node_dropout_p))

dropped_node_inds = np.argpartition(node_dropout_p, -n_nodes_to_drop)[
-n_nodes_to_drop:
]

dropped_nodes = nodes_to_drop[dropped_node_inds]

return dropped_nodes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the node dropout logic.

The node dropout logic can be simplified and optimized. Here is a suggested refactor:

- def forward(self, nodes: list[str]) -> list[str]:
-     if self.n == 0 or self.p == 0:
-         return []
- 
-     nodes_to_drop = np.random.permutation(nodes)
-     node_dropout_p = np.random.uniform(size=len(nodes_to_drop))
- 
-     dropped_node_inds = np.where(node_dropout_p < self.p)
-     node_dropout_p = node_dropout_p[dropped_node_inds]
- 
-     n_nodes_to_drop = min(self.n, len(node_dropout_p))
- 
-     dropped_node_inds = np.argpartition(node_dropout_p, -n_nodes_to_drop)[
-         -n_nodes_to_drop:
-     ]
- 
-     dropped_nodes = nodes_to_drop[dropped_node_inds]
- 
-     return dropped_nodes
+ def drop_nodes(self, nodes: list[str]) -> list[str]:
+     if self.n == 0 or self.p == 0:
+         return []
+ 
+     nodes_to_drop = np.random.choice(nodes, size=self.n, replace=False, p=[self.p]*len(nodes))
+ 
+     return nodes_to_drop

This refactor simplifies the logic by using np.random.choice to directly select the nodes to drop.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def forward(self, nodes: list[str]) -> list[str]:
"""Drop up to `n` random nodes with probability p.
Args:
nodes: A list of available node names to drop.
Returns:
dropped_nodes: A list of up to `self.n` nodes to drop.
"""
if self.n == 0 or self.p == 0:
return []
nodes_to_drop = np.random.permutation(nodes)
node_dropout_p = np.random.uniform(size=len(nodes_to_drop))
dropped_node_inds = np.where(node_dropout_p < self.p)
node_dropout_p = node_dropout_p[dropped_node_inds]
n_nodes_to_drop = min(self.n, len(node_dropout_p))
dropped_node_inds = np.argpartition(node_dropout_p, -n_nodes_to_drop)[
-n_nodes_to_drop:
]
dropped_nodes = nodes_to_drop[dropped_node_inds]
return dropped_nodes
def drop_nodes(self, nodes: list[str]) -> list[str]:
if self.n == 0 or self.p == 0:
return []
nodes_to_drop = np.random.choice(nodes, size=self.n, replace=False, p=[self.p]*len(nodes))
return nodes_to_drop

@@ -1,7 +1,7 @@
"""Test dataset logic."""

from biogtr.datasets.base_dataset import BaseDataset
from biogtr.datasets.data_utils import get_max_padding
from biogtr.datasets.data_utils import get_max_padding, NodeDropout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import.

The get_device function from biogtr.models.model_utils is imported but not used in this file. Removing it will clean up the code.

- from biogtr.models.model_utils import get_device

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from biogtr.datasets.data_utils import get_max_padding, NodeDropout
from biogtr.datasets.data_utils import get_max_padding, NodeDropout

Comment on lines +445 to +459
nodes = ["head", "nose", "tti", "tail"]

for p in [0, 1]:
for n in [0, 1, len(nodes)]:

node_dropout = NodeDropout(p=p, n=n)
dropped_nodes = node_dropout(nodes)

if p == 0:
assert len(dropped_nodes) == 0
else:
assert (
len(dropped_nodes) == n
), f"p={node_dropout.p}, n={node_dropout.n},n_dropped={len(dropped_nodes)}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the syntax error by adding a colon after the for statement.

There is a missing colon after the for statement on line 448, which will cause a syntax error.

-        for n in [0, 1, len(nodes)]
+        for n in [0, 1, len(nodes)]:

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
nodes = ["head", "nose", "tti", "tail"]
for p in [0, 1]:
for n in [0, 1, len(nodes)]:
node_dropout = NodeDropout(p=p, n=n)
dropped_nodes = node_dropout(nodes)
if p == 0:
assert len(dropped_nodes) == 0
else:
assert (
len(dropped_nodes) == n
), f"p={node_dropout.p}, n={node_dropout.n},n_dropped={len(dropped_nodes)}"
nodes = ["head", "nose", "tti", "tail"]
for p in [0, 1]:
for n in [0, 1, len(nodes)]:
node_dropout = NodeDropout(p=p, n=n)
dropped_nodes = node_dropout(nodes)
if p == 0:
assert len(dropped_nodes) == 0
else:
assert (
len(dropped_nodes) == n
), f"p={node_dropout.p}, n={node_dropout.n},n_dropped={len(dropped_nodes)}"

@@ -27,6 +27,7 @@ def __init__(
chunk: bool = True,
clip_length: int = 500,
mode: str = "train",
handle_missing: str = "centroid",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure case sensitivity for handle_missing.

The handle_missing parameter is converted to lowercase, which might not be necessary and could lead to unexpected behavior if case sensitivity is required.

- self.handle_missing = handle_missing.lower()
+ self.handle_missing = handle_missing

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
handle_missing: str = "centroid",
handle_missing: str = "centroid",
self.handle_missing = handle_missing

Comment on lines +358 to +359
if self.handle_missing == "drop" and boxes.isnan().any():
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a log message when instances are dropped due to missing anchors.

Adding a log message when instances are dropped due to missing anchors can help with debugging and understanding the dataset processing.

- if self.handle_missing == "drop" and boxes.isnan().any():
+ if self.handle_missing == "drop" and boxes.isnan().any():
+     if self.verbose:
+         print(f"Dropping instance {j} due to missing anchors.")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if self.handle_missing == "drop" and boxes.isnan().any():
continue
if self.handle_missing == "drop" and boxes.isnan().any():
if self.verbose:
print(f"Dropping instance {j} due to missing anchors.")
continue

Comment on lines +199 to +211
n_instances_dropped = 0

gt_instances = lf.instances
if self.mode == "train":
np.random.shuffle(gt_instances)

for instance in gt_instances:
if (
np.random.uniform() < self.instance_dropout["p"]
and n_instances_dropped < self.instance_dropout["n"]
):
n_instances_dropped += 1
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine conditions for instance dropout.

Combining the conditions for instance dropout can improve readability and maintainability.

if np.random.uniform() < self.instance_dropout["p"] and n_instances_dropped < self.instance_dropout["n"]:
    n_instances_dropped += 1
    continue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants