-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Turn off visual encoder #68
base: main
Are you sure you want to change the base?
Conversation
Warning Rate limit exceeded@aaprasad has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 8 minutes and 38 seconds before requesting another review. How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. WalkthroughThe recent changes streamline the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Tracker
participant VisualEncoder
User->>Tracker: Initialize tracking
Tracker->>VisualEncoder: Check for features
alt Features available
VisualEncoder-->>Tracker: Extract features
else No features
VisualEncoder-->>Tracker: Return zeros
end
Tracker-->>User: Track results
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (3)
- dreem/inference/tracker.py (4 hunks)
- dreem/models/visual_encoder.py (2 hunks)
- tests/test_models.py (2 hunks)
Additional context used
Ruff
dreem/models/visual_encoder.py
77-77: SyntaxError: Expected ':', found 'lambda'
dreem/inference/tracker.py
123-123: f-string without any placeholders
Remove extraneous
f
prefix(F541)
Additional comments not posted (3)
dreem/models/visual_encoder.py (2)
55-60
: LGTM! The conditional assignment ofself.out_layer
enhances flexibility.The implementation correctly assigns
torch.nn.Identity()
toself.out_layer
whenmodel_name
is "off", an empty string, orNone
.
118-118
: LGTM! The updated error message improves clarity.The implementation correctly raises a
ValueError
for unsupported backends.dreem/inference/tracker.py (1)
112-114
: LGTM! The removal of the conditional check simplifies the code.The implementation correctly simplifies the feature extraction process by removing the
self.use_vis_feats
check.
dreem/models/visual_encoder.py
Outdated
if model_name in ["", "off", None]: | ||
feature_extractor = lambda lambda tensor: torch.zeros( | ||
(tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device | ||
) # turn off visual features by returning zeros |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the syntax error in the lambda function.
The lambda function definition is missing a colon.
- feature_extractor = lambda lambda tensor: torch.zeros(
+ feature_extractor = lambda tensor: torch.zeros(
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if model_name in ["", "off", None]: | |
feature_extractor = lambda lambda tensor: torch.zeros( | |
(tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device | |
) # turn off visual features by returning zeros | |
if model_name in ["", "off", None]: | |
feature_extractor = lambda tensor: torch.zeros( | |
(tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device | |
) # turn off visual features by returning zeros |
Tools
Ruff
77-77: SyntaxError: Expected ':', found 'lambda'
tests/test_models.py
Outdated
encoder = VisualEncoder( | ||
model_name="off", in_chans=c, d_model=features, backend=backend | ||
) | ||
output = encoder(input_tensor) | ||
|
||
assert output.shape == (b, features) | ||
assert not torch.is_nonzero.any() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the syntax error in the assertion for non-zero values.
The torch.is_nonzero.any()
function call is incorrect.
- assert not torch.is_nonzero.any()
+ assert not output.any()
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
encoder = VisualEncoder( | |
model_name="off", in_chans=c, d_model=features, backend=backend | |
) | |
output = encoder(input_tensor) | |
assert output.shape == (b, features) | |
assert not torch.is_nonzero.any() | |
encoder = VisualEncoder( | |
model_name="off", in_chans=c, d_model=features, backend=backend | |
) | |
output = encoder(input_tensor) | |
assert output.shape == (b, features) | |
assert not output.any() |
tests/test_models.py
Outdated
encoder = VisualEncoder( | ||
model_name="off", in_chans=c, d_model=features, backend=backend | ||
) | ||
output = encoder(input_tensor) | ||
|
||
assert output.shape == (b, features) | ||
assert not torch.is_nonzero.any() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the syntax error in the assertion for non-zero values.
The torch.is_nonzero.any()
function call is incorrect.
- assert not torch.is_nonzero.any()
+ assert not output.any()
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
encoder = VisualEncoder( | |
model_name="off", in_chans=c, d_model=features, backend=backend | |
) | |
output = encoder(input_tensor) | |
assert output.shape == (b, features) | |
assert not torch.is_nonzero.any() | |
encoder = VisualEncoder( | |
model_name="off", in_chans=c, d_model=features, backend=backend | |
) | |
output = encoder(input_tensor) | |
assert output.shape == (b, features) | |
assert not output.any() |
instances_pred = self.sliding_inference(model, frames) | ||
|
||
if not self.persistent_tracking: | ||
logger.debug(f"Clearing Queue after tracking") | ||
logger.debug(f"Clearing queue after tracking single batch") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the extraneous f
prefix in the logging message.
The f
prefix is unnecessary as there are no placeholders in the string.
- logger.debug(f"Clearing queue after tracking single batch")
+ logger.debug("Clearing queue after tracking single batch")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
logger.debug(f"Clearing queue after tracking single batch") | |
logger.debug("Clearing queue after tracking single batch") |
Tools
Ruff
123-123: f-string without any placeholders
Remove extraneous
f
prefix(F541)
Here we add functionality to turn off the visual encoder and just learn associations based on point coordinates
Summary by CodeRabbit