diff --git a/docs/source/jit.rst b/docs/source/jit.rst index 710f2f928c5f..ccd37738277f 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -547,10 +547,10 @@ best practices? cpu_model = gpu_model.cpu() sample_input_cpu = sample_input_gpu.cpu() - traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu) + traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) torch.jit.save(traced_cpu, "cpu.pth") - traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu) + traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) torch.jit.save(traced_gpu, "gpu.pth") # ... later, when using the model: