Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi anchor crop #41

Merged
merged 34 commits into from
May 14, 2024
Merged

Multi anchor crop #41

merged 34 commits into from
May 14, 2024

Conversation

aaprasad
Copy link
Contributor

@aaprasad aaprasad commented May 7, 2024

As per #32, the major failure mode we're seeing in the model while animal tracking is when nodes disappear causing a shift in visual features and a switch to occur. We've tried randomly selecting available node, using the pose centroid as the crop, and reducing the crop size but none have worked. Instead here, we add functionality to stack crops from as many nodes along the channel dimension and use that as input into the model.

  • we added timm as the backend for the visual encoder to easily enable arbitrary channel numbers
  • generalized the positional encoding to work for arbitrary number of boxes
  • added an MLP component to project the embedding back to the correct number of spaces
  • removed mutable defaults wherever we saw it

Summary by CodeRabbit

  • New Features

    • Improved positional embedding handling for multiple anchors.
    • Enhanced visual encoder initialization with more flexible parameters.
  • Bug Fixes

    • Added exception handling for embedding dimension mismatches.
    • Corrected input box shape handling in embeddings.
  • Enhancements

    • Switched visual encoder's backend to timm for better model support.
    • Updated visual encoder to use torch.nn.LazyLinear for output layer initialization.
  • Tests

    • Expanded test coverage for visual encoder and embeddings with new scenarios and parameters.

Copy link
Contributor

coderabbitai bot commented May 7, 2024

Walkthrough

The updates enhance the Embeddings and VisualEncoder classes in the biogtr module by adding new parameters, improved exception handling, and support for multiple anchors and MLP transformations. The VisualEncoder now uses timm for model creation, eliminating the need for torchvision, and includes a new method to compute embedding dimensions. Corresponding tests have been updated to reflect these changes.

Changes

File Path Change Summary
biogtr/models/embedding.py Added parameters to Embeddings class, improved exception handling, support for multiple anchors and MLP.
biogtr/models/visual_encoder.py Replaced torchvision with timm, added new parameters, removed select_feature_extractor method, added encoder_dim method.
tests/test_models.py Updated imports, added new tests for VisualEncoder, modified existing tests for Embeddings.

In code's vast realm, new paths we trod,
To timm we venture, with MLPs we nod.
Anchors multiply, dimensions align,
Models now robust, in brilliance they shine.
🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@aaprasad aaprasad marked this pull request as ready for review May 8, 2024 17:33
@aaprasad aaprasad requested a review from talmo May 8, 2024 17:38
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 20

Out of diff range and nitpick comments (9)
tests/test_training.py (2)

Line range hint 9-10: Remove unused imports to clean up the code and avoid confusion.

- from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer
- from omegaconf.DictConfig import DictConfig

Line range hint 46-46: The local variable feats is assigned but never used in the test function. Consider removing it if it's not necessary for the test logic.

- feats = 128
tests/test_datasets.py (1)

Line range hint 52-107: Add assertions to check the values of the crops in addition to their shapes for more comprehensive testing.

+ assert torch.allclose(expected_crops, instances[0].get_crops()), "Crops do not match expected values."
tests/test_models.py (2)

Line range hint 363-363: Remove the unused variable img_shape.

- img_shape = (1, 100, 100)

This change removes the unused variable img_shape, which is not needed in the function and could potentially lead to confusion.


Line range hint 435-435: Remove the unused variable cfg.

- cfg = {"resnet18", "ResNet18_Weights.DEFAULT"}

This change removes the unused variable cfg, which is not needed in the function and could potentially lead to confusion.

biogtr/datasets/sleap_dataset.py (1)

Line range hint 9-9: Remove unused import of warnings.

- import warnings
biogtr/data_structures.py (3)

335-343: Clarify the return type of anchor property.

The anchor property's docstring could be clearer about its return type. It currently suggests a list of strings but ends with a return of an empty string if no centroids are defined. This could lead to type inconsistencies.


Line range hint 546-586: Ensure consistent handling of default values in Frame constructor.

The constructor of Frame class does not consistently handle default values for parameters like img_shape and instances. This could lead to unexpected behavior or errors when these values are accessed later in the code.


652-652: Document the track_lookup parameter in to_slp method.

The track_lookup parameter in the to_slp method of the Frame class is not documented in the method's docstring. Adding this documentation would improve code readability and maintainability.

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 1b7e817 and 4578eeb.
Files selected for processing (22)
  • .gitignore (1 hunks)
  • biogtr/data_structures.py (8 hunks)
  • biogtr/datasets/sleap_dataset.py (5 hunks)
  • biogtr/inference/boxes.py (6 hunks)
  • biogtr/models/init.py (1 hunks)
  • biogtr/models/attention_head.py (2 hunks)
  • biogtr/models/embedding.py (9 hunks)
  • biogtr/models/global_tracking_transformer.py (3 hunks)
  • biogtr/models/mlp.py (1 hunks)
  • biogtr/models/model_utils.py (1 hunks)
  • biogtr/models/transformer.py (2 hunks)
  • biogtr/models/visual_encoder.py (3 hunks)
  • biogtr/training/configs/base.yaml (1 hunks)
  • biogtr/training/losses.py (1 hunks)
  • environment.yml (1 hunks)
  • environment_cpu.yml (1 hunks)
  • environment_osx-arm64.yml (1 hunks)
  • tests/configs/base.yaml (5 hunks)
  • tests/test_data_structures.py (1 hunks)
  • tests/test_datasets.py (2 hunks)
  • tests/test_models.py (6 hunks)
  • tests/test_training.py (2 hunks)
Files skipped from review due to trivial changes (3)
  • .gitignore
  • environment.yml
  • environment_cpu.yml
Additional Context Used
Ruff (14)
biogtr/datasets/sleap_dataset.py (1)

9-9: warnings imported but unused

biogtr/models/__init__.py (4)

3-3: .attention_head.ATTWeightHead imported but unused; consider removing, adding to __all__, or using a redundant alias


5-5: .embedding.Embedding imported but unused; consider removing, adding to __all__, or using a redundant alias


6-6: .transformer.Transformer imported but unused; consider removing, adding to __all__, or using a redundant alias


7-7: .visual_encoder.VisualEncoder imported but unused; consider removing, adding to __all__, or using a redundant alias

biogtr/models/visual_encoder.py (1)

3-3: typing.Tuple imported but unused

tests/test_datasets.py (1)

8-8: biogtr.models.model_utils.get_device imported but unused

tests/test_models.py (4)

243-243: Local variable d_model is assigned to but never used


250-250: Local variable times is assigned to but never used


363-363: Local variable img_shape is assigned to but never used


435-435: Local variable cfg is assigned to but never used

tests/test_training.py (3)

9-9: biogtr.models.global_tracking_transformer.GlobalTrackingTransformer imported but unused


10-10: omegaconf.DictConfig imported but unused


46-46: Local variable feats is assigned to but never used

Additional comments not posted (15)
environment_osx-arm64.yml (1)

20-21: Ensure the addition of timm aligns with the project's dependency management strategy. If timm is a crucial dependency for the new features, this change is justified.

biogtr/models/attention_head.py (1)

27-28: The use of the MLP class for projections in the attention mechanism is a good design choice, promoting modularity and reuse of the MLP functionality.

biogtr/models/mlp.py (1)

46-60: The forward method of the MLP class is correctly implemented, ensuring that each layer is followed by a ReLU activation and dropout, except for the last layer. This is a standard practice for MLPs.

biogtr/models/visual_encoder.py (1)

20-53: The VisualEncoder class has been effectively updated to utilize the timm library for model creation. This aligns with the PR's objectives to enhance model flexibility and performance.

tests/configs/base.yaml (2)

Line range hint 3-22: The updates to the model configuration in the YAML file are consistent with the changes in the model's codebase, ensuring that the configuration aligns with the new functionalities.


Line range hint 68-86: The dataset configuration has been updated to include new parameters like anchors, which is crucial for the new multi-anchor cropping feature. This change is necessary and well-implemented.

biogtr/training/configs/base.yaml (1)

17-20: The model configuration updates in this YAML file are consistent with the intended changes to the model's architecture and functionalities. This ensures that the configuration is in sync with the code changes.

biogtr/models/global_tracking_transformer.py (2)

Line range hint 16-64: Ensure validation for encoder_cfg to prevent runtime errors.


61-61: Verify that Frame objects have has_instances() and has_features() methods.

biogtr/inference/boxes.py (4)

36-36: LGTM! Good use of assertion to ensure tensor dimensionality.


66-66: LGTM! Correct handling of 3D tensors in area calculation.


80-83: LGTM! Correct implementation of clipping with 3D tensor support.


100-101: LGTM! Correct identification of non-empty boxes with 3D tensor support.

biogtr/training/losses.py (1)

55-55: Ensure NaN values are expected and handled correctly throughout the model's pipeline.

Verification successful

Handling of NaN values is evident across various parts of the model's pipeline, including data structures, visualization, training losses, inference metrics, and dataset processing. This indicates a comprehensive approach to managing NaN values, which aligns with the use of torch.nanmean in the AssoLoss class to handle NaN values effectively.

Here are the specific files and contexts where NaN handling is implemented:

  • Data Structures: Usage of nanmean for bounding box calculations.
  • Visualization: Interpolation and masking strategies to handle NaN values in visualization data.
  • Training Losses: Direct use of torch.nanmean in the forward method of AssoLoss.
  • Inference Metrics: Initialization of evaluation matrices with NaN values and handling during metric calculations.
  • Datasets: Various checks and handling mechanisms for NaN values in dataset processing functions.

This comprehensive handling suggests that the model's pipeline is designed to expect and correctly manage NaN values, supporting the changes made in the AssoLoss class.

Scripts Executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Check for handling of NaN values in the model pipeline.

# Test: Search for NaN handling in the model pipeline. Expect: Handling logic for NaN values.
rg --type python 'nan'

Length of output: 53



Script:

#!/bin/bash
# Description: Check for handling of NaN values in the model pipeline.

# Test: Search for NaN handling in the model pipeline. Expect: Handling logic for NaN values.
rg --type py 'nan'

Length of output: 2247

biogtr/models/transformer.py (1)

174-174: Proper handling of NaN values in pred_box.

biogtr/models/embedding.py Outdated Show resolved Hide resolved
biogtr/models/embedding.py Show resolved Hide resolved
biogtr/models/embedding.py Show resolved Hide resolved
def test_multianchor_embedding():
frames = 32
objects = 10
d_model = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the unused variable d_model.

- d_model = 256

This change removes the unused variable d_model, which is not needed in the function and could potentially lead to confusion.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
d_model = 256

Comment on lines 42 to 49
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
* an int indicating the number of anchors to randomly select
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
* an int indicating the number of anchors to randomly select
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove duplicate documentation and code for anchors parameter.

-            anchors: One of:
-                        * a string indicating a single node to center crops around
-                        * a list of skeleton node names to be used as the center of crops
-                        * an int indicating the number of anchors to randomly select

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
* an int indicating the number of anchors to randomly select
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
* an int indicating the number of anchors to randomly select
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
* an int indicating the number of anchors to randomly select

Comment on lines +135 to +140
if not torch.nn.parameter.is_lazy(p) and p.dim() > 1:
try:
nn.init.xavier_uniform_(p)
except ValueError as e:
print(f"Failed Trying to initialize {p}")
raise (e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle exceptions more gracefully during parameter initialization.

-                except ValueError as e:
-                    print(f"Failed Trying to initialize {p}")
-                    raise (e)
+                except ValueError as e:
+                    logging.error(f"Failed to initialize parameter {p}: {e}")
+                    raise

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if not torch.nn.parameter.is_lazy(p) and p.dim() > 1:
try:
nn.init.xavier_uniform_(p)
except ValueError as e:
print(f"Failed Trying to initialize {p}")
raise (e)
if not torch.nn.parameter.is_lazy(p) and p.dim() > 1:
try:
nn.init.xavier_uniform_(p)
except ValueError as e:
logging.error(f"Failed to initialize parameter {p}: {e}")
raise

@@ -60,7 +58,10 @@ def __init__(
"""
super().__init__()

self.visual_encoder = VisualEncoder(encoder_model, encoder_cfg, d_model)
if encoder_cfg is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider refactoring the lambda function for better readability.

- for frame in filter(lambda f: f.has_instances() and not f.has_features(), frames):
+ def frame_filter(f):
+     return f.has_instances() and not f.has_features()
+ for frame in filter(frame_filter, frames):

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if encoder_cfg is not None:
if encoder_cfg is not None:
def frame_filter(f):
return f.has_instances() and not f.has_features()
for frame in filter(frame_filter, frames):

Comment on lines +25 to +31
bbox[:, :, [0, 2]] /= w
bbox[:, :, [1, 3]] /= h

boxes.append(bbox)
times.append(torch.full((bbox.shape[0],), fidx))

boxes = torch.cat(boxes, dim=0) # N x 4
boxes = torch.cat(boxes, dim=0) # N, n_anchors, 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure all frames have consistent dimensions before normalization.

+ if not all(f.img_shape == frames[0].img_shape for f in frames):
+     raise ValueError("All frames must have the same dimensions.")
  _, h, w = frames[0].img_shape.flatten()

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
bbox[:, :, [0, 2]] /= w
bbox[:, :, [1, 3]] /= h
boxes.append(bbox)
times.append(torch.full((bbox.shape[0],), fidx))
boxes = torch.cat(boxes, dim=0) # N x 4
boxes = torch.cat(boxes, dim=0) # N, n_anchors, 4
if not all(f.img_shape == frames[0].img_shape for f in frames):
raise ValueError("All frames must have the same dimensions.")
_, h, w = frames[0].img_shape.flatten()
bbox[:, :, [0, 2]] /= w
bbox[:, :, [1, 3]] /= h
boxes.append(bbox)
times.append(torch.full((bbox.shape[0],), fidx))
boxes = torch.cat(boxes, dim=0) # N, n_anchors, 4

@@ -10,7 +10,7 @@ def test_instance():

gt_track_id = 0
pred_track_id = 0
bbox = torch.randn((1, 4))
bbox = torch.randn((1, 1, 4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the test to use a testing framework for better structure and output.

+ import unittest
+ class TestInstance(unittest.TestCase):
+     def test_instance(self):
- def test_instance():

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
bbox = torch.randn((1, 1, 4))
import unittest
class TestInstance(unittest.TestCase):
def test_instance(self):
bbox = torch.randn((1, 1, 4))

Comment on lines +125 to +128
return Boxes(self.tensor[item])
b = self.tensor[item]
assert (
b.dim() == 2
b.dim() == 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more comments to explain the indexing logic for better clarity.

+ # Indexing can be done with an int, slice, or a BoolTensor.
+ # This ensures that the result is always a valid `Boxes` object.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
return Boxes(self.tensor[item])
b = self.tensor[item]
assert (
b.dim() == 2
b.dim() == 3
return Boxes(self.tensor[item])
# Indexing can be done with an int, slice, or a BoolTensor.
# This ensures that the result is always a valid `Boxes` object.
b = self.tensor[item]
assert (
b.dim() == 3

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Out of diff range and nitpick comments (1)
biogtr/inference/post_processing.py (1)

57-60: Consider optimizing the tensor operations for intersection calculation.

The current implementation involves multiple tensor operations which might be optimized further. Consider using built-in PyTorch functions that could potentially handle these operations more efficiently if available.

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 4578eeb and 68915be.
Files selected for processing (1)
  • biogtr/inference/post_processing.py (2 hunks)
Additional comments not posted (4)
biogtr/inference/post_processing.py (4)

Line range hint 25-38: Good use of assertions to ensure necessary variables are present.

The function weight_decay_time correctly uses assertions to ensure that reid_features, T, and k are not None when decay_time is greater than 0. This prevents potential runtime errors from missing data.


85-89: Robust handling of empty boxes in IoU calculation.

The function _pairwise_iou effectively handles cases where the intersection might be zero by assigning NaN and then using nanmean to compute the average while ignoring NaN values. This ensures that the function remains robust against empty input boxes.


Line range hint 100-117: Effective error handling and method flexibility in IoU weighting.

The function weight_iou provides flexibility in how the IoU is used to weight the association matrix and includes robust error handling to ensure that last_ious is provided when necessary. It also correctly raises a ValueError for unsupported methods, guiding the user towards correct usage.


Line range hint 130-152: Well-implemented logic for filtering based on maximum center distance.

The function filter_max_center_dist effectively uses tensor operations to compute distances between box centers and filters the association matrix based on these distances. It also uses assertions to ensure all necessary inputs are provided, preventing runtime errors.

@aaprasad aaprasad changed the title Aadi/multi anchor crop Multi anchor crop May 8, 2024
biogtr/data_structures.py Show resolved Hide resolved
biogtr/data_structures.py Show resolved Hide resolved
biogtr/data_structures.py Show resolved Hide resolved
biogtr/data_structures.py Show resolved Hide resolved
biogtr/models/embedding.py Outdated Show resolved Hide resolved
biogtr/models/global_tracking_transformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Out of diff range and nitpick comments (1)
tests/test_inference.py (1)

Line range hint 88-88: Remove the unused variable N to clean up the code.

- N = N_t * T
Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 68915be and bd806e6.
Files selected for processing (9)
  • .gitignore (1 hunks)
  • biogtr/inference/post_processing.py (3 hunks)
  • biogtr/models/attention_head.py (1 hunks)
  • biogtr/models/embedding.py (9 hunks)
  • biogtr/models/global_tracking_transformer.py (3 hunks)
  • biogtr/models/mlp.py (1 hunks)
  • biogtr/models/visual_encoder.py (3 hunks)
  • tests/test_inference.py (1 hunks)
  • tests/test_models.py (6 hunks)
Files skipped from review as they are similar to previous changes (5)
  • .gitignore
  • biogtr/inference/post_processing.py
  • biogtr/models/attention_head.py
  • biogtr/models/global_tracking_transformer.py
  • biogtr/models/mlp.py
Additional Context Used
Ruff (5)
biogtr/models/visual_encoder.py (1)

3-3: typing.Tuple imported but unused

tests/test_inference.py (2)

4-4: pytest imported but unused


88-88: Local variable N is assigned to but never used

tests/test_models.py (2)

250-250: Local variable times is assigned to but never used


373-373: Local variable img_shape is assigned to but never used

Additional comments not posted (17)
biogtr/models/visual_encoder.py (1)

Line range hint 70-89: The forward method is implemented correctly and handles grayscale images appropriately.

tests/test_inference.py (2)

46-46: The configuration for the encoder in the GlobalTrackingTransformer is correctly set up for the test.


Line range hint 88-88: The metrics testing is comprehensive and covers various scenarios effectively.

tests/test_models.py (10)

6-6: The MLP model is correctly imported and used in the tests.


7-7: The ATTWeightHead model is correctly imported and used in the tests.


52-52: The VisualEncoder is correctly set up with various channel configurations for testing.


243-243: Remove the unused variable d_model to clean up the code.


69-69: The Embedding class is tested thoroughly for various configurations to ensure exceptions are raised as expected.


118-118: The basic functionality of the Embedding class is correctly tested with various configurations.


200-200: The Embedding class is tested with different keyword arguments to ensure correct handling.


282-282: The TransformerEncoderLayer is correctly tested with and without positional embeddings.


Line range hint 373-373: The TransformerDecoderLayer is correctly tested with and without positional embeddings.


Line range hint 373-373: The full transformer logic is correctly tested and the output shapes are asserted as expected.

biogtr/models/embedding.py (4)

149-158: The forward method correctly computes the embeddings and checks the output dimensions.


Line range hint 178-215: The _sine_box_embedding method correctly computes sine positional embeddings for boxes, handling normalization and scaling appropriately.


215-215: The _sine_temp_embedding method correctly computes sine temporal embeddings, handling the computation of denominators and lookup table correctly.


264-306: The _learned_pos_embedding method correctly computes learned positional embeddings for boxes, handling the computation of weights and lookup table correctly.

Comment on lines +41 to +44
if in_chans == 1:
self.in_chans = 3
else:
self.in_chans = in_chans
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review the logic for handling in_chans. Currently, if in_chans is 1, it is forcibly set to 3, which might not be intended for all use cases.

N = frames * objects

boxes = torch.rand(size=(N, n_anchors, 4))
times = torch.rand(size=(N,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the unused variable times to clean up the code.

- times = torch.rand(size=(N,))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
times = torch.rand(size=(N,))

Comment on lines 76 to 91
if self.emb_type == "pos" and mlp_cfg is not None:
if self.mode == "fixed":
self.mlp = MLP(
input_dim=n_points * self.features,
output_dim=self.features,
**mlp_cfg,
)
else:
in_dim = (self.features // (4 * n_points)) * (4 * n_points)
self.mlp = MLP(
input_dim=in_dim,
output_dim=self.features,
**mlp_cfg,
)
else:
self.mlp = torch.nn.Identity()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review the initialization logic for mlp_cfg. Ensure that it is correctly initialized based on the mode and type of embedding.

Copy link

codecov bot commented May 13, 2024

Codecov Report

Attention: Patch coverage is 81.42077% with 34 lines in your changes are missing coverage. Please review.

Project coverage is 69.81%. Comparing base (041d0a4) to head (68cf717).
Report is 9 commits behind head on main.

❗ Current head 68cf717 differs from pull request most recent head c772453. Consider uploading reports for the commit c772453 to get more accurate results

Files Patch % Lines
biogtr/datasets/sleap_dataset.py 72.00% 14 Missing ⚠️
biogtr/inference/boxes.py 16.66% 10 Missing ⚠️
biogtr/data_structures.py 86.11% 5 Missing ⚠️
biogtr/models/transformer.py 57.14% 3 Missing ⚠️
biogtr/models/embedding.py 96.66% 1 Missing ⚠️
biogtr/models/mlp.py 94.73% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #41      +/-   ##
==========================================
+ Coverage   69.60%   69.81%   +0.21%     
==========================================
  Files          28       29       +1     
  Lines        2385     2508     +123     
==========================================
+ Hits         1660     1751      +91     
- Misses        725      757      +32     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Out of diff range and nitpick comments (2)
biogtr/models/embedding.py (1)

Line range hint 6-103: Constructor of Embedding is well-implemented, handling various configurations and initializing the embeddings correctly.

- self.mlp = MLP(...)

Ensure the MLP setup line is covered by tests to verify its functionality.

biogtr/datasets/sleap_dataset.py (1)

Line range hint 9-9: Remove unused import.

- import warnings
Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 68915be and 68cf717.
Files selected for processing (10)
  • .gitignore (1 hunks)
  • biogtr/datasets/sleap_dataset.py (5 hunks)
  • biogtr/inference/post_processing.py (3 hunks)
  • biogtr/models/attention_head.py (1 hunks)
  • biogtr/models/embedding.py (9 hunks)
  • biogtr/models/global_tracking_transformer.py (3 hunks)
  • biogtr/models/mlp.py (1 hunks)
  • biogtr/models/visual_encoder.py (3 hunks)
  • tests/test_inference.py (1 hunks)
  • tests/test_models.py (6 hunks)
Files skipped from review as they are similar to previous changes (4)
  • .gitignore
  • biogtr/inference/post_processing.py
  • biogtr/models/attention_head.py
  • biogtr/models/global_tracking_transformer.py
Additional Context Used
Ruff (6)
biogtr/datasets/sleap_dataset.py (1)

9-9: warnings imported but unused

biogtr/models/visual_encoder.py (1)

3-3: typing.Tuple imported but unused

tests/test_inference.py (2)

4-4: pytest imported but unused


88-88: Local variable N is assigned to but never used

tests/test_models.py (2)

250-250: Local variable times is assigned to but never used


373-373: Local variable img_shape is assigned to but never used

GitHub Check Runs (1)
codecov/patch success (22)

biogtr/datasets/sleap_dataset.py: [warning] 95-95: biogtr/datasets/sleap_dataset.py#L95
Added line #L95 was not covered by tests


biogtr/datasets/sleap_dataset.py: [warning] 184-184: biogtr/datasets/sleap_dataset.py#L184
Added line #L184 was not covered by tests


biogtr/datasets/sleap_dataset.py: [warning] 188-188: biogtr/datasets/sleap_dataset.py#L188
Added line #L188 was not covered by tests


biogtr/datasets/sleap_dataset.py: [warning] 294-294: biogtr/datasets/sleap_dataset.py#L294
Added line #L294 was not covered by tests


biogtr/datasets/sleap_dataset.py: [warning] 300-303: biogtr/datasets/sleap_dataset.py#L300-L303
Added lines #L300 - L303 were not covered by tests


biogtr/datasets/sleap_dataset.py: [warning] 305-307: biogtr/datasets/sleap_dataset.py#L305-L307
Added lines #L305 - L307 were not covered by tests


biogtr/datasets/sleap_dataset.py: [warning] 310-310: biogtr/datasets/sleap_dataset.py#L310
Added line #L310 was not covered by tests


biogtr/datasets/sleap_dataset.py: [warning] 313-313: biogtr/datasets/sleap_dataset.py#L313
Added line #L313 was not covered by tests


biogtr/datasets/sleap_dataset.py: [warning] 322-322: biogtr/datasets/sleap_dataset.py#L322
Added line #L322 was not covered by tests


biogtr/models/embedding.py: [warning] 152-152: biogtr/models/embedding.py#L152
Added line #L152 was not covered by tests


biogtr/models/mlp.py: [warning] 45-45: biogtr/models/mlp.py#L45
Added line #L45 was not covered by tests

Additional comments not posted (18)
biogtr/models/mlp.py (2)

7-43: Constructor setup for MLP looks robust and handles various configurations well.


47-61: Forward method implementation in MLP is correctly applying activations and handling dropout.

biogtr/models/visual_encoder.py (1)

Line range hint 70-86: Forward method in VisualEncoder correctly handles image inputs, including grayscale, and properly uses the feature extractor and output layer.

tests/test_inference.py (3)

46-46: The test_tracker function correctly sets up and tests the tracker module with various configurations.


Line range hint 88-88: The test_post_processing function effectively tests different post-processing methods under various conditions.


Line range hint 88-88: The test_metrics function correctly tests the metrics calculation for tracking, ensuring accurate match detection and switch counting.

tests/test_models.py (8)

6-6: The test_mlp function correctly tests the MLP logic, ensuring it processes inputs as expected and produces outputs of the correct shape.


7-7: The test_att_weight_head function correctly tests the self-attention head logic, ensuring it processes inputs as expected and produces outputs of the correct shape.


52-76: The test_encoder function correctly tests the feature extractor logic, ensuring it handles different channel configurations and produces outputs of the correct shape.


69-69: The test_embedding_validity function correctly tests the embedding usage, ensuring exceptions are raised under invalid configurations.


118-118: The test_embedding_basic function correctly tests the basic embedding logic, ensuring it processes inputs as expected and produces outputs of the correct shape.


200-200: The test_embedding_kwargs function correctly tests the embedding configuration logic, ensuring it handles different keyword arguments and produces outputs of the correct shape.


240-240: The test_multianchor_embedding function correctly tests the multi-anchor embedding logic, ensuring it handles multiple anchors and produces outputs of the correct shape.


282-282: The test_transformer_encoder function correctly tests the transformer encoder layer logic, ensuring it handles positional embeddings and produces outputs of the correct shape.

biogtr/models/embedding.py (3)

143-158: Forward method in Embedding correctly processes inputs through the appropriate embedding function and checks the output shape.


Line range hint 178-215: The _sine_box_embedding method correctly computes sine positional embeddings for boxes, handling different configurations such as normalization and scaling.


264-306: The _learned_pos_embedding method correctly computes learned positional embeddings for boxes, handling different configurations such as over_boxes.

biogtr/datasets/sleap_dataset.py (1)

42-46: Remove duplicate documentation for anchors parameter.

Comment on lines +85 to +95
if isinstance(anchors, int):
self.anchors = anchors
elif isinstance(anchors, str):
self.anchors = [anchors.lower()]
else:
self.anchors = [anchor.lower() for anchor in anchors]

if (
isinstance(self.anchors, list) and len(self.anchors) == 0
) or self.anchors == 0:
raise ValueError(f"Must provide at least one anchor but got {self.anchors}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the handling of anchors to simplify the logic.

-        if isinstance(anchors, int):
-            self.anchors = anchors
-        elif isinstance(anchors, str):
-            self.anchors = [anchors.lower()]
-        else:
-            self.anchors = [anchor.lower() for anchor in anchors]
-        if (
-            isinstance(self.anchors, list) and len(self.anchors) == 0
-        ) or self.anchors == 0:
-            raise ValueError(f"Must provide at least one anchor but got {self.anchors}")
+        # Simplified handling of anchors
+        self.anchors = [anchors.lower()] if isinstance(anchors, str) else anchors
+        if not self.anchors:
+            raise ValueError("Must provide at least one anchor.")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if isinstance(anchors, int):
self.anchors = anchors
elif isinstance(anchors, str):
self.anchors = [anchors.lower()]
else:
self.anchors = [anchor.lower() for anchor in anchors]
if (
isinstance(self.anchors, list) and len(self.anchors) == 0
) or self.anchors == 0:
raise ValueError(f"Must provide at least one anchor but got {self.anchors}")
# Simplified handling of anchors
self.anchors = [anchors.lower()] if isinstance(anchors, str) else anchors
if not self.anchors:
raise ValueError("Must provide at least one anchor.")

Comment on lines +187 to +191
if c == 1:
img = np.concatenate(
[img, img, img], axis=-1
) # convert to grayscale to rgb

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the conversion of grayscale images to RGB.

-                img = np.concatenate(
-                    [img, img, img], axis=-1
-                )  # convert to grayscale to rgb
+                img = np.stack([img] * 3, axis=-1)  # convert grayscale to RGB

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if c == 1:
img = np.concatenate(
[img, img, img], axis=-1
) # convert to grayscale to rgb
if c == 1:
img = np.stack([img] * 3, axis=-1) # convert grayscale to RGB

Comment on lines +281 to +339
if isinstance(self.anchors, int):
anchors_to_choose = list(pose.keys()) + ["midpoint"]
anchors = np.random.choice(anchors_to_choose, self.anchors)
else:
if self.verbose:
warnings.warn(
f"{self.anchor} not in {[key for key in pose.keys()]}! Defaulting to midpoint"
)
anchor = "midpoint"
anchors = self.anchors

if anchor != "midpoint":
centroid = pose[anchor]
for anchor in anchors:
if anchor == "midpoint" or anchor == "centroid":
centroid = np.nanmean(np.array(list(pose.values())), axis=0)

elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
centroid = np.array([np.nan, np.nan])

if np.isnan(centroid).any():
elif anchor not in pose and len(anchors) == 1:
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
else:
# print(f'{self.anchor} not an available option amongst {pose.keys()}. Using midpoint')
centroid = np.nanmean(np.array(list(pose.values())), axis=0)

bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)
elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
centroid = np.array([np.nan, np.nan])

elif anchor not in pose and len(anchors) == 1:
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)

else:
centroid = np.array([np.nan, np.nan])

if np.isnan(centroid).all():
bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])

else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)

if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
crop = data_utils.crop_bbox(img, bbox)

crops.append(crop)
centroids[anchor] = centroid
boxes.append(bbox)

if len(crops) > 0:
crops = torch.concat(crops, dim=0)

crop = data_utils.crop_bbox(img, bbox)
if len(boxes) > 0:
boxes = torch.stack(boxes, dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the handling of anchor points to simplify the logic.

-                if isinstance(self.anchors, int):
-                    anchors_to_choose = list(pose.keys()) + ["midpoint"]
-                    anchors = np.random.choice(anchors_to_choose, self.anchors)
-                else:
-                    anchors = self.anchors
-                for anchor in anchors:
-                    if anchor == "midpoint" or anchor == "centroid":
-                        centroid = np.nanmean(np.array(list(pose.values())), axis=0)
-                    elif anchor in pose:
-                        centroid = np.array(pose[anchor])
-                        if np.isnan(centroid).any():
-                            centroid = np.array([np.nan, np.nan])
-                    elif anchor not in pose and len(anchors) == 1:
-                        anchor = "midpoint"
-                        centroid = np.nanmean(np.array(list(pose.values())), axis=0)
-                    else:
-                        centroid = np.array([np.nan, np.nan])
-                    if np.isnan(centroid).all():
-                        bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])
-                    else:
-                        bbox = data_utils.pad_bbox(
-                            data_utils.get_bbox(centroid, self.crop_size),
-                            padding=self.padding,
-                        )
-                    if bbox.isnan().all():
-                        crop = torch.zeros(
-                            c,
-                            self.crop_size + 2 * self.padding,
-                            self.crop_size + 2 * self.padding,
-                            dtype=img.dtype,
-                        )
-                    else:
-                        crop = data_utils.crop_bbox(img, bbox)
-                    crops.append(crop)
-                    centroids[anchor] = centroid
-                    boxes.append(bbox)
-                if len(crops) > 0:
-                    crops = torch.concat(crops, dim=0)
-                if len(boxes) > 0:
-                    boxes = torch.stack(boxes, dim=0)
+                # Proposed refactoring code here

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if isinstance(self.anchors, int):
anchors_to_choose = list(pose.keys()) + ["midpoint"]
anchors = np.random.choice(anchors_to_choose, self.anchors)
else:
if self.verbose:
warnings.warn(
f"{self.anchor} not in {[key for key in pose.keys()]}! Defaulting to midpoint"
)
anchor = "midpoint"
anchors = self.anchors
if anchor != "midpoint":
centroid = pose[anchor]
for anchor in anchors:
if anchor == "midpoint" or anchor == "centroid":
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
centroid = np.array([np.nan, np.nan])
if np.isnan(centroid).any():
elif anchor not in pose and len(anchors) == 1:
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
else:
# print(f'{self.anchor} not an available option amongst {pose.keys()}. Using midpoint')
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)
elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
centroid = np.array([np.nan, np.nan])
elif anchor not in pose and len(anchors) == 1:
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
else:
centroid = np.array([np.nan, np.nan])
if np.isnan(centroid).all():
bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])
else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)
if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
crop = data_utils.crop_bbox(img, bbox)
crops.append(crop)
centroids[anchor] = centroid
boxes.append(bbox)
if len(crops) > 0:
crops = torch.concat(crops, dim=0)
crop = data_utils.crop_bbox(img, bbox)
if len(boxes) > 0:
boxes = torch.stack(boxes, dim=0)
# Proposed refactoring code here

Comment on lines 20 to 51
def __init__(
self,
model_name: str = "resnet18",
d_model: int = 512,
in_chans: int = 3,
pretrained: bool = False,
**kwargs,
):
"""Initialize Visual Encoder.

Args:
model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
cfg (dict): Dictionary of arguments to pass to the CNN constructor,
e.g: `cfg = {"weights": "ResNet18_Weights.DEFAULT"}`
d_model (int): Output embedding dimension.
in_chans: the number of input channels of the image.
pretrained: whether or not to use pretrained weights from hugging_face
kwargs: see `timm.create_model` for kwargs.
"""
super().__init__()

self.model_name = model_name
self.model_name = model_name.lower()
self.d_model = d_model

self.feature_extractor, out_dim = self.select_feature_extractor(model_name, cfg)

self.feature_extractor = torch.nn.Sequential(
*list(self.feature_extractor.children())[:-1]
if in_chans == 1:
self.in_chans = 3
else:
self.in_chans = in_chans

self.feature_extractor = timm.create_model(
model_name=self.model_name,
in_chans=self.in_chans,
pretrained=pretrained,
num_classes=0,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constructor of VisualEncoder is well-implemented, handling different configurations and initializing the model correctly.

- from typing import Tuple

Remove the unused import to clean up the code.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def __init__(
self,
model_name: str = "resnet18",
d_model: int = 512,
in_chans: int = 3,
pretrained: bool = False,
**kwargs,
):
"""Initialize Visual Encoder.
Args:
model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
cfg (dict): Dictionary of arguments to pass to the CNN constructor,
e.g: `cfg = {"weights": "ResNet18_Weights.DEFAULT"}`
d_model (int): Output embedding dimension.
in_chans: the number of input channels of the image.
pretrained: whether or not to use pretrained weights from hugging_face
kwargs: see `timm.create_model` for kwargs.
"""
super().__init__()
self.model_name = model_name
self.model_name = model_name.lower()
self.d_model = d_model
self.feature_extractor, out_dim = self.select_feature_extractor(model_name, cfg)
self.feature_extractor = torch.nn.Sequential(
*list(self.feature_extractor.children())[:-1]
if in_chans == 1:
self.in_chans = 3
else:
self.in_chans = in_chans
self.feature_extractor = timm.create_model(
model_name=self.model_name,
in_chans=self.in_chans,
pretrained=pretrained,
num_classes=0,
**kwargs,
def __init__(
self,
model_name: str = "resnet18",
d_model: int = 512,
in_chans: int = 3,
pretrained: bool = False,
**kwargs,
):
"""Initialize Visual Encoder.
Args:
model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
d_model (int): Output embedding dimension.
in_chans: the number of input channels of the image.
pretrained: whether or not to use pretrained weights from hugging_face
kwargs: see `timm.create_model` for kwargs.
"""
super().__init__()
self.model_name = model_name.lower()
self.d_model = d_model
if in_chans == 1:
self.in_chans = 3
else:
self.in_chans = in_chans
self.feature_extractor = timm.create_model(
model_name=self.model_name,
in_chans=self.in_chans,
pretrained=pretrained,
num_classes=0,
**kwargs,

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 68cf717 and c772453.
Files selected for processing (3)
  • biogtr/models/embedding.py (9 hunks)
  • biogtr/models/visual_encoder.py (3 hunks)
  • tests/test_models.py (6 hunks)
Additional Context Used
Ruff (3)
biogtr/models/visual_encoder.py (1)

3-3: typing.Tuple imported but unused

tests/test_models.py (2)

302-302: Local variable times is assigned to but never used


425-425: Local variable img_shape is assigned to but never used

Additional comments not posted (12)
biogtr/models/visual_encoder.py (4)

43-46: Review the logic for handling in_chans. Currently, if in_chans is 1, it is forcibly set to 3, which might not be intended for all use cases.


60-112: The select_feature_extractor method is well-implemented, handling different configurations and initializing the model correctly.


114-126: The encoder_dim method correctly computes the embedding dimension based on the model.


Line range hint 128-142: The forward method correctly handles grayscale images by repeating the channels and is well-implemented.

biogtr/models/embedding.py (4)

Line range hint 34-92: The __init__ method correctly initializes the MLP based on n_points and mlp_cfg parameters.


Line range hint 178-215: The _sine_box_embedding method correctly handles multiple anchors and includes the necessary MLP transformations for positional embeddings.


264-306: The _learned_pos_embedding method correctly handles multiple anchors and includes the necessary exception handling for incorrect n_points.


143-158: The forward method correctly includes the necessary exception handling for embedding dimension mismatches and logic for correct dimension projection.

tests/test_models.py (4)

45-85: The test_encoder_timm function correctly tests the VisualEncoder with different input channels and verifies the output shape.


88-128: The test_encoder_torch function correctly tests the VisualEncoder with different input channels and verifies the output shape.


Line range hint 170-218: The test_embedding_basic function correctly tests the Embedding class with different configurations and verifies the output shape.


292-333: The test_multianchor_embedding function correctly tests the Embedding class with multiple anchors and verifies the output shape.

@@ -3,6 +3,7 @@
from typing import Tuple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the unused import of Tuple to clean up the code.

- from typing import Tuple

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from typing import Tuple

@@ -3,6 +3,7 @@
from typing import Tuple, Optional
import math
import torch
from biogtr.models.mlp import MLP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the unused import of Tuple to clean up the code.

- from typing import Tuple

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from biogtr.models.mlp import MLP
from biogtr.models.mlp import MLP

@@ -3,7 +3,8 @@
import pytest
import torch
from biogtr.data_structures import Frame, Instance
from biogtr.models.attention_head import MLP, ATTWeightHead
from biogtr.models.mlp import MLP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the unused import of ATTWeightHead to clean up the code.

- from biogtr.models.attention_head import ATTWeightHead

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from biogtr.models.mlp import MLP
from biogtr.models.mlp import MLP

@aaprasad aaprasad merged commit 3d6bf9a into main May 14, 2024
3 checks passed
@talmo talmo deleted the aadi/multi-anchor-crop branch May 14, 2024 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants