Skip to content

Files

Latest commit

 

History

History
182 lines (124 loc) ยท 6.22 KB

torchscript_inference.rst

File metadata and controls

182 lines (124 loc) ยท 6.22 KB

TorchScript๋กœ ๋ฐฐํฌํ•˜๊ธฐ

์ด ๋ ˆ์‹œํ”ผ์—์„œ๋Š” ๋‹ค์Œ ๋‚ด์šฉ์„ ์•Œ์•„๋ด…๋‹ˆ๋‹ค:

  • TorchScript๋ž€?
  • ํ•™์Šต๋œ ๋ชจ๋ธ์„ TorchScript ํ˜•์‹์œผ๋กœ ๋‚ด๋ณด๋‚ด๊ธฐ
  • TorchScript ๋ชจ๋ธ์„ C++๋กœ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ์ถ”๋ก ํ•˜๊ธฐ

์š”๊ตฌ ์‚ฌํ•ญ

  • PyTorch 1.5
  • TorchVision 0.6.0
  • libtorch 1.5
  • C++ ์ปดํŒŒ์ผ๋Ÿฌ

3๊ฐ€์ง€ PyTorch ์ปดํฌ๋„ŒํŠธ๋ฅผ ์„ค์น˜ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ `pytorch.org`_์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. C++ ์ปดํŒŒ์ผ๋Ÿฌ๋Š” ๋‹น์‹ ์˜ ํ”Œ๋žซํผ์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง‘๋‹ˆ๋‹ค.

TorchScript๋ž€?

TorchScript**๋Š” C++ ๊ฐ™์€ ๊ณ ์„ฑ๋Šฅ ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋Š” PyTorch ๋ชจ๋ธ์˜ ์ค‘๊ฐ„ ํ‘œํ˜„(``nn.Module``์˜ ํ•˜์œ„ ํด๋ž˜์Šค)์ž…๋‹ˆ๋‹ค. Python์˜ ๊ณ ์„ฑ๋Šฅ ํ•˜์œ„ ์ง‘ํ•ฉ์ด๋ฉฐ ๋ชจ๋ธ ์—ฐ์‚ฐ์˜ ๋Ÿฐํƒ€์ž„ ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” **PyTorch JIT Compiler, ์—์„œ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. TorchScript๋Š” PyTorch ๋ชจ๋ธ์—์„œ ์Šค์ผ€์ผ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•  ๋•Œ ๊ถŒ์žฅ๋˜๋Š” ๋ชจ๋ธ ํ˜•์‹์ž…๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ `pytorch.org`_์— ์žˆ๋Š” `Introduction to TorchScript tutorial`_, Loading A TorchScript Model in C++ tutorial, full TorchScript documentation ์—์„œ ํ™•์ธํ•˜์„ธ์š”.

๋ชจ๋ธ ๋‚ด๋ณด๋‚ด๊ธฐ

์‚ฌ์ „ ํ•™์Šต๋œ ์‹œ๊ฐ ๋ชจ๋ธ์„ ์‚ดํŽด๋ด…์‹œ๋‹ค. TorchVision์˜ ๋ชจ๋“  ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ์€ TorchScript์™€ ํ˜ธํ™˜๋ฉ๋‹ˆ๋‹ค.

์Šคํฌ๋ฆฝํŠธ๋‚˜ REPL์—์„œ ๋‹ค์Œ์˜ Python 3 ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜์„ธ์š”:

import torch
import torch.nn.functional as F
import torchvision.models as models

r18 = models.resnet18(pretrained=True)       # ์ด์ œ ์‚ฌ์ „ ํ•™์Šต๋œ ๋ชจ๋ธ์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
r18_scripted = torch.jit.script(r18)         # *** ์—ฌ๊ธฐ๊ฐ€ TorchScript๋กœ ๋‚ด๋ณด๋‚ด๋Š” ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค.
dummy_input = torch.rand(1, 3, 224, 224)     # ๋น ๋ฅด๊ฒŒ ํ…Œ์ŠคํŠธ ํ•ด๋ด…๋‹ˆ๋‹ค.

๋‘ ๋ชจ๋ธ์ด ์ •๋ง ๊ฐ™์€์ง€์— ๋Œ€ํ•ด ์ •๋ฐ€ ๊ฒ€์‚ฌ๋ฅผ ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

unscripted_output = r18(dummy_input)         # ์Šคํฌ๋ฆฝํŠธํ™” ๋˜์ง€ ์•Š์€ ๋ชจ๋ธ์˜ ์˜ˆ์ธก์„ ์–ป๊ณ ...
scripted_output = r18_scripted(dummy_input)  # ...์Šคํฌ๋ฆฝํŠธํ™” ๋œ ๋ชจ๋ธ๋„ ๋˜‘๊ฐ™์ด ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.

unscripted_top5 = F.softmax(unscripted_output, dim=1).topk(5).indices
scripted_top5 = F.softmax(scripted_output, dim=1).topk(5).indices

print('Python model top 5 results:\n  {}'.format(unscripted_top5))
print('TorchScript model top 5 results:\n  {}'.format(scripted_top5))

๋‘ ๋ชจ๋ธ์˜ ๊ฒฐ๊ณผ๊ฐ€ ๋™์ผํ•จ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

Python model top 5 results:
  tensor([[463, 600, 731, 899, 898]])
TorchScript model top 5 results:
  tensor([[463, 600, 731, 899, 898]])

๊ฒ€์‚ฌ๊ฐ€ ๋๋‚ฌ์œผ๋ฉด ๋ชจ๋ธ์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค:

r18_scripted.save('r18_scripted.pt')

C++๋กœ TorchScript ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

๋‹ค์Œ๊ณผ ๊ฐ™์€ C++ ํŒŒ์ผ์„ ๋งŒ๋“ค๊ณ  ํŒŒ์ผ๋ช…์„ ts-infer.cpp ๋ผ๊ณ  ์ง“์Šต๋‹ˆ๋‹ค.

#include <torch/script.h>
#include <torch/nn/functional/activation.h>


int main(int argc, const char* argv[]) {
    if (argc != 2) {
        std::cerr << "usage: ts-infer <path-to-exported-model>\n";
        return -1;
    }

    std::cout << "Loading model...\n";

    // ScriptModule์„ ์—ญ์ง๋ ฌํ™”(deserialize) ํ•ฉ๋‹ˆ๋‹ค.
    torch::jit::script::Module module;
    try {
        module = torch::jit::load(argv[1]);
    } catch (const c10::Error& e) {
        std::cerr << "Error loading model\n";
        std::cerr << e.msg_without_backtrace();
        return -1;
    }

    std::cout << "Model loaded successfully\n";

    torch::NoGradGuard no_grad; // autograd๊ฐ€ ๊บผ์ ธ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
    module.eval(); // dropout๊ณผ ํ•™์Šต ๋‹จ์˜ ๋ ˆ์ด์–ด ๋ฐ ํ•จ์ˆ˜๋“ค์„ ๋•๋‹ˆ๋‹ค.

    // ์ž…๋ ฅ "์ด๋ฏธ์ง€"๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::rand({1, 3, 224, 224}));

    // ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜๊ณ  ์ถœ๋ ฅ ๊ฐ’์„ tensor๋กœ ๋ฝ‘์•„๋ƒ…๋‹ˆ๋‹ค.
    at::Tensor output = module.forward(inputs).toTensor();

    namespace F = torch::nn::functional;
    at::Tensor output_sm = F::softmax(output, F::SoftmaxFuncOptions(1));
    std::tuple<at::Tensor, at::Tensor> top5_tensor = output_sm.topk(5);
    at::Tensor top5 = std::get<1>(top5_tensor);

    std::cout << top5[0] << "\n";

    std::cout << "\nDONE\n";
    return 0;
}

์ด๋Ÿฐ ๊ฒƒ๋“ค์„ ์•Œ์•„๋ณด์•˜์Šต๋‹ˆ๋‹ค:

  • ๋ช…๋ น ์ค„์—์„œ ์ง€์ •ํ•œ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
  • ๋”๋ฏธ ์ž…๋ ฅ "์ด๋ฏธ์ง€" tensor ์ƒ์„ฑํ•˜๊ธฐ
  • ์ž…๋ ฅ์— ๋Œ€ํ•œ ์ถ”๋ก  ์ˆ˜ํ–‰ํ•˜๊ธฐ

๋˜ํ•œ, ์ด ์ฝ”๋“œ์—๋Š” TorchVision์— ๋Œ€ํ•œ ์ข…์†์„ฑ์ด ์—†๋‹ค๋Š” ๊ฒƒ์— ์œ ์˜ํ•˜์„ธ์š”. ์ €์žฅ๋œ TorchScript ๋ชจ๋ธ์—๋Š” ํ•™์Šต ๊ฐ€์ค‘์น˜์™€ ์—ฐ์‚ฐ ๊ทธ๋ž˜ํ”„๊ฐ€ ์žˆ์œผ๋ฉฐ ๋‹ค๋ฅธ ๊ฒƒ์€ ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

C++ ์ถ”๋ก  ์—”์ง„ ๋นŒ๋“œํ•˜๊ณ  ์‹คํ–‰ํ•˜๊ธฐ

๋‹ค์Œ๊ณผ ๊ฐ™์€ CMakeLists.txt ํŒŒ์ผ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)

find_package(Torch REQUIRED)

add_executable(ts-infer ts-infer.cpp)
target_link_libraries(ts-infer "${TORCH_LIBRARIES}")
set_property(TARGET ts-infer PROPERTY CXX_STANDARD 11)

ํ”„๋กœ๊ทธ๋žจ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

cmake -DCMAKE_PREFIX_PATH=<path to your libtorch installation>
make

์ด์ œ C++์—์„œ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

$ ./ts-infer r18_scripted.pt
Loading model...
Model loaded successfully
 418
 845
 111
 892
 644
[ CPULongType{5} ]

DONE

์ค‘์š” ์ฐธ๊ณ ์ž๋ฃŒ

  • pytorch.org ์—์„œ ์„ค์น˜ ๋ฐฉ๋ฒ•๊ณผ ์ถ”๊ฐ€ ๋ฌธ์„œ ๋ฐ ํŠœํ† ๋ฆฌ์–ผ๋“ค์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • Introduction to TorchScript tutorial ์—์„œ ๋” ์‹ฌํ™”๋œ TorchScript ๊ธฐ์ดˆ ์„ค๋ช…์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • Full TorchScript documentation ์—์„œ ์ „์ฒด TorchScript ์–ธ์–ด ๋ฐ API๋ฅผ ์ฐธ์กฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.