In this notebook, we convert our pretrained fake-quantized pytorch models to ONNX format

In [None]:
import sys
sys.path.append('../')
sys.path.append('../../')
import os
import yaml
from os.path import join, dirname
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_quantization.nn as quant_nn
import ruamel.yaml
from easydict import EasyDict as edict
from quantization_libs.calibrator import collect_stats, compute_amax
from utils.misc import setup_seed
from dataset.build_dataset import get_dataset

Load dataset and CLIP text embeddings 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#load model and quantization configs
with open('../world_swin/configs/scannet/swint_mix_ctrs.yaml', 'r') as stream:
    config = edict(ruamel.yaml.safe_load(stream))
with open('../quantization_configs/jacob.yaml', 'r') as stream:
    quant_config = edict(ruamel.yaml.safe_load(stream))
config.update(quant_config)

# load dataset
dataset_val = get_dataset(config.DATA.DATASET)(
            split='test',
            data_dir=join('../dbs', config.DATA.DATASET),
            depth_transform=config.DATA.DEPTH_TRANSFORM,
            label_type='gt',
        )
data_loader_val = DataLoader(dataset_val, batch_size=config.DATA.VAL_BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)

print(f"==> Dataset: {config.DATA.DATASET} Val set: {len(dataset_val)}, Batch size: {config.DATA.VAL_BATCH_SIZE}")

# load computed text_features to gpu 0 
text_features = torch.load('text_features_scannet.pt', map_location='cpu')
text_features = text_features.to(device)


In [None]:
def evaluate(model):
    model.eval()
    cnt_correct = 0
    for batch_idx, (rgb_imgs, depth_imgs, class_id) in enumerate(tqdm(data_loader_val, leave=False)):
        batch_size = rgb_imgs.shape[0]
        batch_diff = config.DATA.VAL_BATCH_SIZE - batch_size

        # padd tensors with 0s to make sure they have the same size
        if rgb_imgs.shape[0] < config.DATA.VAL_BATCH_SIZE:
            rgb_imgs = torch.cat([rgb_imgs, torch.zeros([batch_diff] + list(rgb_imgs.shape)[1:])], dim=0)
            depth_imgs = torch.cat([depth_imgs, torch.zeros([batch_diff] + list(depth_imgs.shape)[1:])], dim=0)
            class_id = torch.cat([class_id, torch.zeros([batch_diff] + list(class_id.shape)[1:])], dim=0)

        if config.MODAL == 'rgb':
            input_imgs = rgb_imgs.to(device)   # ([32, 4, 3, 224, 224]), ([32])
        elif config.MODAL == 'depth':
            input_imgs = depth_imgs.to(device) # ([32, 4, 3, 224, 224]), ([32])
        else:
            raise NotImplementedError

        with torch.no_grad():
            image_features = model(input_imgs) # ([1, 512])

        # Pick the top 5 most similar labels for the image
        image_features = F.normalize(image_features, p=2, dim=-1)
        
        similarity = (100.0 * image_features @ text_features.float().T).softmax(dim=-1) # ([1, 19])
        for i in range(len(similarity)):
            if i >= batch_size:
                break
            values, indices = similarity[i].topk(5)
            if indices[0].item() == class_id[i].item():
                cnt_correct += 1

    acc1 = cnt_correct / len(data_loader_val.dataset)
    return acc1



In [None]:
from world_swin.build_model import build_model
from world_swin.build_model import load_weights
from pytorch_quantization import quant_modules

# load models
model = build_model(config)
load_weights(model, '../logs/swint_0218_143734/wandb/latest-run/files/src/best_model.pth')
model.to(device)

# fix the quantiztion scales as TensorRT only support static quantization
with torch.no_grad():
    collect_stats(model, data_loader_val, config.quantization, device)
    compute_amax(model, config.quantization, device)

evaluate(model)

Convert pytorch model to ONNX format

In [None]:
model_path = 'swin.onnx'
dummy_input = torch.randn(1, 3, 224, 224).to(device)
quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
    torch.onnx.export(model, dummy_input, model_path, verbose=False, opset_version=15)