Skip to content

Files

Latest commit

 

History

History
192 lines (154 loc) Β· 7.94 KB

cpp_cuda_graphs.rst

File metadata and controls

192 lines (154 loc) Β· 7.94 KB

PyTorch C++ APIμ—μ„œ CUDA κ·Έλž˜ν”„ μ‚¬μš©ν•˜κΈ°

λ²ˆμ—­: μž₯효영

Note

|edit| 이 νŠœν† λ¦¬μ–Όμ„ μ—¬κΈ°μ„œ 보고 νŽΈμ§‘ν•˜μ„Έμš” GitHub. 전체 μ†ŒμŠ€ μ½”λ“œλŠ” 여기에 μžˆμŠ΅λ‹ˆλ‹€ GitHub.

μ„ μˆ˜ 지식:

NVIDIA의 CUDA κ·Έλž˜ν”„λŠ” 버전 10 릴리즈 μ΄ν›„λ‘œ CUDA νˆ΄ν‚· 라이브러리의 μΌλΆ€μ˜€μŠ΅λ‹ˆλ‹€
version 10.

CPU κ³ΌλΆ€ν•˜λ₯Ό 크게 쀄여 μ• ν”Œλ¦¬μΌ€μ΄μ…˜μ˜ μ„±λŠ₯을 ν–₯μƒμ‹œν‚΅λ‹ˆλ‹€.

이 νŠœν† λ¦¬μ–Όμ—μ„œλŠ”, CUDA κ·Έλž˜ν”„ μ‚¬μš©μ— μ΄ˆμ μ„ 맞좜 κ²ƒμž…λ‹ˆλ‹€ PyTorch C++ ν”„λ‘ νŠΈμ—”λ“œ μ‚¬μš©ν•˜κΈ°. C++ ν”„λ‘ νŠΈμ—”λ“œλŠ” νŒŒμ΄ν† μΉ˜ μ‚¬μš© μ‚¬λ‘€μ˜ μ€‘μš”ν•œ 뢀뢄인데, 주둜 μ œν’ˆ 및 배포 μ• ν”Œλ¦¬μΌ€μ΄μ…˜μ—μ„œ ν™œμš©λ©λ‹ˆλ‹€. 첫번째 λ“±μž₯ μ΄ν›„λ‘œ CUDA κ·Έλž˜ν”„λŠ” 맀우 μ„±λŠ₯이 μ’‹κ³  μ‚¬μš©ν•˜κΈ° μ‰¬μ›Œμ„œ, μ‚¬μš©μžμ™€ 개발자의 λ§ˆμŒμ„ μ‚¬λ‘œμž‘μ•˜μŠ΅λ‹ˆλ‹€. μ‹€μ œλ‘œ, CUDA κ·Έλž˜ν”„λŠ” νŒŒμ΄ν† μΉ˜ 2.0의 torch.compile μ—μ„œ 기본적으둜 μ‚¬μš©λ˜λ©°, ν›ˆλ ¨κ³Ό μΆ”λ‘  μ‹œμ— 생산성을 λ†’μ—¬μ€λ‹ˆλ‹€.

νŒŒμ΄ν† μΉ˜μ—μ„œ CUDA κ·Έλž˜ν”„ μ‚¬μš©λ²•μ„ λ³΄μ—¬λ“œλ¦¬κ³ μž ν•©λ‹ˆλ‹€ MNIST 예제. LibTorch(C++ ν”„λ‘ νŠΈμ—”λ“œ)μ—μ„œμ˜ CUDA κ·Έλž˜ν”„ μ‚¬μš©λ²•μ€ λ‹€μŒκ³Ό 맀우 μœ μ‚¬ν•˜μ§€λ§Œ Python μ‚¬μš©μ˜ˆμ œ μ•½κ°„μ˜ ꡬ문과 κΈ°λŠ₯의 차이가 μžˆμŠ΅λ‹ˆλ‹€.

μ‹œμž‘ν•˜κΈ°

μ£Όμš” ν›ˆλ ¨ λ£¨ν”„λŠ” μ—¬λŸ¬ λ‹¨κ³„λ‘œ κ΅¬μ„±λ˜μ–΄ 있으며 λ‹€μŒ μ½”λ“œ λͺ¨μŒμ— μ„€λͺ…λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€.

for (auto& batch : data_loader) {
  auto data = batch.data.to(device);
  auto targets = batch.target.to(device);
  optimizer.zero_grad();
  auto output = model.forward(data);
  auto loss = torch::nll_loss(output, targets);
  loss.backward();
  optimizer.step();
}

μœ„μ˜ μ˜ˆμ‹œμ—λŠ” μˆœμ „νŒŒ, μ—­μ „νŒŒ, κ°€μ€‘μΉ˜ μ—…λ°μ΄νŠΈκ°€ ν¬ν•¨λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€.

이 νŠœν† λ¦¬μ–Όμ—μ„œλŠ” 전체 λ„€νŠΈμ›Œν¬ κ·Έλž˜ν”„ 캑처λ₯Ό 톡해 λͺ¨λ“  계산 단계에 CUDA κ·Έλž˜ν”„λ₯Ό μ μš©ν•©λ‹ˆλ‹€.
ν•˜μ§€λ§Œ κ·Έ 전에 μ•½κ°„μ˜ μ†ŒμŠ€ μ½”λ“œ μˆ˜μ •μ΄ ν•„μš”ν•©λ‹ˆλ‹€. μš°λ¦¬κ°€ ν•΄μ•Ό ν•  일은 μ£Ό ν›ˆλ ¨ λ£¨ν”„μ—μ„œ

tensorλ₯Ό μž¬μ‚¬μš©ν•  수 μžˆλ„λ‘ tensorλ₯Ό 미리 ν• λ‹Ήν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. λ‹€μŒμ€ κ΅¬ν˜„ μ˜ˆμ‹œμž…λ‹ˆλ‹€.

torch::TensorOptions FloatCUDA =
    torch::TensorOptions(device).dtype(torch::kFloat);
torch::TensorOptions LongCUDA =
    torch::TensorOptions(device).dtype(torch::kLong);

torch::Tensor data = torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA);
torch::Tensor targets = torch::zeros({kTrainBatchSize}, LongCUDA);
torch::Tensor output = torch::zeros({1}, FloatCUDA);
torch::Tensor loss = torch::zeros({1}, FloatCUDA);

for (auto& batch : data_loader) {
  data.copy_(batch.data);
  targets.copy_(batch.target);
  training_step(model, optimizer, data, targets, output, loss);
}

μ—¬κΈ°μ„œ ``training_step``은 λ‹¨μˆœνžˆ ν•΄λ‹Ή μ˜΅ν‹°λ§ˆμ΄μ € 호좜과 ν•¨κ»˜ μˆœμ „νŒŒ 및 μ—­μ „νŒŒλ‘œ κ΅¬μ„±λ©λ‹ˆλ‹€

void training_step(
    Net& model,
    torch::optim::Optimizer& optimizer,
    torch::Tensor& data,
    torch::Tensor& targets,
    torch::Tensor& output,
    torch::Tensor& loss) {
  optimizer.zero_grad();
  output = model.forward(data);
  loss = torch::nll_loss(output, targets);
  loss.backward();
  optimizer.step();
}

νŒŒμ΄ν† μΉ˜μ˜ CUDA κ·Έλž˜ν”„ APIλŠ” 슀트림 μΊ‘μ²˜μ— μ˜μ‘΄ν•˜κ³  있으며, 이 경우 λ‹€μŒμ²˜λŸΌ μ‚¬μš©λ©λ‹ˆλ‹€

at::cuda::CUDAGraph graph;
at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool();
at::cuda::setCurrentCUDAStream(captureStream);

graph.capture_begin();
training_step(model, optimizer, data, targets, output, loss);
graph.capture_end();

μ‹€μ œ κ·Έλž˜ν”„ 캑처 전에, μ‚¬μ΄λ“œ μŠ€νŠΈλ¦Όμ—μ„œ μ—¬λŸ¬ 번의 μ›Œλ°μ—… λ°˜λ³΅μ„ μ‹€ν–‰ν•˜μ—¬ CUDA μΊμ‹œλΏλ§Œ μ•„λ‹ˆλΌ ν›ˆλ ¨ 쀑에 μ‚¬μš©ν•  CUDA 라이브러리(CUBLAS와 CUDNN같은)λ₯Ό μ€€λΉ„ν•˜λŠ” 것이 μ€‘μš”ν•©λ‹ˆλ‹€.

at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool();
at::cuda::setCurrentCUDAStream(warmupStream);
for (int iter = 0; iter < num_warmup_iters; iter++) {
  training_step(model, optimizer, data, targets, output, loss);
}

κ·Έλž˜ν”„ μΊ‘μ²˜μ— μ„±κ³΅ν•˜λ©΄ training_step(model, optimizer, data, target, output, loss); ν˜ΈμΆœμ„ ``graph.replay()``둜 λŒ€μ²΄ν•˜μ—¬ ν•™μŠ΅ 단계λ₯Ό 진행할 수 μžˆμŠ΅λ‹ˆλ‹€.

ν›ˆλ ¨ κ²°κ³Ό

μ½”λ“œλ₯Ό ν•œ 번 μ‚΄νŽ΄λ³΄λ©΄ κ·Έλž˜ν”„κ°€ μ•„λ‹Œ 일반 ν›ˆλ ¨μ—μ„œ λ‹€μŒκ³Ό 같은 κ²°κ³Όλ₯Ό λ³Ό 수 μžˆμŠ΅λ‹ˆλ‹€

$ time ./mnist
Train Epoch: 1 [59584/60000] Loss: 0.3921
Test set: Average loss: 0.2051 | Accuracy: 0.938
Train Epoch: 2 [59584/60000] Loss: 0.1826
Test set: Average loss: 0.1273 | Accuracy: 0.960
Train Epoch: 3 [59584/60000] Loss: 0.1796
Test set: Average loss: 0.1012 | Accuracy: 0.968
Train Epoch: 4 [59584/60000] Loss: 0.1603
Test set: Average loss: 0.0869 | Accuracy: 0.973
Train Epoch: 5 [59584/60000] Loss: 0.2315
Test set: Average loss: 0.0736 | Accuracy: 0.978
Train Epoch: 6 [59584/60000] Loss: 0.0511
Test set: Average loss: 0.0704 | Accuracy: 0.977
Train Epoch: 7 [59584/60000] Loss: 0.0802
Test set: Average loss: 0.0654 | Accuracy: 0.979
Train Epoch: 8 [59584/60000] Loss: 0.0774
Test set: Average loss: 0.0604 | Accuracy: 0.980
Train Epoch: 9 [59584/60000] Loss: 0.0669
Test set: Average loss: 0.0544 | Accuracy: 0.984
Train Epoch: 10 [59584/60000] Loss: 0.0219
Test set: Average loss: 0.0517 | Accuracy: 0.983

real    0m44.287s
user    0m44.018s
sys    0m1.116s

CUDA κ·Έλž˜ν”„λ₯Ό μ‚¬μš©ν•œ ν›ˆλ ¨μ€ λ‹€μŒκ³Ό 같은 좜λ ₯을 μƒμ„±ν•©λ‹ˆλ‹€

$ time ./mnist --use-train-graph
Train Epoch: 1 [59584/60000] Loss: 0.4092
Test set: Average loss: 0.2037 | Accuracy: 0.938
Train Epoch: 2 [59584/60000] Loss: 0.2039
Test set: Average loss: 0.1274 | Accuracy: 0.961
Train Epoch: 3 [59584/60000] Loss: 0.1779
Test set: Average loss: 0.1017 | Accuracy: 0.968
Train Epoch: 4 [59584/60000] Loss: 0.1559
Test set: Average loss: 0.0871 | Accuracy: 0.972
Train Epoch: 5 [59584/60000] Loss: 0.2240
Test set: Average loss: 0.0735 | Accuracy: 0.977
Train Epoch: 6 [59584/60000] Loss: 0.0520
Test set: Average loss: 0.0710 | Accuracy: 0.978
Train Epoch: 7 [59584/60000] Loss: 0.0935
Test set: Average loss: 0.0666 | Accuracy: 0.979
Train Epoch: 8 [59584/60000] Loss: 0.0744
Test set: Average loss: 0.0603 | Accuracy: 0.981
Train Epoch: 9 [59584/60000] Loss: 0.0762
Test set: Average loss: 0.0547 | Accuracy: 0.983
Train Epoch: 10 [59584/60000] Loss: 0.0207
Test set: Average loss: 0.0525 | Accuracy: 0.983

real    0m6.952s
user    0m7.048s
sys    0m0.619s

κ²°λ‘ 

μœ„ μ˜ˆμ‹œμ—μ„œ λ³Ό 수 μžˆλ“―μ΄, λ°”λ‘œ MNIST 예제 에 CUDA κ·Έλž˜ν”„λ₯Ό μ μš©ν•˜λŠ” κ²ƒλ§ŒμœΌλ‘œλ„ ν›ˆλ ¨ μ„±λŠ₯을 6λ°° 이상 ν–₯μƒμ‹œν‚¬ 수 μžˆμ—ˆμŠ΅λ‹ˆλ‹€. μ΄λ ‡κ²Œ 큰 μ„±λŠ₯ ν–₯상이 κ°€λŠ₯ν–ˆλ˜ 것은 λͺ¨λΈ 크기가 μž‘μ•˜κΈ° λ•Œλ¬Έμž…λ‹ˆλ‹€. GPU μ‚¬μš©λŸ‰μ΄ λ§Žμ€ λŒ€ν˜• λͺ¨λΈμ˜ 경우 CPU κ³ΌλΆ€ν•˜μ˜ 영ν–₯이 적기 λ•Œλ¬Έμ— κ°œμ„  νš¨κ³Όκ°€ 더 μž‘μ„ 수 μžˆμŠ΅λ‹ˆλ‹€. 그런 κ²½μš°λΌλ„, GPU의 μ„±λŠ₯을 μ΄λŒμ–΄λ‚΄λ €λ©΄ CUDA κ·Έλž˜ν”„λ₯Ό μ‚¬μš©ν•˜λŠ” 것이 항상 μœ λ¦¬ν•©λ‹ˆλ‹€.