In [1]:
import os

import numpy as np
import math
import random

import torch

import torch.nn as nn 
from torch.nn.utils import parameters_to_vector
import torch.optim as optim
from torchinfo import summary

from brevitas.export import export_onnx_qcdq

import config_aimet
import models_aimet_medium_fasdd
import models
import utils
import datasets
import metrics
import loss
import val_epoch

In [2]:
if config_aimet.MODEL == "BED":
    
    print("Using Fixed Point Quantizers without BN")
    quant_model = models_aimet_medium_fasdd.QUANT_MEDIUM_PRUNING_AFTER_SVD_CLASSIFIER(
            weight_bw = config_aimet.WEIGHTS_BIT_WIDTH,
            big_layers_weight_bw = config_aimet.BIG_LAYERS_WEIGHTS_BIT_WIDTH,
            act_bw = config_aimet.ACTIVATIONS_BIT_WIDTH,
            bias_bw = config_aimet.BIAS_BIT_WIDTH,
            num_classes=config_aimet.N_CLASSES).to(config_aimet.DEVICE)
    # quant_model = models.QUANT_FixedPoint_NoBN_BED_CLASSIFIER(
    #         weight_bw = config_aimet.WEIGHTS_BIT_WIDTH,
    #         big_layers_weight_bw = config_aimet.BIG_LAYERS_WEIGHTS_BIT_WIDTH,
    #         act_bw = config_aimet.ACTIVATIONS_BIT_WIDTH,
    #         bias_bw = config_aimet.BIAS_BIT_WIDTH,
    #         num_classes=config_aimet.N_CLASSES).to(config_aimet.DEVICE)

else:
    print("Wrong Model")
    raise SystemExit("Wrong Model")

optimizer = optim.Adam(quant_model.parameters(), 
                       lr=config_aimet.LEARNING_RATE, 
                       weight_decay=config_aimet.WEIGHT_DECAY)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                 mode='min',
                                                 factor=config_aimet.FACTOR, 
                                                 patience=config_aimet.PATIENCE, 
                                                 threshold=config_aimet.THRES, 
                                                 threshold_mode='abs',
                                                 min_lr=config_aimet.MIN_LR)

# MODEL PARAMETERS
n_trainable = sum(p.numel() for p in quant_model.parameters() if p.requires_grad)
print(f'\nTrainable parameters = {n_trainable}')

n_params = parameters_to_vector(quant_model.parameters()).numel()
print(f'Total parameters = {n_params}\n')

Using Fixed Point Quantizers without BN

Trainable parameters = 63631
Total parameters = 63631



## Model Medium Compression

In [3]:
# model_dir = 'experiments_fuseBN_256_fasdd/test_v10_MED_w4W3a8b4_FxdPnt_MSE_PerChnlW_IntBiasIntScl/weights/'
# model_file = model_dir + 'BED_classifier__best_mean_F1.pt'
# epochs_trained = utils.load_checkpoint(model_file, quant_model, optimizer, scheduler, config_aimet.DEVICE)

### Model with Medium Compression: conv341 defined as big layer

In [4]:
model_dir = 'experiments_fuseBN_256_fasdd/test_v11_MED_w4W3a8b4_FxdPnt_MSE_PerChnlW_IntBiasIntScl/weights/'
model_file = model_dir + 'BED_classifier__smoke__precision=0.9096__recall=0.8942__epoch=93.pt'
epochs_trained = utils.load_checkpoint(model_file, quant_model, optimizer, scheduler, config_aimet.DEVICE)

Loading Model. Trained during 93 epochs


## Model with No Compression

In [5]:
# model_dir = 'experiments_fuseBN_256_fasdd/test_v00_NoCOMP_w4W2a8b4_FxdPnt_MSE_PerChnlW_IntBiasIntScl/weights/'
# model_file = model_dir + 'BED_classifier__best_smoke__precision=0.9217__epoch=59.pt'
# epochs_trained = utils.load_checkpoint(model_file, quant_model, optimizer, scheduler, config_aimet.DEVICE)

# Evaluate Brevitas Quant Model

## Dataset

In [6]:
val_loader = datasets.get_val_loader()


TEST DFire dataset
DFire Removed wrong images: 0
DFire empty images: 2005
DFire only smoke images: 1186
DFire only fire images: 220
DFire smoke and fire images: 895

Test dataset len: 4306

TEST FASDD UAV dataset
DFire Removed wrong images: 0
DFire empty images: 1997
DFire only smoke images: 846
DFire only fire images: 35
DFire smoke and fire images: 1303

Test FASDD UAV dataset len: 4181

TEST FASDD CV dataset
DFire Removed wrong images: 0
DFire empty images: 6533
DFire only smoke images: 3902
DFire only fire images: 2091
DFire smoke and fire images: 3358

Test FASDD CV dataset len: 15884
Concatenate Test DFire and FASDD UAV datasets
Test dataset len: 8487
Concatenate with FASDD CV dataset
Test dataset len: 24371


## Loss
Needed for evaluation function

In [7]:
if config_aimet.LOSS_FN == "BCE":
    print(f'Loss Function: BCE')
    print(f'Smoke Precision Weight: {config_aimet.SMOKE_PRECISION_WEIGHT}')
    loss_fn = loss.BCE_LOSS(device=config_aimet.DEVICE, smoke_precision_weight=config_aimet.SMOKE_PRECISION_WEIGHT)
else:
    print("Wrong loss function")
    raise SystemExit("Wrong loss function")

Loss Function: BCE
Smoke Precision Weight: 0.8


In [8]:
with torch.no_grad():
    val_losses, val_metrics = val_epoch.eval_fn(
        loader=val_loader, 
        model=quant_model,                         
        loss_fn=loss_fn,
        device=config_aimet.DEVICE)

  return super().rename(names)
Validating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:33<00:00, 11.19it/s]

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
20.536      |13.376      |7.160       
SMOKE -> Precision: 0.910 - Recall: 0.894 - Accuracy: 0.908 - F1: 0.902
FIRE -> Precision: 0.903 - Recall: 0.969 - Accuracy: 0.956 - F1: 0.935





In [9]:
for k, v in val_metrics.items():
    print(f'{k}: smoke {v[0]:.4f} - fire: {v[1]:.4f}')
print(f'F1 Mean: {(val_metrics["F1"][0] + val_metrics["F1"][1])/2:.4f}')

Accuracy: smoke 0.9082 - fire: 0.9562
Precision: smoke 0.9096 - fire: 0.9031
Recall: smoke 0.8942 - fire: 0.9689
F1: smoke 0.9018 - fire: 0.9348
F1 Mean: 0.9183


# Model to CPU and ONNX Export

In [10]:
quant_model.to('cpu')

QUANT_MEDIUM_PRUNING_AFTER_SVD_CLASSIFIER(
  (model): Sequential(
    (input0): QuantIdentity(
      (input_quant): ActQuantProxyFromInjector(
        (_zero_hw_sentinel): StatelessBuffer()
      )
      (act_quant): ActQuantProxyFromInjector(
        (_zero_hw_sentinel): StatelessBuffer()
        (fused_activation_quant_proxy): FusedActivationQuantProxy(
          (activation_impl): Identity()
          (tensor_quant): RescalingIntQuant(
            (int_quant): IntQuant(
              (float_to_int_impl): RoundSte()
              (tensor_clamp_impl): TensorClamp()
              (delay_wrapper): DelayWrapper(
                (delay_impl): _NoDelay()
              )
            )
            (scaling_impl): ConstScaling(
              (restrict_clamp_scaling): _RestrictClampValue(
                (clamp_min_ste): ScalarClampMinSte()
                (restrict_value_impl): PowerOfTwoRestrictValue(
                  (float_to_int_impl): CeilSte()
                  (power_of_two): PowerOfT

In [11]:
export_onnx_qcdq(
    quant_model, 
    torch.randn(1, 3, config_aimet.IMG_H, config_aimet.IMG_W).to('cpu'), 
    export_path='./models/onnx_export/medium_fassd__conv341_big__epoch=93.onnx')

