PyTorchμ μ΄λ¦μμ μ μ μλ―μ΄ PyTorchλ Python νλ‘κ·Έλλ° μΈμ΄λ₯Ό κΈ°λ³Έ μΈν°νμ΄μ€λ‘ νκ³ μμ΅λλ€. Pythonμ λμ μ±κ³Ό μ μν μ΄ν°λ μ΄μ μ΄ νμν μν©μ μ ν©νκ³ μ νΈλλ μΈμ΄μ λλ€. νμ§λ§ λ§μ°¬κ°μ§λ‘ μ΄λ¬ν Pythonμ νΉμ§λ€μ΄ Pythonμ μ¬μ©νκΈ° μ ν©νμ§ μκ² λ§λλ μν©λ λ§μ΄ λ°μν©λλ€. Pythonμ μ¬μ©νκΈ° μ ν©νμ§ μμ λνμ μΈ μλ‘ μμ© νκ²½μ΄ μμ΅λλ€. μμ© νκ²½μμλ 짧μ μ§μ°μκ°μ΄ μ€μνκ³ λ°°ν¬νλ λ°μλ λ§μ μ μ½μ΄ λ°λ¦ λλ€. μ΄λ‘ μΈν΄ μμ© νκ²½μμλ λ§μ μ¬λλ€μ΄ C++λ₯Ό κ°λ°μΈμ΄λ‘ μ±ννκ² λ©λλ€. λ¨μ§ Java, Rust, λλ Goμ κ°μ λ€λ₯Έ μΈμ΄λ€μ λ°μΈλ©νκΈ° μν λͺ©μ μΌ λΏμΌμ§λΌλ λ§μ΄μ£ . μμΌλ‘ μ΄ νν 리μΌμμ μ΄λ»κ² PyTorchμμ PythonμΌλ‘ μμ±λ λͺ¨λΈλ€μ Python μμ‘΄μ±μ΄ μ ν μλ C++νκ²½μμλ μ½κ³ μ€νν μ μλ λ°©μμΌλ‘ μ§λ ¬νν μ μλμ§ μμλ³΄κ² μ΅λλ€.
Torch Script λ PyTorch λͺ¨λΈμ Pythonμμ C++λ‘ λ³ννλ κ²μ κ°λ₯νκ² ν΄μ€λλ€. TorchScriptλ TorchScript μ»΄νμΌλ¬κ° μ΄ν΄νκ³ , μ»΄νμΌνκ³ , μ§λ ¬νν μ μλ PyTorch λͺ¨λΈμ ν ννλ°©μμ λλ€. λ§μ½ κΈ°λ³Έμ μΈ "μ¦μ μ€ν"[μμ μ£Ό: eager execution] APIλ₯Ό μ¬μ©ν΄ μμ±λ PyTorch λͺ¨λΈμ΄ μλ€λ©΄, μ²μμΌλ‘ ν΄μΌ ν μΌμ μ΄ λͺ¨λΈμ TorchScript λͺ¨λΈλ‘ λ³ννλ κ²μ λλ€. μλμ μ€λͺ λμ΄ μλ―μ΄, λλΆλΆμ κ²½μ°μ μ΄ κ³Όμ μ λ§€μ° κ°λ¨ν©λλ€. μ΄λ―Έ TorchScript λͺ¨λμ κ°μ§κ³ μλ€λ©΄, μ΄ μΉμ μ 건λλ°μ΄λ μ’μ΅λλ€.
PyTorch λͺ¨λΈμ TorchScriptλ‘ λ³ννλ λ°©λ²μλ λκ°μ§κ° μμ΅λλ€. 첫λ²μ§Έλ νΈλ μ΄μ±(tracing)μ΄λΌλ λ°©λ²μΌλ‘ μ΄λ€ μ λ ₯κ°μ μ¬μ©νμ¬ λͺ¨λΈμ ꡬ쑰λ₯Ό νμ νκ³ μ΄ μ λ ₯κ°μ λͺ¨λΈ μμμμ νλ¦μ ν΅ν΄ λͺ¨λΈμ κΈ°λ‘νλ λ°©μμ λλ€. μ΄ λ°©λ²μ 쑰건문μ λ§μ΄ μ¬μ©νμ§ μλ λͺ¨λΈμ κ²½μ°μ μ ν©ν©λλ€. PyTorch λͺ¨λΈμ TorchScriptλ‘ λ³ννλ λλ²μ§Έ λ°©λ²μ λͺ¨λΈμ λͺ μμ μΈ μ΄λ Έν μ΄μ (annotation)μ μΆκ°νμ¬ TorchScript μ»΄νμΌλ¬λ‘ νμ¬κΈ μ§μ λͺ¨λΈ μ½λλ₯Ό λΆμνκ³ μ»΄νμΌνκ² νλ λ°©μμ λλ€. μ΄ λ°©μμ μ¬μ©ν λλ TorchScript μΈμ΄ μ체μ μ μ½μ΄ μμ μ μμ΅λλ€.
Tip
μ λ λ°©μμ κ΄λ ¨λ μ 보μ λ μ€ μ΄λ€ λ°©λ²μ μ¬μ©ν΄μΌ ν μ§ λ±μ λν κ°μ΄λλ 곡μ κΈ°μ λ¬ΈμμΈ Torch Script reference μμ νμΈνμ€ μ μμ΅λλ€.
PyTorch λͺ¨λΈμ νΈλ μ΄μ±μ ν΅ν΄ TorchScriptλ‘ λ³ννκΈ° μν΄μλ, μ¬λ¬λΆμ΄ ꡬνν λͺ¨λΈμ μΈμ€ν΄μ€λ₯Ό
μμ μ
λ ₯κ°κ³Ό ν¨κ» torch.jit.trace
ν¨μμ λ겨주μ΄μΌ ν©λλ€. κ·Έλ¬λ©΄ μ΄ ν¨μλ torch.jit.ScriptModule
κ°μ²΄λ₯Ό μμ±νκ² λ©λλ€. μ΄λ κ² μμ±λ κ°μ²΄μλ λͺ¨λμ forward
λ©μλμ λͺ¨λΈ μ€νμ λ°νμμ traceν
κ²°κ³Όκ° ν¬ν¨λκ² λ©λλ€:
import torch import torchvision # λͺ¨λΈ μΈμ€ν΄μ€ μμ± model = torchvision.models.resnet18() # μΌλ°μ μΌλ‘ λͺ¨λΈμ forward() λ©μλμ λ겨주λ μ λ ₯κ° example = torch.rand(1, 3, 224, 224) # torch.jit.traceλ₯Ό μ¬μ©νμ¬ νΈλ μ΄μ±μ μ΄μ©ν΄ torch.jit.ScriptModule μμ± traced_script_module = torch.jit.trace(model, example)
μ΄λ κ² traceλ ScriptModule
μ μΌλ°μ μΈ PyTorch λͺ¨λκ³Ό κ°μ λ°©μμΌλ‘ μ
λ ₯κ°μ λ°μ
μ²λ¦¬ν μ μμ΅λλ€:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224)) In[2]: output[0, :5] Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
νΉμ ν νκ²½(κ°λ Ή λͺ¨λΈμ΄ μ΄λ€ μ μ΄νλ¦μ μ¬μ©νκ³ μλ κ²½μ°)μμλ μ¬λ¬λΆμ λͺ¨λΈμ μ΄λ Έν μ΄νΈ(annotate)νμ¬ TorchScriptλ‘ λ°λ‘ μμ±νλ κ²μ΄ λ°λμ§ν κ²½μ°κ° μμ΅λλ€. μλ₯Ό λ€μ΄, μλμ κ°μ PyTorch λͺ¨λΈμ΄ μλ€κ³ κ°μ νκ² μ΅λλ€:
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) def forward(self, input): if input.sum() > 0: output = self.weight.mv(input) else: output = self.weight + input return output
μ΄ λͺ¨λμ forward
λ©μλλ μ
λ ₯κ°μ μν₯μ λ°λ μ μ΄νλ¦μ μ¬μ©νκ³ μκΈ° λλ¬Έμ, μ΄ λͺ¨λμ
νΈλ μ΄μ±μλ μ ν©νμ§ μμ΅λλ€. λμ μ°λ¦¬λ μ΄ λͺ¨λμ ScriptModule
λ‘ λ³νν μ μμ΅λλ€.
λͺ¨λμ ScriptModule
λ‘ λ³ννκΈ° μν΄μλ, μλμ κ°μ΄ torch.jit.script
ν¨μλ₯Ό μ¬μ©ν΄
λͺ¨λμ μ»΄νμΌν΄μΌ ν©λλ€:
class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) def forward(self, input): if input.sum() > 0: output = self.weight.mv(input) else: output = self.weight + input return output my_module = MyModule(10,20) sm = torch.jit.script(my_module)
μμ§ TorchScriptμμ μ§μνμ§ μλ Python κΈ°λ₯μ μ¬μ©νκ³ μλ λ©μλλ€μ μ¬λ¬λΆμ nn.Module
μμ μ μΈνκ³ μΆλ€λ©΄, κ·Έ λ©μλλ€μ @torch.jit.ignore
λ‘ μ΄λ
Έν
μ΄νΈνλ©΄ λ©λλ€.
sm
μ μ§λ ¬ν(serialization) μ€λΉκ° λ ScriptModule
μ μΈμ€ν΄μ€μ
λλ€.
λͺ¨λΈμ νΈλ μ΄μ±μ΄λ μ΄λ
Έν
μ΄ν
μ ν΅ν΄ ScriptModule
λ‘ λ³ννμλ€λ©΄, μ΄μ κ·Έκ²μ νμΌλ‘ μ§λ ¬νν
μλ μμ΅λλ€. λμ€μ C++λ₯Ό μ΄μ©ν΄ νμΌλ‘λΆν° λͺ¨λμ μ½μ΄μ¬ μ μκ³ Pythonμ μ΄λ€ μμ‘΄μ±λ μμ΄
κ·Έ λͺ¨λμ μ€νν μ μμ΅λλ€. μλ₯Ό λ€μ΄ νΈλ μ΄μ± μμμμ λ€μλ ResNet18
λͺ¨λΈμ
μ§λ ¬ννκ³ μΆλ€κ³ κ°μ ν©μλ€. μ§λ ¬νλ₯Ό νκΈ° μν΄μλ, save
ν¨μλ₯Ό νΈμΆνκ³ λͺ¨λκ³Ό νμΌλͺ
λ§ λ겨주면 λ©λλ€:
traced_script_module.save("traced_resnet_model.pt")
μ΄ ν¨μλ traced_resnet_model.pt
νμΌμ μμ
λλ ν 리μ μμ±ν κ²μ
λλ€. λ§μ½ μ΄λ
Έν
μ΄μ
μμμ
sm
μ μ§λ ¬ννκ³ μΆλ€λ©΄, sm.save("my_module_model.pt")
λ₯Ό
νΈμΆνλ©΄ λ©λλ€. μ΄λ‘μ¨ μ΄μ Pythonμ μΈκ³μμ λ²μ΄λ C++ νκ²½μμ μμ
ν μ€λΉλ₯Ό λ§μ³€μ΅λλ€.
μ§λ ¬νλ PyTorch λͺ¨λΈμ C++μμ λ‘λνκΈ° μν΄μλ, μ΄ν리μΌμ΄μ μ΄ λ°λμ LibTorch λΌκ³ λΆλ¦¬λ PyTorch C++ APIλ₯Ό μ¬μ©ν΄μΌ ν©λλ€. LibTorchλ μ¬λ¬ 곡μ λΌμ΄λΈλ¬λ¦¬λ€, ν€λ νμΌλ€, κ·Έλ¦¬κ³ CMake λΉλ μ€μ νμΌλ€μ ν¬ν¨νκ³ μμ΅λλ€. CMakeλ LibTorchλ₯Ό μ°κΈ°μν νμ μꡬμ¬νμ μλμ§λ§, κΆμ₯λλ λ°©μμ΄κ³ ν₯νμλ κ³μ μ§μλ μμ μ λλ€. μ΄ νν 리μΌμμλ CMakeμ LibTorchλ₯Ό μ¬μ©νμ¬ μ§λ ¬νλ PyTorch λͺ¨λΈμ μ½κ³ μ€ννλ μμ£Ό κ°λ¨ν C++ μ΄ν리μΌμ΄μ μ λ§λ€μ΄λ³΄λλ‘ νκ² μ΅λλ€.
μ°μ λͺ¨λμ λ‘λνλ μ½λμ λν΄ μ΄ν΄λ³΄λλ‘ νκ² μ΅λλ€. μλμ κ°λ¨ν μ½λλ‘ λͺ¨λμ μ½κ² μ½μ΄μ¬ μ μμ΅λλ€:
#include <torch/script.h> // νμν λ¨ νλμ ν€λνμΌ.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
torch::jit::script::Module module;
try {
// torch::jit::load()μ μ¬μ©ν΄ ScriptModuleμ νμΌλ‘λΆν° μμ§λ ¬ν
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "ok\n";
}
<torch/script.h>
ν€λλ μμλ₯Ό μ€ννκΈ° μν λͺ¨λ LibTorch λΌμ΄λΈλ¬λ¦¬λ₯Ό ν¬ν¨νκ³ μμ΅λλ€.
μ°λ¦¬μ μ΄ν리μΌμ΄μ
μ μ§λ ¬νλ PyTorch ScriptModule
μ κ²½λ‘λ₯Ό μ μΌν λͺ
λ Ήν μΈμλ‘ μ
λ ₯λ°κ³
μ΄ νμΌκ²½λ‘λ₯Ό μΈμλ‘ λ°λ torch::jit::load()
λ₯Ό μ¬μ©ν΄ λͺ¨λμ μμ§λ ¬νν©λλ€. κ·Έ κ²°κ³Όλ‘
torch::jit::script::Module
λ₯Ό λλ €λ°μ΅λλ€. μ΄ λ¦¬ν΄λ°μ λͺ¨λμ μ΄λ»κ² μ¬μ©νλμ§μ λν΄μλ 곧 μ΄ν΄λ³΄κ² μ΅λλ€.
μμ μ½λλ₯Ό example-app.cpp
μ΄λΌλ νμΌμ μ μ₯νμλ€κ³ κ°μ ν©λλ€. μ μ½λλ₯Ό λΉλνκΈ° μν
κ°λ¨ν CMakeLists.txt
μ
λλ€:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
μμ μ΄ν리μΌμ΄μ μ λΉλνκΈ° μν΄ λ§μ§λ§μΌλ‘ νμν κ²μ LibTorch λ°°ν¬νμ λλ€. μΈμ λ κ°μ₯ μ΅μ μ μμ λ²μ μ PyTorch μΉμ¬μ΄νΈμ download page λ‘λΆν° λ°μΌμ€ μ μμ΅λλ€. κ°μ₯ μ΅μ λ²μ μ λ€μ΄λ‘λ λ°μ μμΆμ νΈμλ©΄, μλμ κ°μ λλ ν 리 ꡬ쑰μ ν΄λλ₯Ό νμΈνμ€ μ μμ΅λλ€:
libtorch/
bin/
include/
lib/
share/
lib/
ν΄λλ λ§ν¬ν΄μΌ ν 곡μ λΌμ΄λΈλ¬λ¦¬λ₯Ό ν¬ν¨νκ³ μμ΅λλ€.include/
ν΄λλ μ¬λ¬λΆμ νλ‘κ·Έλ¨μ΄ include ν΄μΌ ν ν€λ νμΌλ€μ λ΄κ³ μμ΅λλ€.share/
ν΄λλ μμμ μ€νν κ°λ¨ν λͺ λ Ήμ΄μΈfind_package(Torch)
λ₯Ό μ€ννκ² ν΄μ£Όλ CMake μ€μ μ λ΄κ³ μμ΅λλ€.
Tip
μλμ°μμλ λλ²κ·Έ λΉλμ λ¦΄λ¦¬μ¦ λΉλκ° ABI-compatibleνμ§ μμ΅λλ€. λ§μ½ νλ‘μ νΈλ₯Ό debug λͺ¨λμμ λΉλνκ³ μΆλ€λ©΄, LibTorchμ debug λ²μ μ μ¬μ©ν΄μΌν©λλ€. κ·Έλ¦¬κ³ cmake --build .` μ μλ§μ μ€μ μ λͺ μν΄ μ£Όμ΄μΌ ν©λλ€.
λ§μ§λ§ λ¨κ³λ μ΄ν리μΌμ΄μ μ λΉλνλ κ²μ λλ€. μ΄λ₯Ό μν΄μ λλ ν 리 κ΅¬μ‘°κ° μλμ κ°μ΄ κ°λ€κ³ κ°μ νκ² μ΅λλ€.
example-app/
CMakeLists.txt
example-app.cpp
μ΄μ μλ λͺ
λ Ήμ΄λ€μ μ¬μ©ν΄ example-app/
ν΄λ μμμ μ΄ν리μΌμ΄μ
μ λΉλν μ μμ΅λλ€.
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release
μ¬κΈ°μ /path/to/libtorch
λ LibTorch λ°°ν¬νμ μμΆμ νΌ μ 체 κ²½λ‘μ
λλ€. λͺ¨λ κ²μ΄ μ λμλ€λ©΄,
μλμ κ°μ κ²μ΄ λνλ κ²μ
λλ€:
root@4b5a67132e81:/example-app# mkdir build
root@4b5a67132e81:/example-app# cd build
root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Configuring done
-- Generating done
-- Build files have been written to: /example-app/build
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
μ΄μ traceλ ResNet18
λͺ¨λΈμΈ traced_resnet_model.pt
κ²½λ‘λ₯Ό example-app
λ°μ΄λ리μ
μ
λ ₯νλ€λ©΄, "ok" λ©μμ§λ₯Ό νμΈν μ μμ κ²μ
λλ€. λ§μ½ μ΄ μμ μ my_module_model.pt
λ₯Ό
μΈμλ‘ λκ²Όλ€λ©΄, μ
λ ₯κ°μ΄ νΈνλμ§ μλ λͺ¨μμ΄λΌλ μλ¬λ©μμ§κ° μΆλ ₯λ©λλ€. my_module_model.pt
λ
4Dκ° μλ 1D ν
μλ₯Ό λ°λλ‘ λμ΄μκΈ° λλ¬Έμ
λλ€.
root@4b5a67132e81:/example-app/build# ./example-app <path_to_model>/traced_resnet_model.pt
ok
ResNet18
μ C++μμ μ±κ³΅μ μΌλ‘ λ‘λ©ν λ€, μ΄μ λͺ μ€μ μ½λλ§ λ μΆκ°νλ©΄ λͺ¨λμ μ€νν μ μμ΅λλ€.
C++ μ΄ν리μΌμ΄μ
μ main()
ν¨μμ μλμ μ½λλ₯Ό μΆκ°νκ² μ΅λλ€.
// μ
λ ₯κ° λ²‘ν°λ₯Ό μμ±ν©λλ€.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// λͺ¨λΈμ μ€νν λ€ λ¦¬ν΄κ°μ ν
μλ‘ λ³νν©λλ€.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
첫 λμ€μ λͺ¨λΈμ μ
λ ₯κ°μ μμ±ν©λλ€. torch::jit::IValue
(script::Module
λ©μλλ€μ΄
μ
λ ₯λ°κ³ λ 리ν΄ν μ μλ νμ
μ΄ μκ±°λ μλ£ν)μ 벑ν°λ₯Ό λ§λ€κ³ κ·Έ 벑ν°μ νλμ μ
λ ₯κ°μ μΆκ°ν©λλ€.
μ
λ ₯κ° ν
μλ₯Ό λ§λ€κΈ° μν΄μ μ°λ¦¬λ torch::ones()
μ μ¬μ©ν©λλ€. μ΄ ν¨μλ torch.ones
μ C++ API λ²μ μ
λλ€.
μ΄μ script::Module
μ forward
λ©μλμ μ
λ ₯κ° λ²‘ν°λ₯Ό λκ²¨μ£Όμ΄ μ€ννλ©΄, μ°λ¦¬λ μλ‘μ΄
IValue
λ₯Ό 리ν΄λ°κ² λκ³ , μ΄ κ°μ toTensor()
λ₯Ό ν΅ν΄ ν
μλ‘ λ³νν μ μμ΅λλ€.
Tip
torch::ones
λ₯Ό λΉλ‘―ν PyTorch C++ APIμ λν΄ λ μκ³ μΆλ€λ©΄ https://pytorch.org/cppdocsμ μλ
λ¬Έμλ₯Ό μ°Έκ³ νμλ©΄ λ©λλ€. PyTorch C++ APIλ Python APIμ κ±°μ λμΌν κΈ°λ₯μ μ 곡νμ¬ μ¬μ©μλ€μ΄
ν
μλ₯Ό λ€λ£¨κ³ μ¬μ©νλ κ²μ Pythonκ³Ό λμΌνκ² ν μ μλλ‘ ν©λλ€.
λ§μ§λ§ μ€μμ μΆλ ₯κ°μ 첫 λ€μ― κ°λ€μ νλ¦°νΈν©λλ€. μ΄λ² νν 리μΌμ μλΆλΆμμ Python λͺ¨λΈμ λμΌν μ λ ₯κ°μ λ겨주μκΈ° λλ¬Έμ, μ΄ λΆλΆμμλ μΆλ ₯κ°μ κ°μ κ²μ΄λΌκ³ μμν μ μμ΅λλ€. κ·ΈλΌ μ΄ν리μΌμ΄μ μ λ€μ μ»΄νμΌνκ³ κ°μ μ§λ ¬νλ λͺ¨λΈμ λν΄ μ€νν΄ λ³΄κ² μ΅λλ€:
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
root@4b5a67132e81:/example-app/build# ./example-app traced_resnet_model.pt
-0.2698 -0.0381 0.4023 -0.3010 -0.0448
[ Variable[CPUFloatType]{1,5} ]
μ°Έκ³ λ‘, μ΄μ μ Pythonμμμ μΆλ ₯κ°μ μλμ κ°μμ΅λλ€:
tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
λ μΆλ ₯κ°μ΄ μΌμΉνλ κ±Έ νμΈνμ€ μ μμ΅λλ€!
Tip
λͺ¨λΈμ GPU λ©λͺ¨λ¦¬μ μ¬λ¦¬κΈ° μν΄μλ, model.to(at::kCUDA);
λ₯Ό μ¬μ©νλ©΄ λ©λλ€.
λͺ¨λΈμ λ겨주λ μ
λ ₯κ°λ€μ λν΄μλ tensor.to(at::kCUDA)
λ₯Ό ν΅ν΄ CUDA λ©λͺ¨λ¦¬μ μ¬λ¦° λ€
μ¬μ©ν΄μΌν©λλ€. tensor.to(at::kCUDA)
λ CUDA λ©λͺ¨λ¦¬μ μλ μλ‘μ΄ ν
μλ₯Ό 리ν΄ν©λλ€.
μ΄ νν 리μΌμ΄ PyTorch λͺ¨λΈμ PythonμμλΆν° C++λ‘ λ³ννλ κ³Όμ μ μ΄ν΄νλλ° λμμ΄ λμκΈΈ λ°λλλ€.
λ³Έ νν 리μΌμμ λ€λ£¬ κ°λ
λ€λ‘, μ¬λ¬λΆμ μ΄μ "μ¦μ μ€ν" λ²μ μ PyTorch λͺ¨λΈμμλΆν° Pythonμμ μ»΄νμΌλ ScriptModule
λ‘,
λ λμκ° λμ€ν¬ μμ μ§λ ¬νλ νμΌλ‘, κ·Έλ¦¬κ³ λ§μ§λ§μΌλ‘ C++μμ μ€νκ°λ₯ν script::Module
κΉμ§ λ§λ€
μ μκ² λμμ΅λλ€.
λ¬Όλ‘ μ΄ νν 리μΌμμ λ€λ£¨μ§ λͺ»ν κ°λ
λ€λ λ§μ΅λλ€. μλ₯Ό λ€μ΄ μ¬λ¬λΆμ ScriptModule
μ΄ C++λ CUDAλ‘
μ μλ 컀μ€ν
μ°μ°μλ₯Ό μ¬μ©ν μ μκ²νλ λ°©λ² λλ μ΄λ¬ν 컀μ€ν
μ°μ°μλ₯Ό C++ μμ© νκ²½μ ScriptModule
μμ
μ¬μ©ν μ μκ²νλ λ°©λ²μ λν΄μλ λ³Έ νν 리μΌμμ λ€λ£¨μ§ μμμ΅λλ€. μ’μ μμμ μ΄λ¬ν κ²λ€μ΄ κ°λ₯νλ€λ κ²μ΄κ³ μ§μλκ³
μλ€λ μ μ
λλ€! μ ν¬κ° 곧 μ΄κ²μ κ΄ν νν 리μΌμ μ
λ‘λν λκΉμ§ μ΄ ν΄λ
λ₯Ό μμλ‘ μΌμ μ°Έκ³ νμλ©΄ λκ² μ΅λλ€. λ μλ λ§ν¬λ€μ΄ λμμ΄ λ κ²μ
λλ€:
- The Torch Script reference: https://pytorch.org/docs/master/jit.html
- The PyTorch C++ API documentation: https://pytorch.org/cppdocs/
- The PyTorch Python API documentation: https://pytorch.org/docs/
μΈμ λ κ·Έλ λ―μ΄, λ¬Έμ λ₯Ό λ§λ₯λ¨λ¦¬μκ±°λ μ§λ¬Έμ΄ μμΌμλ©΄ μ ν¬ forum λλ GitHub issues μ μ¬λ €μ£Όμλ©΄ λκ² μ΅λλ€.