Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
KsenijaS committed Oct 1, 2020
2 parents cb08a07 + 0b3ad54 commit 78c1cb4
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 73 deletions.
78 changes: 78 additions & 0 deletions .github/workflows/quantization_triage.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
name: quantization-triage

on:
issues:
types: [labeled]

jobs:
welcome:
runs-on: ubuntu-latest
steps:
- uses: actions/github-script@v2
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
// Arguments available:
// - github: A pre-authenticated octokit/rest.js client
// - context: An object containing the context of the workflow run
// - core: A reference to the @actions/core package
// - io: A reference to the @actions/io package
// Check if issue has a Quantization label.
const kQuantizationLabel = "oncall: quantization";
issue = await github.issues.get({
owner: context.issue.owner,
repo: context.issue.repo,
issue_number: context.issue.number,
})
const hasQuantizationLabel = issue.data.labels.filter(label => label.name == kQuantizationLabel).length > 0;
if (!hasQuantizationLabel) {
core.debug("Issue " + issue.data.title + " does not have Quantization label");
return;
}
// Get project column ID.
const kProjectName = "Quantization Triage";
const kColumnName = "Need Triage";
// Query all projects in the repository.
// TODO: Support pagination once there are > 30 projects.
const projects = await github.projects.listForRepo({
owner: context.issue.owner,
repo: context.issue.repo,
});
// Filter out unwanted projects and get the ID for the Quantization Triage project.
const filteredProjects = projects.data.filter(project => project.name == kProjectName);
if (filteredProjects.length != 1) {
core.setFailed("Unable to find a project named " + kProjectName);
return;
}
const projectId = filteredProjects[0].id;
// First, query all columns in the project.
// TODO: Support pagination once there are > 30 columns.
const columns = await github.projects.listColumns({
project_id: projectId,
});
// Filter out unwanted projects and get the ID for the Need triage column.
const filteredColumns = columns.data.filter(column => column.name == kColumnName);
if (filteredColumns.length != 1) {
core.setFailed("Unable to find a column named " + kColumnName);
return;
}
const columnId = filteredColumns[0].id;
// Create a project card for this new issue.
await github.projects.createCard({
column_id: columnId,
content_id: issue.data.id,
content_type: "Issue",
})
27 changes: 0 additions & 27 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pathlib import Path
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph
from torch.fx.experimental import GraphManipulation
from torch.fx.experimental import shape_prop

from torch.fx.proxy import TraceError

Expand Down Expand Up @@ -734,31 +733,5 @@ def test_wrong_topo(self):
with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
graph.lint()

def test_example_shape_prop(self):
class TestCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.randn(3, 4)
self.submod = torch.nn.Linear(4, 4)

def forward(self, x):
return torch.neg(self.submod(x.relu() + self.attr))
tc = TestCase()
tc_traced = symbolic_trace(tc)
ref_out = tc_traced(torch.rand(3, 4))

# Make sure we're testing all opcodes
opcodes = set()
for node in tc_traced.graph.nodes:
opcodes.add(node.op)
self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', 'call_module']))

# Test shape propogation and make sure results match actual
shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
self.assertEqual(tc_traced.graph.result.shape, ref_out.shape)




if __name__ == '__main__':
run_tests()
2 changes: 2 additions & 0 deletions torch/csrc/jit/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ std::ostream& Graph::print(std::ostream& out, bool print_source_locations)
out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
<< *fg->g(attr::Subgraph);
}
out.flush();

/*
// Uncomment this to debug all_nodes issues
{
Expand Down
46 changes: 0 additions & 46 deletions torch/fx/experimental/shape_prop.py

This file was deleted.

0 comments on commit 78c1cb4

Please sign in to comment.