Skip to content

Precomputing Transposes for frozen linear layers #65631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from

Conversation

Gamrix
Copy link
Contributor

@Gamrix Gamrix commented Sep 24, 2021

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 24, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit fa959e5 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added oncall: jit Add this issue/PR to JIT oncall triage queue cla signed labels Sep 24, 2021
Gamrix added a commit that referenced this pull request Sep 24, 2021
ghstack-source-id: f69787b
Pull Request resolved: #65631
Comment on lines +1 to 18
import io
import unittest
from itertools import product
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import unittest
from torch.testing._internal.jit_utils import JitTestCase
from torch._C import parse_ir

from torch.jit._recursive import wrap_cpp_module
from torch.testing import FileCheck
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import set_default_dtype
from torch.testing._internal.jit_utils import JitTestCase
from torch.utils import mkldnn as mkldnn_utils


from torch.jit._recursive import wrap_cpp_module
from typing import Any
from itertools import product
import io

try:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess some linting step I did ran isort.

Gamrix added a commit that referenced this pull request Sep 24, 2021
ghstack-source-id: 2239b0e
Pull Request resolved: #65631
@eellison eellison self-requested a review September 27, 2021 20:24
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Can you add more comments to this - both to the PR description detailing the benchmarking that was done and also inline into the pass itself ?

Also, i thought we determined this was good for CPU as well ?

if (node->kind() != aten::linear) {
return false;
}
auto weight = node->namedInput("weight");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight is never none, do you mean bias ? Can we handle the no-bias case as well here ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like bias is actually handled


void handleBlockAndSubblocks(Block* block) {
// Can't delete nodes while also iterating over it
std::vector<Node*> constant_linear_nodes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node->destroy();
};

void handleBlockAndSubblocks(Block* block) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block is a good use case for

class DepthFirstGraphNodeIterator {

namespace torch {
namespace jit {

// Concats multiple linear ops with the same Tensor input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad copy pasta


private:
std::shared_ptr<Graph> graph_;
bool graph_modified = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: graph_modified -> graph_modified_

@Gamrix
Copy link
Contributor Author

Gamrix commented Sep 29, 2021

Given that we determined in the meeting that single core CPU is the main metric to look at (due to prod uses usually running many concurrent processes with 1 core per process), and that multicore CPU didn't significantly regress, I will remove the restriction on only running this on GPU.

Link to benchmarking results
https://fb.quip.com/6RmbAN7ZBXrD

Gamrix added a commit that referenced this pull request Sep 29, 2021
ghstack-source-id: de711dd
Pull Request resolved: #65631
Gamrix added a commit that referenced this pull request Sep 30, 2021
ghstack-source-id: d6ca2ea
Pull Request resolved: #65631
@Gamrix
Copy link
Contributor Author

Gamrix commented Sep 30, 2021

@Gamrix has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@Gamrix Gamrix requested a review from eellison October 1, 2021 00:08
@facebook-github-bot facebook-github-bot deleted the gh/gamrix/10/head branch October 9, 2021 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants