์ด ๋ ์ํผ์์๋ ๋ค์ ๋ด์ฉ์ ์์๋ด ๋๋ค:
- TorchScript๋?
- ํ์ต๋ ๋ชจ๋ธ์ TorchScript ํ์์ผ๋ก ๋ด๋ณด๋ด๊ธฐ
- TorchScript ๋ชจ๋ธ์ C++๋ก ๋ถ๋ฌ์ค๊ณ ์ถ๋ก ํ๊ธฐ
- PyTorch 1.5
- TorchVision 0.6.0
- libtorch 1.5
- C++ ์ปดํ์ผ๋ฌ
3๊ฐ์ง PyTorch ์ปดํฌ๋ํธ๋ฅผ ์ค์นํ๋ ๋ฐฉ๋ฒ์ `pytorch.org`_์์ ํ์ธํ ์ ์์ต๋๋ค. C++ ์ปดํ์ผ๋ฌ๋ ๋น์ ์ ํ๋ซํผ์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋๋ค.
TorchScript**๋ C++ ๊ฐ์ ๊ณ ์ฑ๋ฅ ํ๊ฒฝ์์ ์คํํ ์ ์๋ PyTorch ๋ชจ๋ธ์ ์ค๊ฐ ํํ(``nn.Module``์ ํ์ ํด๋์ค)์ ๋๋ค. Python์ ๊ณ ์ฑ๋ฅ ํ์ ์งํฉ์ด๋ฉฐ ๋ชจ๋ธ ์ฐ์ฐ์ ๋ฐํ์ ์ต์ ํ๋ฅผ ์ํํ๋ **PyTorch JIT Compiler, ์์ ์ฌ์ฉ๋ฉ๋๋ค. TorchScript๋ PyTorch ๋ชจ๋ธ์์ ์ค์ผ์ผ ์ถ๋ก ์ ์ํํ ๋ ๊ถ์ฅ๋๋ ๋ชจ๋ธ ํ์์ ๋๋ค. ์์ธํ ๋ด์ฉ์ `pytorch.org`_์ ์๋ `Introduction to TorchScript tutorial`_, Loading A TorchScript Model in C++ tutorial, full TorchScript documentation ์์ ํ์ธํ์ธ์.
์ฌ์ ํ์ต๋ ์๊ฐ ๋ชจ๋ธ์ ์ดํด๋ด ์๋ค. TorchVision์ ๋ชจ๋ ์ฌ์ ํ์ต ๋ชจ๋ธ์ TorchScript์ ํธํ๋ฉ๋๋ค.
์คํฌ๋ฆฝํธ๋ REPL์์ ๋ค์์ Python 3 ์ฝ๋๋ฅผ ์คํํ์ธ์:
import torch
import torch.nn.functional as F
import torchvision.models as models
r18 = models.resnet18(pretrained=True) # ์ด์ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ ์ธ์คํด์ค๊ฐ ์์ต๋๋ค.
r18_scripted = torch.jit.script(r18) # *** ์ฌ๊ธฐ๊ฐ TorchScript๋ก ๋ด๋ณด๋ด๋ ๋ถ๋ถ์
๋๋ค.
dummy_input = torch.rand(1, 3, 224, 224) # ๋น ๋ฅด๊ฒ ํ
์คํธ ํด๋ด
๋๋ค.
๋ ๋ชจ๋ธ์ด ์ ๋ง ๊ฐ์์ง์ ๋ํด ์ ๋ฐ ๊ฒ์ฌ๋ฅผ ํด๋ณด๊ฒ ์ต๋๋ค.
unscripted_output = r18(dummy_input) # ์คํฌ๋ฆฝํธํ ๋์ง ์์ ๋ชจ๋ธ์ ์์ธก์ ์ป๊ณ ... scripted_output = r18_scripted(dummy_input) # ...์คํฌ๋ฆฝํธํ ๋ ๋ชจ๋ธ๋ ๋๊ฐ์ด ๋ฐ๋ณตํฉ๋๋ค. unscripted_top5 = F.softmax(unscripted_output, dim=1).topk(5).indices scripted_top5 = F.softmax(scripted_output, dim=1).topk(5).indices print('Python model top 5 results:\n {}'.format(unscripted_top5)) print('TorchScript model top 5 results:\n {}'.format(scripted_top5))
๋ ๋ชจ๋ธ์ ๊ฒฐ๊ณผ๊ฐ ๋์ผํจ์ ํ์ธํ ์ ์์ต๋๋ค:
Python model top 5 results: tensor([[463, 600, 731, 899, 898]]) TorchScript model top 5 results: tensor([[463, 600, 731, 899, 898]])
๊ฒ์ฌ๊ฐ ๋๋ฌ์ผ๋ฉด ๋ชจ๋ธ์ ์ ์ฅํฉ๋๋ค:
r18_scripted.save('r18_scripted.pt')
๋ค์๊ณผ ๊ฐ์ C++ ํ์ผ์ ๋ง๋ค๊ณ ํ์ผ๋ช
์ ts-infer.cpp
๋ผ๊ณ ์ง์ต๋๋ค.
#include <torch/script.h>
#include <torch/nn/functional/activation.h>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: ts-infer <path-to-exported-model>\n";
return -1;
}
std::cout << "Loading model...\n";
// ScriptModule์ ์ญ์ง๋ ฌํ(deserialize) ํฉ๋๋ค.
torch::jit::script::Module module;
try {
module = torch::jit::load(argv[1]);
} catch (const c10::Error& e) {
std::cerr << "Error loading model\n";
std::cerr << e.msg_without_backtrace();
return -1;
}
std::cout << "Model loaded successfully\n";
torch::NoGradGuard no_grad; // autograd๊ฐ ๊บผ์ ธ์๋์ง ํ์ธํฉ๋๋ค.
module.eval(); // dropout๊ณผ ํ์ต ๋จ์ ๋ ์ด์ด ๋ฐ ํจ์๋ค์ ๋๋๋ค.
// ์
๋ ฅ "์ด๋ฏธ์ง"๋ฅผ ์์ฑํฉ๋๋ค.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::rand({1, 3, 224, 224}));
// ๋ชจ๋ธ์ ์คํํ๊ณ ์ถ๋ ฅ ๊ฐ์ tensor๋ก ๋ฝ์๋
๋๋ค.
at::Tensor output = module.forward(inputs).toTensor();
namespace F = torch::nn::functional;
at::Tensor output_sm = F::softmax(output, F::SoftmaxFuncOptions(1));
std::tuple<at::Tensor, at::Tensor> top5_tensor = output_sm.topk(5);
at::Tensor top5 = std::get<1>(top5_tensor);
std::cout << top5[0] << "\n";
std::cout << "\nDONE\n";
return 0;
}
์ด๋ฐ ๊ฒ๋ค์ ์์๋ณด์์ต๋๋ค:
- ๋ช ๋ น ์ค์์ ์ง์ ํ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
- ๋๋ฏธ ์ ๋ ฅ "์ด๋ฏธ์ง" tensor ์์ฑํ๊ธฐ
- ์ ๋ ฅ์ ๋ํ ์ถ๋ก ์ํํ๊ธฐ
๋ํ, ์ด ์ฝ๋์๋ TorchVision์ ๋ํ ์ข ์์ฑ์ด ์๋ค๋ ๊ฒ์ ์ ์ํ์ธ์. ์ ์ฅ๋ TorchScript ๋ชจ๋ธ์๋ ํ์ต ๊ฐ์ค์น์ ์ฐ์ฐ ๊ทธ๋ํ๊ฐ ์์ผ๋ฉฐ ๋ค๋ฅธ ๊ฒ์ ํ์ํ์ง ์์ต๋๋ค.
๋ค์๊ณผ ๊ฐ์ CMakeLists.txt
ํ์ผ์ ์์ฑํฉ๋๋ค:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(custom_ops) find_package(Torch REQUIRED) add_executable(ts-infer ts-infer.cpp) target_link_libraries(ts-infer "${TORCH_LIBRARIES}") set_property(TARGET ts-infer PROPERTY CXX_STANDARD 11)
ํ๋ก๊ทธ๋จ์ ์คํํฉ๋๋ค:
cmake -DCMAKE_PREFIX_PATH=<path to your libtorch installation> make
์ด์ C++์์ ์ถ๋ก ์ ์ํํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
$ ./ts-infer r18_scripted.pt Loading model... Model loaded successfully 418 845 111 892 644 [ CPULongType{5} ] DONE
- pytorch.org ์์ ์ค์น ๋ฐฉ๋ฒ๊ณผ ์ถ๊ฐ ๋ฌธ์ ๋ฐ ํํ ๋ฆฌ์ผ๋ค์ ํ์ธํ ์ ์์ต๋๋ค.
- Introduction to TorchScript tutorial ์์ ๋ ์ฌํ๋ TorchScript ๊ธฐ์ด ์ค๋ช ์ ํ์ธํ ์ ์์ต๋๋ค.
- Full TorchScript documentation ์์ ์ ์ฒด TorchScript ์ธ์ด ๋ฐ API๋ฅผ ์ฐธ์กฐํ ์ ์์ต๋๋ค.