Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented May 13, 2025

⚡️ This pull request contains optimizations for PR #1250

If you approve this dependent PR, these changes will be merged into the original PR branch feature/inference-v1-models.

This PR will be automatically closed if the original PR is merged.


📄 94% (0.94x) speedup for Dinov2WithRegistersDropPath.forward in inference/v1/models/rfdetr/dinov2_with_windowed_attn.py

⏱️ Runtime : 16.2 milliseconds 8.36 milliseconds (best of 62 runs)

📝 Explanation and details

Here is an optimized version of your code, focusing on runtime and memory reduction. The profiler indicates the vast majority of time is spent in the line.

We can optimize this by performing in-place operations (to reduce memory allocations and speed up computation), and by fusing more operations. Also, there is no need to construct shape using Python arithmetic every call—let's use tensor broadcasting and expand_as for efficiency.

Changes and rationale:

  • Replace .div(keep_prob) * random_tensor with input.mul_(random_tensor).div_(keep_prob) in-place, if it is safe (as no reuse of input).
  • Use expand_as(input) instead of shape tuple math.
  • Reuse allocated tensors when possible for memory efficiency.
  • Move some scalar ops out of the batch loop.
  • Only one allocation for the random tensor which is then modified in-place.

Performance rationale.

  • Only a single random tensor is allocated and modified in-place before use.
  • The shape creation is lightweight, and broadcasting/multiplication is fast.
  • We avoid an explicit .div() followed by a *, doing only the minimum required math using fused operations.
  • No unnecessary temporary allocations.

You could go further with.

  • Making this a CUDA custom function for maximal perf,
  • Or avoiding mul/div altogether with some bitmasking, if needed.

But as a drop-in, this is as fast as you can get in PyTorch with the existing logic.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 45 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage
🌀 Generated Regression Tests Details
from typing import Optional

# imports
import pytest  # used for our unit tests
import torch
from inference.v1.models.rfdetr.dinov2_with_windowed_attn import \
    Dinov2WithRegistersDropPath
from torch import nn

# unit tests

@pytest.fixture
def model():
    # Fixture to create a model instance with a specific drop probability
    return Dinov2WithRegistersDropPath(drop_prob=0.5)

def test_zero_drop_prob_training(model):
    # Test with zero drop probability and training mode on
    model.drop_prob = 0.0
    model.train()
    input_tensor = torch.rand(10, 5)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_non_zero_drop_prob_training(model):
    # Test with non-zero drop probability and training mode on
    model.drop_prob = 0.5
    model.train()
    input_tensor = torch.ones(10, 5)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_training_mode_off(model):
    # Test with training mode off
    model.drop_prob = 0.5
    model.eval()
    input_tensor = torch.rand(10, 5)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_empty_tensor(model):
    # Test with an empty tensor
    input_tensor = torch.empty(0, 5)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_single_element_tensor(model):
    # Test with a single element tensor
    input_tensor = torch.tensor([[1.0]])
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_high_dimensional_tensor(model):
    # Test with a high dimensional tensor
    input_tensor = torch.rand(10, 3, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_large_batch_size(model):
    # Test with a large batch size
    input_tensor = torch.rand(1024, 10)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_large_tensor_size(model):
    # Test with a large tensor size
    input_tensor = torch.rand(128, 3, 224, 224)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_randomness_consistency(model):
    # Test randomness consistency with a fixed seed
    torch.manual_seed(42)
    model.drop_prob = 0.3
    model.train()
    input_tensor = torch.rand(10, 5)
    codeflash_output = model.forward(input_tensor); output_tensor1 = codeflash_output
    torch.manual_seed(42)
    codeflash_output = model.forward(input_tensor); output_tensor2 = codeflash_output

def test_high_drop_prob(model):
    # Test with high drop probability
    model.drop_prob = 0.99
    model.train()
    input_tensor = torch.rand(10, 5)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_low_drop_prob(model):
    # Test with low drop probability
    model.drop_prob = 0.01
    model.train()
    input_tensor = torch.rand(10, 5)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from typing import Optional

# imports
import pytest  # used for our unit tests
import torch
from inference.v1.models.rfdetr.dinov2_with_windowed_attn import \
    Dinov2WithRegistersDropPath
from torch import nn

# unit tests

def test_no_dropout_training():
    """Test case for no dropout during training."""
    model = Dinov2WithRegistersDropPath(drop_prob=0.0)
    model.train()  # Set model to training mode
    input_tensor = torch.rand(10, 3, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_no_dropout_not_training():
    """Test case for no dropout during inference."""
    model = Dinov2WithRegistersDropPath(drop_prob=0.0)
    model.eval()  # Set model to evaluation mode
    input_tensor = torch.rand(10, 3, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_full_dropout_training():
    """Test case for full dropout during training."""
    model = Dinov2WithRegistersDropPath(drop_prob=1.0)
    model.train()  # Set model to training mode
    input_tensor = torch.rand(10, 3, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_full_dropout_not_training():
    """Test case for full dropout during inference."""
    model = Dinov2WithRegistersDropPath(drop_prob=1.0)
    model.eval()  # Set model to evaluation mode
    input_tensor = torch.rand(10, 3, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_zero_tensor_input():
    """Test case for zero tensor input."""
    model = Dinov2WithRegistersDropPath(drop_prob=0.5)
    model.train()  # Set model to training mode
    input_tensor = torch.zeros(10, 3, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_single_element_tensor():
    """Test case for single element tensor."""
    model = Dinov2WithRegistersDropPath(drop_prob=0.5)
    model.train()  # Set model to training mode
    input_tensor = torch.tensor([1.0])
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_high_dimensional_tensor():
    """Test case for high dimensional tensor."""
    model = Dinov2WithRegistersDropPath(drop_prob=0.5)
    model.train()  # Set model to training mode
    input_tensor = torch.rand(10, 3, 32, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_random_drop_probability():
    """Test case for random drop probability."""
    drop_prob = torch.rand(1).item()
    model = Dinov2WithRegistersDropPath(drop_prob=drop_prob)
    model.train()  # Set model to training mode
    input_tensor = torch.rand(10, 3, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_large_batch_size():
    """Test case for large batch size."""
    model = Dinov2WithRegistersDropPath(drop_prob=0.5)
    model.train()  # Set model to training mode
    input_tensor = torch.rand(1000, 3, 32, 32)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_large_tensor_size():
    """Test case for large tensor size."""
    model = Dinov2WithRegistersDropPath(drop_prob=0.5)
    model.train()  # Set model to training mode
    input_tensor = torch.rand(10, 3, 100, 100)
    codeflash_output = model.forward(input_tensor); output_tensor = codeflash_output

def test_invalid_drop_probability():
    """Test case for invalid drop probability."""
    model = Dinov2WithRegistersDropPath(drop_prob=-0.1)
    model.train()  # Set model to training mode
    input_tensor = torch.rand(10, 3, 32, 32)
    with pytest.raises(ValueError):
        model.forward(input_tensor)

    model = Dinov2WithRegistersDropPath(drop_prob=1.1)
    model.train()  # Set model to training mode
    with pytest.raises(ValueError):
        model.forward(input_tensor)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1250-2025-05-13T13.24.23 and push.

Codeflash

…1250 (`feature/inference-v1-models`)

Here is an optimized version of your code, focusing on runtime and memory reduction. The profiler indicates the vast majority of time is spent in the line.

We can optimize this by performing in-place operations (to reduce memory allocations and speed up computation), and by fusing more operations. Also, there is no need to construct `shape` using Python arithmetic every call—let's use tensor broadcasting and `expand_as` for efficiency.

**Changes and rationale:**
- Replace `.div(keep_prob) * random_tensor` with `input.mul_(random_tensor).div_(keep_prob)` in-place, if it is safe (as no reuse of input).
- Use `expand_as(input)` instead of shape tuple math.
- Reuse allocated tensors when possible for memory efficiency.
- Move some scalar ops out of the batch loop.
- Only one allocation for the random tensor which is then modified in-place.




**Performance rationale**.
- Only a single random tensor is allocated and modified in-place before use.
- The shape creation is lightweight, and broadcasting/multiplication is fast.
- We avoid an explicit `.div()` followed by a `*`, doing only the minimum required math using fused operations.
- No unnecessary temporary allocations.

You could go further with.
- Making this a CUDA custom function for maximal perf,
- Or avoiding mul/div altogether with some bitmasking, if needed.

But as a drop-in, this is as fast as you can get in PyTorch with the existing logic.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label May 13, 2025
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label May 13, 2025
@codeflash-ai codeflash-ai bot mentioned this pull request May 13, 2025
4 tasks
@grzegorz-roboflow
Copy link
Collaborator

Not relevant anymore, source branch received further updates.

@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr1250-2025-05-13T13.24.23 branch June 10, 2025 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants