μ΄ λ μνΌλ Pytorch λͺ¨λΈμ μμννλ λ°©λ²μ μ€λͺ ν©λλ€. μμνλ λͺ¨λΈμ μλ³Έ λͺ¨λΈκ³Ό κ±°μ κ°μ μ νλλ₯Ό λ΄λ©΄μ, μ¬μ΄μ¦κ° μ€μ΄λ€κ³ μΆλ‘ μλκ° λΉ¨λΌμ§λλ€. μμν μμ μ μλ² λͺ¨λΈκ³Ό λͺ¨λ°μΌ λͺ¨λΈ λ°°ν¬μ λͺ¨λ μ μ©λ μ μμ§λ§, λͺ¨λ°μΌ νκ²½μμ νΉν μ€μνκ³ λ§€μ° νμν©λλ€. κ·Έ μ΄μ λ μμνλ₯Ό μ μ©νμ§ μμ λͺ¨λΈμ ν¬κΈ°κ° iOSλ Android μ±μ΄ νμ©νλ ν¬κΈ° νλλ₯Ό μ΄κ³Όνκ³ , κ·Έλ‘ μΈν΄ λͺ¨λΈμ λ°°ν¬λ OTA μ λ°μ΄νΈκ° λ무 μ€λ 걸리며, λν μΆλ‘ μλκ° λ무 λλ €μ μ¬μ©μμ μΎμ ν¨μ λ°©ν΄νκΈ° λλ¬Έμ λλ€.
μμνλ λͺ¨λΈ 맀κ°λ³μλ₯Ό ꡬμ±νλ 32λΉνΈ ν¬κΈ°μ μ€μ μλ£νμ μ«μλ₯Ό 8λΉνΈ ν¬κΈ°μ μ μ μλ£νμ μ«μλ‘ μ ννλ κΈ°λ²μ λλ€. μμν κΈ°λ²μ μ μ©νλ©΄, μ νλλ κ±°μ κ°κ² μ μ§νλ©΄μ, λͺ¨λΈμ ν¬κΈ°μ λ©λͺ¨λ¦¬ μ 체 μ¬μ©λμ μλ³Έ λͺ¨λΈμ 4λΆμ 1κΉμ§ κ°μμν¬ μ μκ³ , μΆλ‘ μ 2~4λ°° μ λ λΉ λ₯΄κ² λ§λ€ μ μμ΅λλ€.
λͺ¨λΈμ μμννλ λ°λ μ λΆ μΈ κ°μ§μ μ κ·Όλ² λ° μμ λ°©μμ΄ μμ΅λλ€. νμ΅ ν λμ μμν(post training dynamic quantization), νμ΅ ν μ μ μμν(post training static quantization), κ·Έλ¦¬κ³ μμνλ₯Ό κ³ λ €ν νμ΅(quantization aware training)μ΄ μμ΅λλ€. νμ§λ§ μ¬μ©νλ €λ λͺ¨λΈμ΄ μ΄λ―Έ μμνλ λ²μ μ΄ μλ€λ©΄, μμ μΈ κ°μ§ λ°©μμ κ±°μΉμ§ μκ³ κ·Έ λ²μ μ λ°λ‘ μ¬μ©νλ©΄ λ©λλ€. μλ₯Ό λ€μ΄, torchvision λΌμ΄λΈλ¬λ¦¬μλ μ΄λ―Έ MobileNet v2, ResNet 18, ResNet 50, Inception v3, GoogleNetμ ν¬ν¨ν λͺ¨λΈμ μμνλ λ²μ μ΄ μ‘΄μ¬ν©λλ€. λ°λΌμ λΉλ‘ λ¨μν μμ μ΄κ² μ§λ§, μ¬μ νμ΅ λ° μμνλ λͺ¨λΈ μ¬μ©(use pretrained quantized model)μ λ λ€λ₯Έ μμ λ°©μ μ€ νλλ‘ ν¬ν¨νλ € ν©λλ€.
Note
μμνλ μΌλΆ μ νλ λ²μμ μ°μ°μμλ§ μ§μλ©λλ€. λ λ§μ μ 보λ μ¬κΈ° λ₯Ό μ°Έκ³ νμΈμ.
PyTorch 1.6.0 or 1.7.0
torchvision 0.6.0 or 0.7.0
λͺ¨λΈμ μμννλ €λ©΄ λ€μ 4κ°μ§ λ°©μ μ€ νλλ₯Ό μ¬μ©νμΈμ.
μ¬μ νμ΅λ MobileNet v2 λͺ¨λΈμ λΆλ¬μ€λ €λ©΄, λ€μμ μ λ ₯νμΈμ.
import torchvision model_quantized = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True)
μμν μ μ MobileNet v2 λͺ¨λΈκ³Ό μμνλ λ²μ μ λͺ¨λΈμ ν¬κΈ°λ₯Ό λΉκ΅ν©λλ€.
model = torchvision.models.mobilenet_v2(pretrained=True) import os import torch def print_model_size(mdl): torch.save(mdl.state_dict(), "tmp.pt") print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6)) os.remove('tmp.pt') print_model_size(model) print_model_size(model_quantized)
μΆλ ₯μ λ€μκ³Ό κ°μ΅λλ€.
14.27 MB 3.63 MB
λμ μμνλ₯Ό μ μ©νλ©΄, λͺ¨λΈμ λͺ¨λ κ°μ€μΉλ 32λΉνΈ ν¬κΈ°μ μ€μ μλ£νμμ 8λΉνΈ ν¬κΈ°μ μ μ μλ£νμΌλ‘ μ νλμ§λ§, νμ±νμ λν κ³μ°μ μ§ννκΈ° μ§μ κΉμ§λ νμ± ν¨μλ 8λΉνΈ μ μνμΌλ‘ μ ννμ§ μκ² λ©λλ€. λμ μμνλ₯Ό μ μ©νλ €λ©΄, torch.quantization.quantize_dynamic μ μ¬μ©νλ©΄ λ©λλ€.
model_dynamic_quantized = torch.quantization.quantize_dynamic( model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8 )
μ¬κΈ°μ qconfig_spec μΌλ‘ model λ΄μμ μμν μ μ© λμμΈ λ΄λΆ λͺ¨λ(submodules)μ μ§μ ν©λλ€.
Warning
λμ μμνλ μ¬μ νμ΅λ μμν μ μ© λͺ¨λΈμ΄ μ€λΉλμ§ μμμ λ μ¬μ©νκΈ° κ°μ₯ μ¬μ΄ λ°©μμ΄μ§λ§, μ΄ λ°©μμ μ£Όμ νκ³λ qconfig_spec μ΅μ μ΄ νμ¬λ nn.Linear κ³Ό nn.LSTM λ§ μ§μνλ€λ κ²μ λλ€. μ΄λ nn.Conv2d κ°μ λ€λ₯Έ λͺ¨λμ μμνν λ, λμ€μ λ Όμλ μ μ μμνλ μμνλ₯Ό κ³ λ €ν νμ΅μ μ¬μ©ν΄μΌ νλ€λ κ±Έ μλ―Έν©λλ€.
quantize_dynamic API call κ΄λ ¨ μ 체 λ¬Έμλ μ¬κΈ° λ₯Ό μ°Έκ³ νμΈμ. νμ΅ ν λμ μμνλ₯Ό μ¬μ©νλ μΈ κ°μ§ μμ μλ the Bert example, an LSTM model example, demo LSTM example μ΄ μμ΅λλ€.
μ΄ λ°©μμ λͺ¨λΈμ κ°μ€μΉμ νμ± ν¨μ λͺ¨λλ₯Ό 8λΉνΈ ν¬κΈ°μ μ μ μλ£νμΌλ‘ 미리 λ³ννλ―λ‘, λμ μμνμ²λΌ μΆλ‘ κ³Όμ μ€μ νμ±νμ λν μ¦κ°μ μΈ μμνλ₯Ό μ§ννμ§ μμ΅λλ€. νμ΅ ν μ μ μμνλ μΆλ‘ μλλ₯Ό ν¬κ² ν₯μμν€κ³ λͺ¨λΈμ ν¬κΈ°λ₯Ό μ€μΌ μ μμ§λ§, μ΄ λ°©λ²μ λμ μμνμ λΉν΄ μλ³Έ λͺ¨λΈ λλΉ μ νλκ° λ λ¨μ΄μ§ μ μμ΅λλ€.
μ μ μμνλ₯Ό λͺ¨λΈμ μ μ©νλ μ½λλ λ€μκ³Ό κ°μ΅λλ€.
backend = "qnnpack" model.qconfig = torch.quantization.get_default_qconfig(backend) torch.backends.quantized.engine = backend model_static_quantized = torch.quantization.prepare(model, inplace=False) model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
μ΄λ€μμ print_model_size(model_static_quantized) λ₯Ό μ€ννλ©΄ μ μ μμνκ° μ μ©λ λͺ¨λΈμ΄ 3.98MB λΌ νμλ©λλ€.
λͺ¨λΈμ μ 체 μ μμ μ μ μμνμ μμ λ μ¬κΈ° μμ νμΈνμΈμ. νΉμν μ μ μμν νν 리μΌμ μ¬κΈ° μμ νμΈνμΈμ.
Note
λͺ¨λ°μΌ μ₯λΉλ μΌλ°μ μΌλ‘ ARM μν€ν μ²λ₯Ό νμ¬νλλ° μ¬κΈ°μ λͺ¨λΈμ΄ μλνκ² νλ €λ©΄, qnnpack μ backend λ‘ μ¬μ©ν΄μΌ ν©λλ€. μ΄μ λ¬λ¦¬ x86 μν€ν μ²λ₯Ό νμ¬ν μ»΄ν¨ν°μμ λͺ¨λΈμ΄ μλνκ² νλ €λ©΄, x86 μ backend λ‘ μ¬μ©νμΈμ. (μ΄μ μ 'fbgemm' λν μ¬μ ν μ¬μ© κ°λ₯νμ§λ§, 'x86'μ κΈ°λ³ΈμΌλ‘ μ¬μ©νλ κ²μ κΆμ₯ν©λλ€.)
μμνλ₯Ό κ³ λ €ν νμ΅μ λͺ¨λΈ νμ΅ κ³Όμ μμ λͺ¨λ κ°μ€μΉμ νμ± ν¨μμ κ°μ§ μμνλ₯Ό μ½μ νκ² λκ³ , νμ΅ ν μμννλ λ°©λ²λ³΄λ€ λμ μΆλ‘ μ νλλ₯Ό κ°μ§λλ€. μ΄λ μ£Όλ‘ CNN λͺ¨λΈμ μ¬μ©λ©λλ€.
λͺ¨λΈμ μμνλ₯Ό κ³ λ €ν νμ΅μ κ°λ₯νκ² νλ €λ©΄, λͺ¨λΈ μ μ λΆλΆμ __init__ λ©μλμμ QuantStub κ³Ό DeQuantStub μ μ μν΄μΌ ν©λλ€. μ΄λ€μ κ°κ° tensorλ₯Ό μ€μνμμ μμνλ μλ£νμΌλ‘ μ ννκ±°λ λ°λλ‘ μ ννλ μν μ λλ€.
self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub()
κ·Έλ€μ, λͺ¨λΈ μ μ λΆλΆμ forward λ©μλμ μμ λΆλΆκ³Ό λλΆλΆμμ, x = self.quant(x) μ x = self.dequant(x) λ₯Ό νΈμΆνμΈμ.
μμνλ₯Ό κ³ λ €ν νμ΅μ μ§ννλ €λ©΄, λ€μμ μ½λ μ‘°κ°μ μ¬μ©νμμμ€.
model.qconfig = torch.quantization.get_default_qat_qconfig(backend) model_qat = torch.quantization.prepare_qat(model, inplace=False) # μμνλ₯Ό κ³ λ €ν νμ΅μ΄ μ¬κΈ°μ μ§νλ©λλ€. model_qat = torch.quantization.convert(model_qat.eval(), inplace=False)
μμνλ₯Ό κ³ λ €ν νμ΅μ λ μμΈν μμλ μ¬κΈ° μ μ¬κΈ° λ₯Ό μ°Έκ³ νμΈμ.
μ¬μ νμ΅λ μμν μ μ© λͺ¨λΈλ μμνλ₯Ό κ³ λ €ν μ μ΄ νμ΅μ μ¬μ©λ μ μμ΅λλ€. μ΄λλ μμμ μ¬μ©ν quant μ dequant λ₯Ό λκ°μ΄ μ¬μ©ν©λλ€. μ 체 μμ λ μ¬κΈ° λ₯Ό νμΈνμΈμ.
μμ λ¨κ³ μ€ νλλ₯Ό μ΄μ©ν΄ μμνλ λͺ¨λΈμ΄ μμ±λ νμ, λͺ¨λ°μΌ μ₯μΉμμ μλλκ² νλ €λ©΄ μΆκ°λ‘ TorchScript νμμΌλ‘ μ ννκ³ λͺ¨λ°μΌ appμ μ΅μ νλ₯Ό μ§νν΄μΌ ν©λλ€. μμΈν λ΄μ©μ Script and Optimize for Mobile recipe λ₯Ό νμΈνμΈμ.
λ€λ₯Έ μμν μ μ©λ²μ λν μΆκ° μ 보λ μ¬κΈ° μ μ¬κΈ° λ₯Ό μ°Έκ³ νμΈμ.