Skip to content

Commit

Permalink
[JIT] Fix archive file extension in examples and docs (#50649)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #50649

**Summary**
Tutorials, documentation and comments are not consistent with the file
extension they use for JIT archives. This commit modifies certain
instances of `*.pth` in `torch.jit.save` calls with `*.pt`.

**Test Plan**
Continuous integration.

**Fixes**
This commit fixes #49660.

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D25961628

Pulled By: SplitInfinity

fbshipit-source-id: a40c97954adc7c255569fcec1f389aa78f026d47
  • Loading branch information
Meghan Lele authored and facebook-github-bot committed Jan 20, 2021
1 parent e009665 commit 4aea007
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions docs/source/jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -548,17 +548,17 @@ best practices?
cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pth")
torch.jit.save(traced_cpu, "cpu.pt")

traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pth")
torch.jit.save(traced_gpu, "gpu.pt")

# ... later, when using the model:

if use_gpu:
model = torch.jit.load("gpu.pth")
model = torch.jit.load("gpu.pt")
else:
model = torch.jit.load("cpu.pth")
model = torch.jit.load("cpu.pt")

model(input)

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/jit/test_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TEST(GraphExecutorTest, runAsync_executor) {
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
return r1.wait() + r2.wait()
demo = DemoModule()
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth')
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pt')
*/
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/jit/test_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ TEST(InterpreterTest, runAsyncBasicTest) {
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
return r1.wait() + r2.wait()
demo = DemoModule()
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth')
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pt')
*/
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
Expand Down

0 comments on commit 4aea007

Please sign in to comment.