Skip to content
Permalink
Browse files

Run resnet (#3202)

Summary:
Run PyTorch Resnet/Resnext models on Glow except for AdaptiveAvgPool and FC

Documentation:
doxygen
Pull Request resolved: #3202

Test Plan:
added operator tests

`python setup.py test`

to run example:
`python setup.py develop --run_cmake --release`
then
`python examples/resnet_example.py`

Differential Revision: D16190914

Pulled By: jackm321

fbshipit-source-id: d20b63ce31e4b2dff3fdbec36954ca1cff1412b0
  • Loading branch information...
jackm321 authored and facebook-github-bot committed Jul 10, 2019
1 parent 99b6571 commit db6804cc2d8d7c286c5c3f6063339d5922bc2589
@@ -180,9 +180,10 @@ endif()
if (GLOW_BUILD_PYTORCH_INTEGRATION)
if(NOT EXISTS "${GLOW_THIRDPARTY_DIR}/pybind11")
message(FATAL_ERROR "No pybind11 git submodule. Run: git submodule update --init --recursive")
else()
add_subdirectory(${GLOW_THIRDPARTY_DIR}/pybind11)
endif()

add_subdirectory(${GLOW_THIRDPARTY_DIR}/pybind11)
add_subdirectory(torch_glow/src)
endif()

@@ -10,7 +10,6 @@ See `glow/torch_glow/examples` for illustrative examples.
## Setup
* Follow directions in Building.md to make sure Glow can be built
* install [PyTorch](https://pytorch.org/) nightly
* install torchvision: `pip install torchvision`
* cd to `glow/torch_glow`

## Usage
@@ -0,0 +1,52 @@
import torch
import torch_glow
from PIL import Image

import utils.torchvision_fake.transforms as torchvisionTransforms
import utils.torchvision_fake.resnet as resnet

import argparse

def load_image(image_path):
image = Image.open(image_path).convert('RGB')
transformed_image = transform_image(image)
return torch.reshape(transformed_image, (1, 3, 224, 224))

# given a PIL image, transform it to a normalized tensor for classification.
def transform_image(image):
image = torchvisionTransforms.resize(image, 256)
image = torchvisionTransforms.center_crop(image, 224)
image = torchvisionTransforms.to_tensor(image)
image = torchvisionTransforms.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return image

def run_model(model, image, use_glow, print_graph):
if use_glow:
torch_glow.enableFusionPass()
with torch.no_grad():
traced = torch.jit.trace(model, image)
if print_graph:
print(traced.graph_for(image))
all_outputs = traced(image)
topk = all_outputs.topk(5)
return(topk[1], topk[0])

def run():
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True, help="Location of the image to be classified")
parser.add_argument("--k", type=int, default=5, help="how many results to show")
parser.add_argument("--skip_glow", action='store_true', default=False, help="Don't run using Glow")
parser.add_argument("--print_graph", action='store_true', default=False, help="Don't run using Glow")
args = parser.parse_args()

image = load_image(args.image)
model = resnet.resnet18(pretrained=True, progress=True)
model.eval()
use_glow = not args.skip_glow

(indices, scores) = run_model(model, image, use_glow=use_glow, print_graph=args.print_graph)
print ("rank", "class", "P")
for i in xrange(args.k):
print(i, int(indices[0][i]), float(scores[0][i]))

run()
@@ -25,7 +25,7 @@ target_link_libraries(_torch_glow
ExecutionEngine
Graph
Importer
Interpreter
Support
Backends
pybind11)

@@ -41,8 +41,8 @@ void CachingGraphRunner::runGraph(const torch::jit::Node *node,
at::ArrayRef<torch::jit::IValue> inputs = torch::jit::last(stack, numInputs);

// TODO: cache loaderResult so we don't have to recompile every run.
PyTorchModelLoader loader(executionEngine_.getModule(),
graphInfo.subgraph.get(), inputs);
PyTorchModelLoader loader(&executionEngine_.getModule(),
graphInfo.subgraph.get(), &inputs);
loader.load();

glow::Function *f = loader.getFunction();

0 comments on commit db6804c

Please sign in to comment.
You can’t perform that action at this time.