Skip to content

Files

Latest commit

 

History

History
356 lines (267 loc) Β· 18.2 KB

cpp_export.rst

File metadata and controls

356 lines (267 loc) Β· 18.2 KB

C++μ—μ„œ TorchScript λͺ¨λΈ λ‘œλ”©ν•˜κΈ°

PyTorch의 μ΄λ¦„μ—μ„œ μ•Œ 수 μžˆλ“―μ΄ PyTorchλŠ” Python ν”„λ‘œκ·Έλž˜λ° μ–Έμ–΄λ₯Ό κΈ°λ³Έ μΈν„°νŽ˜μ΄μŠ€λ‘œ ν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. Python은 동적성과 μ‹ μ†ν•œ μ΄ν„°λ ˆμ΄μ…˜μ΄ ν•„μš”ν•œ 상황에 μ ν•©ν•˜κ³  μ„ ν˜Έλ˜λŠ” μ–Έμ–΄μž…λ‹ˆλ‹€. ν•˜μ§€λ§Œ λ§ˆμ°¬κ°€μ§€λ‘œ μ΄λŸ¬ν•œ Python의 νŠΉμ§•λ“€μ΄ Python을 μ‚¬μš©ν•˜κΈ° μ ν•©ν•˜μ§€ μ•Šκ²Œ λ§Œλ“œλŠ” 상황도 많이 λ°œμƒν•©λ‹ˆλ‹€. Python을 μ‚¬μš©ν•˜κΈ° μ ν•©ν•˜μ§€ μ•Šμ€ λŒ€ν‘œμ μΈ 예둜 μƒμš© ν™˜κ²½μ΄ μžˆμŠ΅λ‹ˆλ‹€. μƒμš© ν™˜κ²½μ—μ„œλŠ” 짧은 μ§€μ—°μ‹œκ°„μ΄ μ€‘μš”ν•˜κ³  λ°°ν¬ν•˜λŠ” 데에도 λ§Žμ€ μ œμ•½μ΄ λ”°λ¦…λ‹ˆλ‹€. 이둜 인해 μƒμš© ν™˜κ²½μ—μ„œλŠ” λ§Žμ€ μ‚¬λžŒλ“€μ΄ C++λ₯Ό κ°œλ°œμ–Έμ–΄λ‘œ μ±„νƒν•˜κ²Œ λ©λ‹ˆλ‹€. 단지 Java, Rust, λ˜λŠ” Go와 같은 λ‹€λ₯Έ 언어듀을 λ°”μΈλ”©ν•˜κΈ° μœ„ν•œ λͺ©μ μΌ 뿐일지라도 말이죠. μ•žμœΌλ‘œ 이 νŠœν† λ¦¬μ–Όμ—μ„œ μ–΄λ–»κ²Œ PyTorchμ—μ„œ Python으둜 μž‘μ„±λœ λͺ¨λΈλ“€μ„ Python μ˜μ‘΄μ„±μ΄ μ „ν˜€ μ—†λŠ” C++ν™˜κ²½μ—μ„œλ„ 읽고 μ‹€ν–‰ν•  수 μžˆλŠ” λ°©μ‹μœΌλ‘œ 직렬화할 수 μžˆλŠ”μ§€ μ•Œμ•„λ³΄κ² μŠ΅λ‹ˆλ‹€.

단계 1. PyTorch λͺ¨λΈμ„ TorchScript λͺ¨λΈλ‘œ λ³€ν™˜ν•˜κΈ°

Torch Script λŠ” PyTorch λͺ¨λΈμ„ Pythonμ—μ„œ C++둜 λ³€ν™˜ν•˜λŠ” 것을 κ°€λŠ₯ν•˜κ²Œ ν•΄μ€λ‹ˆλ‹€. TorchScriptλŠ” TorchScript μ»΄νŒŒμΌλŸ¬κ°€ μ΄ν•΄ν•˜κ³ , μ»΄νŒŒμΌν•˜κ³ , 직렬화할 수 μžˆλŠ” PyTorch λͺ¨λΈμ˜ ν•œ ν‘œν˜„λ°©μ‹μž…λ‹ˆλ‹€. λ§Œμ•½ 기본적인 "μ¦‰μ‹œ μ‹€ν–‰"[μ—­μž μ£Ό: eager execution] APIλ₯Ό μ‚¬μš©ν•΄ μž‘μ„±λœ PyTorch λͺ¨λΈμ΄ μžˆλ‹€λ©΄, 처음으둜 ν•΄μ•Ό ν•  일은 이 λͺ¨λΈμ„ TorchScript λͺ¨λΈλ‘œ λ³€ν™˜ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. μ•„λž˜μ— μ„€λͺ…λ˜μ–΄ μžˆλ“―μ΄, λŒ€λΆ€λΆ„μ˜ κ²½μš°μ— 이 과정은 맀우 κ°„λ‹¨ν•©λ‹ˆλ‹€. 이미 TorchScript λͺ¨λ“ˆμ„ 가지고 μžˆλ‹€λ©΄, 이 μ„Ήμ…˜μ„ κ±΄λ„ˆλ›°μ–΄λ„ μ’‹μŠ΅λ‹ˆλ‹€.

PyTorch λͺ¨λΈμ„ TorchScript둜 λ³€ν™˜ν•˜λŠ” λ°©λ²•μ—λŠ” 두가지가 μžˆμŠ΅λ‹ˆλ‹€. μ²«λ²ˆμ§ΈλŠ” νŠΈλ ˆμ΄μ‹±(tracing)μ΄λΌλŠ” λ°©λ²•μœΌλ‘œ μ–΄λ–€ μž…λ ₯값을 μ‚¬μš©ν•˜μ—¬ λͺ¨λΈμ˜ ꡬ쑰λ₯Ό νŒŒμ•…ν•˜κ³  이 μž…λ ₯κ°’μ˜ λͺ¨λΈ μ•ˆμ—μ„œμ˜ 흐름을 톡해 λͺ¨λΈμ„ κΈ°λ‘ν•˜λŠ” λ°©μ‹μž…λ‹ˆλ‹€. 이 방법은 쑰건문을 많이 μ‚¬μš©ν•˜μ§€ μ•ŠλŠ” λͺ¨λΈμ˜ κ²½μš°μ— μ ν•©ν•©λ‹ˆλ‹€. PyTorch λͺ¨λΈμ„ TorchScript둜 λ³€ν™˜ν•˜λŠ” λ‘λ²ˆμ§Έ 방법은 λͺ¨λΈμ— λͺ…μ‹œμ μΈ μ–΄λ…Έν…Œμ΄μ…˜(annotation)을 μΆ”κ°€ν•˜μ—¬ TorchScript 컴파일러둜 ν•˜μ—¬κΈˆ 직접 λͺ¨λΈ μ½”λ“œλ₯Ό λΆ„μ„ν•˜κ³  μ»΄νŒŒμΌν•˜κ²Œ ν•˜λŠ” λ°©μ‹μž…λ‹ˆλ‹€. 이 방식을 μ‚¬μš©ν•  λ•ŒλŠ” TorchScript μ–Έμ–΄ μžμ²΄μ— μ œμ•½μ΄ μžˆμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.

Tip

μœ„ 두 방식에 κ΄€λ ¨λœ 정보와 λ‘˜ 쀑 μ–΄λ–€ 방법을 μ‚¬μš©ν•΄μ•Ό 할지 등에 λŒ€ν•œ κ°€μ΄λ“œλŠ” 곡식 κΈ°μˆ λ¬Έμ„œμΈ Torch Script reference μ—μ„œ ν™•μΈν•˜μ‹€ 수 μžˆμŠ΅λ‹ˆλ‹€.

νŠΈλ ˆμ΄μ‹±(tracing)을 톡해 TorchScript둜 λ³€ν™˜ν•˜κΈ°

PyTorch λͺ¨λΈμ„ νŠΈλ ˆμ΄μ‹±μ„ 톡해 TorchScript둜 λ³€ν™˜ν•˜κΈ° μœ„ν•΄μ„œλŠ”, μ—¬λŸ¬λΆ„μ΄ κ΅¬ν˜„ν•œ λͺ¨λΈμ˜ μΈμŠ€ν„΄μŠ€λ₯Ό 예제 μž…λ ₯κ°’κ³Ό ν•¨κ»˜ torch.jit.trace ν•¨μˆ˜μ— λ„˜κ²¨μ£Όμ–΄μ•Ό ν•©λ‹ˆλ‹€. 그러면 이 ν•¨μˆ˜λŠ” torch.jit.ScriptModule 객체λ₯Ό μƒμ„±ν•˜κ²Œ λ©λ‹ˆλ‹€. μ΄λ ‡κ²Œ μƒμ„±λœ κ°μ²΄μ—λŠ” λͺ¨λ“ˆμ˜ forward λ©”μ†Œλ“œμ˜ λͺ¨λΈ μ‹€ν–‰μ‹œ λŸ°νƒ€μž„μ„ traceν•œ κ²°κ³Όκ°€ ν¬ν•¨λ˜κ²Œ λ©λ‹ˆλ‹€:

import torch
import torchvision

# λͺ¨λΈ μΈμŠ€ν„΄μŠ€ 생성
model = torchvision.models.resnet18()

# 일반적으둜 λͺ¨λΈμ˜ forward() λ©”μ†Œλ“œμ— λ„˜κ²¨μ£ΌλŠ” μž…λ ₯κ°’
example = torch.rand(1, 3, 224, 224)

# torch.jit.traceλ₯Ό μ‚¬μš©ν•˜μ—¬ νŠΈλ ˆμ΄μ‹±μ„ μ΄μš©ν•΄ torch.jit.ScriptModule 생성
traced_script_module = torch.jit.trace(model, example)

μ΄λ ‡κ²Œ trace된 ScriptModule 은 일반적인 PyTorch λͺ¨λ“ˆκ³Ό 같은 λ°©μ‹μœΌλ‘œ μž…λ ₯값을 λ°›μ•„ μ²˜λ¦¬ν•  수 μžˆμŠ΅λ‹ˆλ‹€:

In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

μ–΄λ…Έν…Œμ΄μ…˜(annotation)을 톡해 TorchScript둜 λ³€ν™˜ν•˜κΈ°

νŠΉμ •ν•œ ν™˜κ²½(κ°€λ Ή λͺ¨λΈμ΄ μ–΄λ–€ μ œμ–΄νλ¦„μ„ μ‚¬μš©ν•˜κ³  μžˆλŠ” 경우)μ—μ„œλŠ” μ—¬λŸ¬λΆ„μ˜ λͺ¨λΈμ„ μ–΄λ…Έν…Œμ΄νŠΈ(annotate)ν•˜μ—¬ TorchScript둜 λ°”λ‘œ μž‘μ„±ν•˜λŠ” 것이 λ°”λžŒμ§ν•œ κ²½μš°κ°€ μžˆμŠ΅λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄, μ•„λž˜μ™€ 같은 PyTorch λͺ¨λΈμ΄ μžˆλ‹€κ³  κ°€μ •ν•˜κ² μŠ΅λ‹ˆλ‹€:

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

이 λͺ¨λ“ˆμ˜ forward λ©”μ†Œλ“œλŠ” μž…λ ₯값에 영ν–₯을 λ°›λŠ” μ œμ–΄νλ¦„μ„ μ‚¬μš©ν•˜κ³  있기 λ•Œλ¬Έμ—, 이 λͺ¨λ“ˆμ€ νŠΈλ ˆμ΄μ‹±μ—λŠ” μ ν•©ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. λŒ€μ‹  μš°λ¦¬λŠ” 이 λͺ¨λ“ˆμ„ ScriptModule 둜 λ³€ν™˜ν•  수 μžˆμŠ΅λ‹ˆλ‹€. λͺ¨λ“ˆμ„ ScriptModule 둜 λ³€ν™˜ν•˜κΈ° μœ„ν•΄μ„œλŠ”, μ•„λž˜μ™€ 같이 torch.jit.script ν•¨μˆ˜λ₯Ό μ‚¬μš©ν•΄ λͺ¨λ“ˆμ„ μ»΄νŒŒμΌν•΄μ•Ό ν•©λ‹ˆλ‹€:

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

아직 TorchScriptμ—μ„œ μ§€μ›ν•˜μ§€ μ•ŠλŠ” Python κΈ°λŠ₯을 μ‚¬μš©ν•˜κ³  μžˆλŠ” λ©”μ†Œλ“œλ“€μ„ μ—¬λŸ¬λΆ„μ˜ nn.Module μ—μ„œ μ œμ™Έν•˜κ³  μ‹Άλ‹€λ©΄, κ·Έ λ©”μ†Œλ“œλ“€μ„ @torch.jit.ignore 둜 μ–΄λ…Έν…Œμ΄νŠΈν•˜λ©΄ λ©λ‹ˆλ‹€.

sm 은 직렬화(serialization) μ€€λΉ„κ°€ 된 ScriptModule 의 μΈμŠ€ν„΄μŠ€μž…λ‹ˆλ‹€.

단계 2. Script λͺ¨λ“ˆμ„ 파일둜 μ§λ ¬ν™”ν•˜κΈ°

λͺ¨λΈμ„ νŠΈλ ˆμ΄μ‹±μ΄λ‚˜ μ–΄λ…Έν…Œμ΄νŒ…μ„ 톡해 ScriptModule 둜 λ³€ν™˜ν•˜μ˜€λ‹€λ©΄, 이제 그것을 파일둜 직렬화할 μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€. λ‚˜μ€‘μ— C++λ₯Ό μ΄μš©ν•΄ νŒŒμΌλ‘œλΆ€ν„° λͺ¨λ“ˆμ„ μ½μ–΄μ˜¬ 수 있고 Python에 μ–΄λ–€ μ˜μ‘΄μ„±λ„ 없이 κ·Έ λͺ¨λ“ˆμ„ μ‹€ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄ νŠΈλ ˆμ΄μ‹± μ˜ˆμ‹œμ—μ„œ λ“€μ—ˆλ˜ ResNet18 λͺ¨λΈμ„ μ§λ ¬ν™”ν•˜κ³  μ‹Άλ‹€κ³  κ°€μ •ν•©μ‹œλ‹€. 직렬화λ₯Ό ν•˜κΈ° μœ„ν•΄μ„œλŠ”, save ν•¨μˆ˜λ₯Ό ν˜ΈμΆœν•˜κ³  λͺ¨λ“ˆκ³Ό 파일λͺ…λ§Œ λ„˜κ²¨μ£Όλ©΄ λ©λ‹ˆλ‹€:

traced_script_module.save("traced_resnet_model.pt")

이 ν•¨μˆ˜λŠ” traced_resnet_model.pt νŒŒμΌμ„ μž‘μ—… 디렉토리에 생성할 κ²ƒμž…λ‹ˆλ‹€. λ§Œμ•½ μ–΄λ…Έν…Œμ΄μ…˜ μ˜ˆμ‹œμ˜ sm 을 μ§λ ¬ν™”ν•˜κ³  μ‹Άλ‹€λ©΄, sm.save("my_module_model.pt") λ₯Ό ν˜ΈμΆœν•˜λ©΄ λ©λ‹ˆλ‹€. 이둜써 이제 Python의 μ„Έκ³„μ—μ„œ λ²—μ–΄λ‚˜ C++ ν™˜κ²½μ—μ„œ μž‘μ—…ν•  μ€€λΉ„λ₯Ό λ§ˆμ³€μŠ΅λ‹ˆλ‹€.

단계 3. C++μ—μ„œ Script λͺ¨λ“ˆ λ‘œλ”©ν•˜κΈ°

μ§λ ¬ν™”λœ PyTorch λͺ¨λΈμ„ C++μ—μ„œ λ‘œλ“œν•˜κΈ° μœ„ν•΄μ„œλŠ”, μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜μ΄ λ°˜λ“œμ‹œ LibTorch 라고 λΆˆλ¦¬λŠ” PyTorch C++ APIλ₯Ό μ‚¬μš©ν•΄μ•Ό ν•©λ‹ˆλ‹€. LibTorchλŠ” μ—¬λŸ¬ 곡유 λΌμ΄λΈŒλŸ¬λ¦¬λ“€, 헀더 νŒŒμΌλ“€, 그리고 CMake λΉŒλ“œ μ„€μ •νŒŒμΌλ“€μ„ ν¬ν•¨ν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. CMakeλŠ” LibTorchλ₯Ό μ“°κΈ°μœ„ν•œ ν•„μˆ˜ μš”κ΅¬μ‚¬ν•­μ€ μ•„λ‹ˆμ§€λ§Œ, ꢌμž₯λ˜λŠ” 방식이고 ν–₯후에도 계속 지원될 μ˜ˆμ •μž…λ‹ˆλ‹€. 이 νŠœν† λ¦¬μ–Όμ—μ„œλŠ” CMake와 LibTorchλ₯Ό μ‚¬μš©ν•˜μ—¬ μ§λ ¬ν™”λœ PyTorch λͺ¨λΈμ„ 읽고 μ‹€ν–‰ν•˜λŠ” μ•„μ£Ό κ°„λ‹¨ν•œ C++ μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜μ„ λ§Œλ“€μ–΄λ³΄λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€.

κ°„λ‹¨ν•œ C++ μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜

μš°μ„  λͺ¨λ“ˆμ„ λ‘œλ“œν•˜λŠ” μ½”λ“œμ— λŒ€ν•΄ μ‚΄νŽ΄λ³΄λ„λ‘ ν•˜κ² μŠ΅λ‹ˆλ‹€. μ•„λž˜μ˜ κ°„λ‹¨ν•œ μ½”λ“œλ‘œ λͺ¨λ“ˆμ„ μ‰½κ²Œ μ½μ–΄μ˜¬ 수 μžˆμŠ΅λ‹ˆλ‹€:

#include <torch/script.h> // ν•„μš”ν•œ 단 ν•˜λ‚˜μ˜ ν—€λ”νŒŒμΌ.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }


  torch::jit::script::Module module;
  try {
    // torch::jit::load()을 μ‚¬μš©ν•΄ ScriptModule을 νŒŒμΌλ‘œλΆ€ν„° 역직렬화
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

  std::cout << "ok\n";
}

<torch/script.h> ν—€λ”λŠ” μ˜ˆμ‹œλ₯Ό μ‹€ν–‰ν•˜κΈ° μœ„ν•œ λͺ¨λ“  LibTorch 라이브러리λ₯Ό ν¬ν•¨ν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. 우리의 μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜μ€ μ§λ ¬ν™”λœ PyTorch ScriptModule 의 경둜λ₯Ό μœ μΌν•œ λͺ…λ Ήν–‰ 인자둜 μž…λ ₯λ°›κ³  이 파일경둜λ₯Ό 인자둜 λ°›λŠ” torch::jit::load() λ₯Ό μ‚¬μš©ν•΄ λͺ¨λ“ˆμ„ μ—­μ§λ ¬ν™”ν•©λ‹ˆλ‹€. κ·Έ 결과둜 torch::jit::script::Module λ₯Ό λŒλ €λ°›μŠ΅λ‹ˆλ‹€. 이 리턴받은 λͺ¨λ“ˆμ„ μ–΄λ–»κ²Œ μ‚¬μš©ν•˜λŠ”μ§€μ— λŒ€ν•΄μ„œλŠ” 곧 μ‚΄νŽ΄λ³΄κ² μŠ΅λ‹ˆλ‹€.

LibTorch μ‚¬μš© 및 μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜ λΉŒλ“œ 방법

μœ„μ˜ μ½”λ“œλ₯Ό example-app.cpp μ΄λΌλŠ” νŒŒμΌμ— μ €μž₯ν•˜μ˜€λ‹€κ³  κ°€μ •ν•©λ‹ˆλ‹€. μœ„ μ½”λ“œλ₯Ό λΉŒλ“œν•˜κΈ° μœ„ν•œ κ°„λ‹¨ν•œ CMakeLists.txt μž…λ‹ˆλ‹€:

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)

find_package(Torch REQUIRED)

add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)

μ˜ˆμ‹œ μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜μ„ λΉŒλ“œν•˜κΈ° μœ„ν•΄ λ§ˆμ§€λ§‰μœΌλ‘œ ν•„μš”ν•œ 것은 LibTorch λ°°ν¬νŒμž…λ‹ˆλ‹€. μ–Έμ œλ‚˜ κ°€μž₯ μ΅œμ‹ μ˜ μ•ˆμ • 버전을 PyTorch μ›Ήμ‚¬μ΄νŠΈμ˜ download page λ‘œλΆ€ν„° λ°›μœΌμ‹€ 수 μžˆμŠ΅λ‹ˆλ‹€. κ°€μž₯ μ΅œμ‹  버전을 λ‹€μš΄λ‘œλ“œ λ°›μ•„ 압좕을 ν‘Έμ‹œλ©΄, μ•„λž˜μ™€ 같은 디렉토리 ꡬ쑰의 폴더λ₯Ό ν™•μΈν•˜μ‹€ 수 μžˆμŠ΅λ‹ˆλ‹€:

libtorch/
  bin/
  include/
  lib/
  share/
  • lib/ ν΄λ”λŠ” 링크해야 ν•  곡유 라이브러리λ₯Ό ν¬ν•¨ν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€.
  • include/ ν΄λ”λŠ” μ—¬λŸ¬λΆ„μ˜ ν”„λ‘œκ·Έλž¨μ΄ include ν•΄μ•Ό ν•  헀더 νŒŒμΌλ“€μ„ λ‹΄κ³  μžˆμŠ΅λ‹ˆλ‹€.
  • share/ ν΄λ”λŠ” μœ„μ—μ„œ μ‹€ν–‰ν•œ κ°„λ‹¨ν•œ λͺ…령어인 find_package(Torch) λ₯Ό μ‹€ν–‰ν•˜κ²Œ ν•΄μ£ΌλŠ” CMake 섀정을 λ‹΄κ³ μžˆμŠ΅λ‹ˆλ‹€.

Tip

μœˆλ„μš°μ—μ„œλŠ” 디버그 λΉŒλ“œμ™€ 릴리즈 λΉŒλ“œκ°€ ABI-compatibleν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. λ§Œμ•½ ν”„λ‘œμ νŠΈλ₯Ό debug λͺ¨λ“œμ—μ„œ λΉŒλ“œν•˜κ³  μ‹Άλ‹€λ©΄, LibTorch의 debug 버전을 μ‚¬μš©ν•΄μ•Όν•©λ‹ˆλ‹€. 그리고 cmake --build .` 에 μ•Œλ§žμ€ 섀정을 λͺ…μ‹œν•΄ μ£Όμ–΄μ•Ό ν•©λ‹ˆλ‹€.

λ§ˆμ§€λ§‰ λ‹¨κ³„λŠ” μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜μ„ λΉŒλ“œν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. 이λ₯Ό μœ„ν•΄μ„œ 디렉토리 ꡬ쑰가 μ•„λž˜μ™€ 같이 κ°™λ‹€κ³  κ°€μ •ν•˜κ² μŠ΅λ‹ˆλ‹€.

example-app/
  CMakeLists.txt
  example-app.cpp

이제 μ•„λž˜ λͺ…령어듀을 μ‚¬μš©ν•΄ example-app/ 폴더 μ•ˆμ—μ„œ μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜μ„ λΉŒλ“œν•  수 μžˆμŠ΅λ‹ˆλ‹€.

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release

μ—¬κΈ°μ„œ /path/to/libtorch λŠ” LibTorch 배포판의 압좕을 ν‘Ό 전체 κ²½λ‘œμž…λ‹ˆλ‹€. λͺ¨λ“  것이 잘 λ˜μ—ˆλ‹€λ©΄, μ•„λž˜μ™€ 같은 것이 λ‚˜νƒ€λ‚  κ²ƒμž…λ‹ˆλ‹€:

root@4b5a67132e81:/example-app# mkdir build
root@4b5a67132e81:/example-app# cd build
root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Configuring done
-- Generating done
-- Build files have been written to: /example-app/build
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app

이제 trace된 ResNet18 λͺ¨λΈμΈ traced_resnet_model.pt 경둜λ₯Ό example-app λ°”μ΄λ„ˆλ¦¬μ— μž…λ ₯ν–ˆλ‹€λ©΄, "ok" λ©”μ‹œμ§€λ₯Ό 확인할 수 μžˆμ„ κ²ƒμž…λ‹ˆλ‹€. λ§Œμ•½ 이 μ˜ˆμ œμ— my_module_model.pt λ₯Ό 인자둜 λ„˜κ²Όλ‹€λ©΄, μž…λ ₯값이 ν˜Έν™˜λ˜μ§€ μ•ŠλŠ” λͺ¨μ–‘μ΄λΌλŠ” μ—λŸ¬λ©”μ‹œμ§€κ°€ 좜λ ₯λ©λ‹ˆλ‹€. my_module_model.pt λŠ” 4Dκ°€ μ•„λ‹Œ 1D ν…μ„œλ₯Ό 받도둝 λ˜μ–΄μžˆκΈ° λ•Œλ¬Έμž…λ‹ˆλ‹€.

root@4b5a67132e81:/example-app/build# ./example-app <path_to_model>/traced_resnet_model.pt
ok

단계 4. Script λͺ¨λ“ˆμ„ C++μ—μ„œ μ‹€ν–‰ν•˜κΈ°

ResNet18 을 C++μ—μ„œ μ„±κ³΅μ μœΌλ‘œ λ‘œλ”©ν•œ λ’€, 이제 λͺ‡ μ€„μ˜ μ½”λ“œλ§Œ 더 μΆ”κ°€ν•˜λ©΄ λͺ¨λ“ˆμ„ μ‹€ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€. C++ μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜μ˜ main() ν•¨μˆ˜μ— μ•„λž˜μ˜ μ½”λ“œλ₯Ό μΆ”κ°€ν•˜κ² μŠ΅λ‹ˆλ‹€.

// μž…λ ₯κ°’ 벑터λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));

// λͺ¨λΈμ„ μ‹€ν–‰ν•œ λ’€ 리턴값을 ν…μ„œλ‘œ λ³€ν™˜ν•©λ‹ˆλ‹€.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

첫 두쀄은 λͺ¨λΈμ˜ μž…λ ₯값을 μƒμ„±ν•©λ‹ˆλ‹€. torch::jit::IValue (script::Module λ©”μ†Œλ“œλ“€μ΄ μž…λ ₯λ°›κ³  또 리턴할 수 μžˆλŠ” νƒ€μž…μ΄ μ†Œκ±°λœ μžλ£Œν˜•)의 벑터λ₯Ό λ§Œλ“€κ³  κ·Έ 벑터에 ν•˜λ‚˜μ˜ μž…λ ₯값을 μΆ”κ°€ν•©λ‹ˆλ‹€. μž…λ ₯κ°’ ν…μ„œλ₯Ό λ§Œλ“€κΈ° μœ„ν•΄μ„œ μš°λ¦¬λŠ” torch::ones() 을 μ‚¬μš©ν•©λ‹ˆλ‹€. 이 ν•¨μˆ˜λŠ” torch.ones 의 C++ API λ²„μ „μž…λ‹ˆλ‹€. 이제 script::Module 의 forward λ©”μ†Œλ“œμ— μž…λ ₯κ°’ 벑터λ₯Ό λ„˜κ²¨μ£Όμ–΄ μ‹€ν–‰ν•˜λ©΄, μš°λ¦¬λŠ” μƒˆλ‘œμš΄ IValue λ₯Ό λ¦¬ν„΄λ°›κ²Œ 되고, 이 값을 toTensor() λ₯Ό 톡해 ν…μ„œλ‘œ λ³€ν™˜ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

Tip

torch::ones λ₯Ό λΉ„λ‘―ν•œ PyTorch C++ API에 λŒ€ν•΄ 더 μ•Œκ³  μ‹Άλ‹€λ©΄ https://pytorch.org/cppdocs에 μžˆλŠ” λ¬Έμ„œλ₯Ό μ°Έκ³ ν•˜μ‹œλ©΄ λ©λ‹ˆλ‹€. PyTorch C++ APIλŠ” Python API와 거의 λ™μΌν•œ κΈ°λŠ₯을 μ œκ³΅ν•˜μ—¬ μ‚¬μš©μžλ“€μ΄ ν…μ„œλ₯Ό 닀루고 μ‚¬μš©ν•˜λŠ” 것을 Pythonκ³Ό λ™μΌν•˜κ²Œ ν•  수 μžˆλ„λ‘ ν•©λ‹ˆλ‹€.

λ§ˆμ§€λ§‰ μ€„μ—μ„œ 좜λ ₯κ°’μ˜ 첫 λ‹€μ„― 값듀을 ν”„λ¦°νŠΈν•©λ‹ˆλ‹€. 이번 νŠœν† λ¦¬μ–Όμ˜ μ•žλΆ€λΆ„μ—μ„œ Python λͺ¨λΈμ— λ™μΌν•œ μž…λ ₯값을 λ„˜κ²¨μ£Όμ—ˆκΈ° λ•Œλ¬Έμ—, 이 λΆ€λΆ„μ—μ„œλ„ 좜λ ₯값은 같을 것이라고 μ˜ˆμƒν•  수 μžˆμŠ΅λ‹ˆλ‹€. 그럼 μ–΄ν”Œλ¦¬μΌ€μ΄μ…˜μ„ λ‹€μ‹œ μ»΄νŒŒμΌν•˜κ³  같은 μ§λ ¬ν™”λœ λͺ¨λΈμ— λŒ€ν•΄ μ‹€ν–‰ν•΄ λ³΄κ² μŠ΅λ‹ˆλ‹€:

root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
root@4b5a67132e81:/example-app/build# ./example-app traced_resnet_model.pt
-0.2698 -0.0381  0.4023 -0.3010 -0.0448
[ Variable[CPUFloatType]{1,5} ]

참고둜, μ΄μ „μ˜ Pythonμ—μ„œμ˜ 좜λ ₯값은 μ•„λž˜μ™€ κ°™μ•˜μŠ΅λ‹ˆλ‹€:

tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

두 좜λ ₯값이 μΌμΉ˜ν•˜λŠ” κ±Έ ν™•μΈν•˜μ‹€ 수 μžˆμŠ΅λ‹ˆλ‹€!

Tip

λͺ¨λΈμ„ GPU λ©”λͺ¨λ¦¬μ— 올리기 μœ„ν•΄μ„œλŠ”, model.to(at::kCUDA); λ₯Ό μ‚¬μš©ν•˜λ©΄ λ©λ‹ˆλ‹€. λͺ¨λΈμ— λ„˜κ²¨μ£ΌλŠ” μž…λ ₯값듀에 λŒ€ν•΄μ„œλ„ tensor.to(at::kCUDA) λ₯Ό 톡해 CUDA λ©”λͺ¨λ¦¬μ— 올린 λ’€ μ‚¬μš©ν•΄μ•Όν•©λ‹ˆλ‹€. tensor.to(at::kCUDA) λŠ” CUDA λ©”λͺ¨λ¦¬μ— μžˆλŠ” μƒˆλ‘œμš΄ ν…μ„œλ₯Ό λ¦¬ν„΄ν•©λ‹ˆλ‹€.

단계 5. API 더 μ•Œμ•„λ³΄κΈ°

이 νŠœν† λ¦¬μ–Όμ΄ PyTorch λͺ¨λΈμ„ Pythonμ—μ„œλΆ€ν„° C++둜 λ³€ν™˜ν•˜λŠ” 과정을 μ΄ν•΄ν•˜λŠ”λ° 도움이 λ˜μ—ˆκΈΈ λ°”λžλ‹ˆλ‹€. λ³Έ νŠœν† λ¦¬μ–Όμ—μ„œ 닀룬 κ°œλ…λ“€λ‘œ, μ—¬λŸ¬λΆ„μ€ 이제 "μ¦‰μ‹œ μ‹€ν–‰" λ²„μ „μ˜ PyTorch λͺ¨λΈμ—μ„œλΆ€ν„° Pythonμ—μ„œ 컴파일된 ScriptModule 둜, 더 λ‚˜μ•„κ°€ λ””μŠ€ν¬ μƒμ˜ μ§λ ¬ν™”λœ 파일둜, 그리고 λ§ˆμ§€λ§‰μœΌλ‘œ C++μ—μ„œ μ‹€ν–‰κ°€λŠ₯ν•œ script::Module κΉŒμ§€ λ§Œλ“€ 수 있게 λ˜μ—ˆμŠ΅λ‹ˆλ‹€.

λ¬Όλ‘  이 νŠœν† λ¦¬μ–Όμ—μ„œ 닀루지 λͺ»ν•œ κ°œλ…λ“€λ„ λ§ŽμŠ΅λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄ μ—¬λŸ¬λΆ„μ˜ ScriptModule 이 C++λ‚˜ CUDA둜 μ •μ˜λœ μ»€μŠ€ν…€ μ—°μ‚°μžλ₯Ό μ‚¬μš©ν•  수 μžˆκ²Œν•˜λŠ” 방법 λ˜λŠ” μ΄λŸ¬ν•œ μ»€μŠ€ν…€ μ—°μ‚°μžλ₯Ό C++ μƒμš© ν™˜κ²½μ˜ ScriptModule μ—μ„œ μ‚¬μš©ν•  수 μžˆκ²Œν•˜λŠ” 방법에 λŒ€ν•΄μ„œλŠ” λ³Έ νŠœν† λ¦¬μ–Όμ—μ„œ 닀루지 μ•Šμ•˜μŠ΅λ‹ˆλ‹€. 쒋은 μ†Œμ‹μ€ μ΄λŸ¬ν•œ 것듀이 κ°€λŠ₯ν•˜λ‹€λŠ” 것이고 μ§€μ›λ˜κ³  μžˆλ‹€λŠ” μ μž…λ‹ˆλ‹€! 저희가 곧 이것에 κ΄€ν•œ νŠœν† λ¦¬μ–Όμ„ μ—…λ‘œλ“œν•  λ•ŒκΉŒμ§€ 이 폴더 λ₯Ό μ˜ˆμ‹œλ‘œ μ‚Όμ•„ μ°Έκ³ ν•˜μ‹œλ©΄ λ˜κ² μŠ΅λ‹ˆλ‹€. 또 μ•„λž˜ 링크듀이 도움이 될 κ²ƒμž…λ‹ˆλ‹€:

μ–Έμ œλ‚˜ 그렇듯이, 문제λ₯Ό 맞λ‹₯λœ¨λ¦¬μ‹œκ±°λ‚˜ 질문이 μžˆμœΌμ‹œλ©΄ 저희 forum λ˜λŠ” GitHub issues 에 μ˜¬λ €μ£Όμ‹œλ©΄ λ˜κ² μŠ΅λ‹ˆλ‹€.