In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import os

import torch.nn.utils.prune as prune

In [2]:
PRUNE_AMOUNT = 0.7

FLASK_MODEL_DIR = './Flask/Models'
TRITON_MODEL_DIR = './Triton/Models'

TRITON_CONFIG_FILE = 'config.pbtxt'
TRITON_MODEL_FILE = 'model.pt'

os.makedirs(FLASK_MODEL_DIR, exist_ok=True)
os.makedirs(TRITON_MODEL_DIR, exist_ok=True)

In [3]:
def is_leaf_module(module):
    if not list(module.children()):
        return True
    else:
        return False

def is_conv3x3_module(module):
    if isinstance(module, torch.nn.modules.conv.Conv2d):
        if module.kernel_size == (3, 3):
            return True
        else:
            return False
    else:
        return False

In [4]:
def get_prune_modules(model):
    # get conv3x3 modules
    return [(m, 'weight') for m in model.modules() if is_conv3x3_module(m)]

In [5]:
def save_flask_model(model, model_name):
    path = os.path.join(FLASK_MODEL_DIR, model_name + '.pt')
    torch.jit.save(model, path)
    print(path)
    return

In [6]:
def save_triton_model(model, config, model_name):
    path = os.path.join(TRITON_MODEL_DIR, model_name, TRITON_CONFIG_FILE)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w') as f:
        f.write(config.strip())
    print(path)
    
    path = os.path.join(TRITON_MODEL_DIR, model_name, '1', TRITON_MODEL_FILE)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.jit.save(model, path)
    print(path)
    return

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [8]:
model_name = f'resnet34-prune{int(PRUNE_AMOUNT * 100)}-script'

model = torchvision.models.resnet34(pretrained=True)

modules_to_prune = get_prune_modules(model)

prune.global_unstructured(
    modules_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=PRUNE_AMOUNT,
)

for m, n in modules_to_prune:
    prune.remove(m, n)

model = torch.jit.script(model)
model = model.to(device)
model = model.eval()

triton_config = """
platform: "pytorch_libtorch"
max_batch_size: 32
input [
 {
    name: "input__0"
    data_type: TYPE_FP32
    format: FORMAT_NCHW
    dims: [ 3, 224, 224 ]
  }
]
output {
    name: "output__0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
"""

save_flask_model(model, model_name)
save_triton_model(model, triton_config, model_name)

./Flask/Models/resnet34-prune70-script.pt
./Triton/Models/resnet34-prune70-script/config.pbtxt
./Triton/Models/resnet34-prune70-script/1/model.pt


In [9]:
model_name = f'mobilenet_v2-prune{int(PRUNE_AMOUNT * 100)}-script'

model = torchvision.models.mobilenet_v2(pretrained=True)

modules_to_prune = get_prune_modules(model)

prune.global_unstructured(
    modules_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=PRUNE_AMOUNT,
)

for m, n in modules_to_prune:
    prune.remove(m, n)

model = torch.jit.script(model)
model = model.to(device)
model = model.eval()

triton_config = """
platform: "pytorch_libtorch"
max_batch_size: 32
input [
 {
    name: "input_0"
    data_type: TYPE__FP32
    format: FORMAT_NCHW
    dims: [ 3, 224, 224 ]
  }
]
output {
    name: "output__0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
"""

save_flask_model(model, model_name)
save_triton_model(model, triton_config, model_name)

./Flask/Models/mobilenet_v2-prune70-script.pt
./Triton/Models/mobilenet_v2-prune70-script/config.pbtxt
./Triton/Models/mobilenet_v2-prune70-script/1/model.pt


In [10]:
model_name = f'efficientnet_b0-prune{int(PRUNE_AMOUNT * 100)}-script'

model = torchvision.models.efficientnet_b0(pretrained=True)

modules_to_prune = get_prune_modules(model)

prune.global_unstructured(
    modules_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=PRUNE_AMOUNT,
)

for m, n in modules_to_prune:
    prune.remove(m, n)

model = torch.jit.script(model)
model = model.to(device)
model = model.eval()

triton_config = """
platform: "pytorch_libtorch"
max_batch_size: 32
input [
 {
    name: "input__0"
    data_type: TYPE_FP32
    format: FORMAT_NCHW
    dims: [ 3, 224, 224 ]
  }
]
output {
    name: "output__0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
"""

save_flask_model(model, model_name)
save_triton_model(model, triton_config, model_name)

./Flask/Models/efficientnet_b0-prune70-script.pt
./Triton/Models/efficientnet_b0-prune70-script/config.pbtxt
./Triton/Models/efficientnet_b0-prune70-script/1/model.pt


In [11]:
model_name = f'efficientnet_b7-prune{int(PRUNE_AMOUNT * 100)}-script'

model = torchvision.models.efficientnet_b7(pretrained=True)

modules_to_prune = get_prune_modules(model)

prune.global_unstructured(
    modules_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=PRUNE_AMOUNT,
)

for m, n in modules_to_prune:
    prune.remove(m, n)

model = torch.jit.script(model)
model = model.to(device)
model = model.eval()

triton_config = """
platform: "pytorch_libtorch"
max_batch_size: 32
input [
 {
    name: "input__0"
    data_type: TYPE_FP32
    format: FORMAT_NCHW
    dims: [ 3, 600, 600 ]
  }
]
output {
    name: "output__0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
"""

save_flask_model(model, model_name)
save_triton_model(model, triton_config, model_name)

./Flask/Models/efficientnet_b7-prune70-script.pt
./Triton/Models/efficientnet_b7-prune70-script/config.pbtxt
./Triton/Models/efficientnet_b7-prune70-script/1/model.pt
