Skip to content

Conversation

shunting314
Copy link
Contributor

@shunting314 shunting314 commented Nov 15, 2022

Sometimes it's really convenient to run simple models thru the torchbench.py script rather than those from pytorch/benchmark. This PR add the ability to run any model from a specified path by overloading the --only argument.

This PR is split out from #88904

Here is the usage:

    Specify the path and class name of the model in format like:
    --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME>

    Due to the fact that dynamo changes current working directory,
    the path should be an absolute path.

    The class should have a method get_example_inputs to return the inputs
    for the model. An example looks like
    ```
    class LinearModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(10, 10)

        def forward(self, x):
            return self.linear(x)

        def get_example_inputs(self):
            return (torch.randn(2, 10),)
    ```

Test command:

# python benchmarks/dynamo/torchbench.py --performance --only=path:/pytorch/myscripts/model_collection.py,class:LinearModel --backend=eager
WARNING:common:torch.cuda.is_available() == False, using CPU
cpu  eval  LinearModel                        0.824x p=0.00

Content of model_collection.py

from torch import nn
import torch

class LinearModel(nn.Module):
    """
    AotAutogradStrategy.compile_fn ignore graph with at most 1 call nodes.
    Make sure this model calls 2 linear layers to avoid being skipped.
    """
    def __init__(self, nlayer=2):
        super().__init__()
        layers = []
        for _ in range(nlayer):
            layers.append(nn.Linear(10, 10))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

    def get_example_inputs(self):
        return (torch.randn(2, 10),)

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire @JackCaoG

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 15, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89028

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Merge Blocking SEVs

There is 1 active merge blocking SEVs. Please view them below:

If you must merge, use @pytorchbot merge -f.

✅ No Failures

As of commit e8feb41:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 15, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at:
#89092

Details for Dev Infra team Raised by workflow job

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Sometimes it's really convenient to run simple models thru the torchbench.py script rather than those from pytorch/benchmark. This PR add the ability to run any model from a specified path by overloading the --only argument.

This PR is split out from pytorch#88904

Here is the usage:

        Specify the path and class name of the model in format like:
        --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME>

        Due to the fact that dynamo changes current working directory,
        the path should be an absolute path.

        The class should have a method get_example_inputs to return the inputs
        for the model. An example looks like
        ```
        class LinearModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(10, 10)

            def forward(self, x):
                return self.linear(x)

            def get_example_inputs(self):
                return (torch.randn(2, 10),)
        ```

Test command:
```
# python benchmarks/dynamo/torchbench.py --performance --only=path:/pytorch/myscripts/model_collection.py,class:LinearModel --backend=eager
WARNING:common:torch.cuda.is_available() == False, using CPU
cpu  eval  LinearModel                        0.824x p=0.00
```

Content of model_collection.py
```
from torch import nn
import torch

class LinearModel(nn.Module):
    """
    AotAutogradStrategy.compile_fn ignore graph with at most 1 call nodes.
    Make sure this model calls 2 linear layers to avoid being skipped.
    """
    def __init__(self, nlayer=2):
        super().__init__()
        layers = []
        for _ in range(nlayer):
            layers.append(nn.Linear(10, 10))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

    def get_example_inputs(self):
        return (torch.randn(2, 10),)
```

Pull Request resolved: pytorch#89028
Approved by: https://github.com/jansel
@github-actions github-actions bot deleted the dynamo-benchmark-model-from-path branch May 16, 2024 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants