In [1]:
import os

import utils

import models
import models_aimet_high
import models_aimet_medium
import models_aimet_low

import config
import dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import validate

import torch
from torchinfo import summary

import cv2
import matplotlib.pyplot as plt

from brevitas.export import export_onnx_qcdq

# Models

## Model No Compression

In [2]:
model_no_comp = models.QUANT_FixedPoint_NoBN_BED_CLASSIFIER(
    weight_bw = config.NO_COMP_WEIGHTS_BIT_WIDTH,
    big_layers_weight_bw = config.NO_COMP_BIG_LAYERS_WEIGHTS_BIT_WIDTH,
    act_bw = config.NO_COMP_ACTIVATIONS_BIT_WIDTH,
    bias_bw = config.NO_COMP_BIAS_BIT_WIDTH,
    num_classes=config.N_CLASSES).to(config.DEVICE)

In [3]:
model_no_comp_folder = './models/'
model_no_comp_name = 'BED_classifier__NOCOMP__smoke__precision=0.9025__recall=0.9021__epoch=35.pt'
model_no_comp_pt = model_no_comp_folder + model_no_comp_name

In [4]:
utils.load_checkpoint(model_path = model_no_comp_pt, 
                      model = model_no_comp, 
                      optimizer= None, 
                      scheduler= None, 
                      device = config.DEVICE)

Loading Model. Trained during 35 epochs


35

### Export to ONNX

In [5]:
export_onnx_qcdq(
    model_no_comp, 
    torch.randn(1, 3, config.IMG_H, config.IMG_W).to(config.DEVICE), 
    export_path='./models/onnx/no_comp_model.onnx')



## Model Low Compression

In [6]:
model_low_comp = models_aimet_low.QUANT_SOFT_PRUNING_AFTER_SVD_CLASSIFIER(
    weight_bw = config.LOW_COMP_WEIGHTS_BIT_WIDTH,
    big_layers_weight_bw = config.LOW_COMP_BIG_LAYERS_WEIGHTS_BIT_WIDTH,
    act_bw = config.LOW_COMP_ACTIVATIONS_BIT_WIDTH,
    bias_bw = config.LOW_COMP_BIAS_BIT_WIDTH,
    num_classes=config.N_CLASSES).to(config.DEVICE)

In [7]:
model_low_comp_folder = './models/'
model_low_comp_name = 'BED_classifier__LOWCOMP__smoke__precision=0.9024__recall=0.9011__epoch=80.pt'
model_low_comp_pt = model_low_comp_folder + model_low_comp_name

In [8]:
utils.load_checkpoint(model_path = model_low_comp_pt, 
                      model = model_low_comp, 
                      optimizer= None, 
                      scheduler= None, 
                      device = config.DEVICE)

Loading Model. Trained during 80 epochs


80

### Export to ONNX

In [9]:
export_onnx_qcdq(
    model_low_comp, 
    torch.randn(1, 3, config.IMG_H, config.IMG_W).to(config.DEVICE), 
    export_path='./models/onnx/low_comp_model.onnx')



## Model Medium Compression

In [10]:
model_med_comp = models_aimet_medium.QUANT_MEDIUM_PRUNING_AFTER_SVD_CLASSIFIER(
    weight_bw = config.MED_COMP_WEIGHTS_BIT_WIDTH,
    big_layers_weight_bw = config.MED_COMP_BIG_LAYERS_WEIGHTS_BIT_WIDTH,
    act_bw = config.MED_COMP_ACTIVATIONS_BIT_WIDTH,
    bias_bw = config.MED_COMP_BIAS_BIT_WIDTH,
    num_classes=config.N_CLASSES).to(config.DEVICE)

In [11]:
model_med_comp_folder = './models/'
model_med_comp_name = 'BED_classifier__MEDCOMP__smoke__precision=0.9028__recall=0.9001__epoch=49.pt'
model_med_comp_pt = model_med_comp_folder + model_med_comp_name

In [12]:
utils.load_checkpoint(model_path = model_med_comp_pt, 
                      model = model_med_comp, 
                      optimizer= None, 
                      scheduler= None, 
                      device = config.DEVICE)

Loading Model. Trained during 49 epochs


49

### Export to ONNX

In [13]:
export_onnx_qcdq(
    model_med_comp, 
    torch.randn(1, 3, config.IMG_H, config.IMG_W).to(config.DEVICE), 
    export_path='./models/onnx/med_comp_model.onnx')



## Model High Compression

In [14]:
model_high_comp = models_aimet_high.QUANT_PRUNING_AFTER_SVD_CLASSIFIER(
    weight_bw = config.HIGH_COMP_WEIGHTS_BIT_WIDTH,
    big_layers_weight_bw = config.HIGH_COMP_BIG_LAYERS_WEIGHTS_BIT_WIDTH,
    act_bw = config.HIGH_COMP_ACTIVATIONS_BIT_WIDTH,
    bias_bw = config.HIGH_COMP_BIAS_BIT_WIDTH,
    num_classes=config.N_CLASSES).to(config.DEVICE)

In [15]:
model_high_comp_folder = './models/'
model_high_comp_name = 'BED_classifier__HIGHCOMP__smoke__precision=0.9081__recall=0.9006__epoch=90.pt'
model_high_comp_pt = model_high_comp_folder + model_high_comp_name

In [16]:
utils.load_checkpoint(model_path = model_high_comp_pt, 
                      model = model_high_comp, 
                      optimizer= None, 
                      scheduler= None, 
                      device = config.DEVICE)

Loading Model. Trained during 90 epochs


90

### Export to ONNX

In [17]:
export_onnx_qcdq(
    model_high_comp, 
    torch.randn(1, 3, config.IMG_H, config.IMG_W).to(config.DEVICE), 
    export_path='./models/onnx/high_comp_model.onnx')



# Evaluate all Models

## Val Dataset and Loader

In [18]:
# VALIDATION DATASET
val_transform = A.Compose([
    A.Resize(config.IMG_H, config.IMG_W, p=1),
    ToTensorV2(p=1),
    ]
)

print("\nTEST DFire dataset")
val_dataset = dataset.DFireDataset(
    img_h = config.IMG_H,
    img_w = config.IMG_W,
    img_dir = config.VAL_IMG_DIR,
    label_dir = config.VAL_LABEL_DIR,
    num_classes = config.N_CLASSES,
    ds_len = config.DS_LEN,
    transform=val_transform)

print(f'Test dataset len: {len(val_dataset)}')

# LOADERS
val_loader = DataLoader(dataset=val_dataset,
                        batch_size=config.BATCH_SIZE,
                        num_workers=config.NUM_WORKERS,
                        pin_memory=config.PIN_MEMORY,
                        shuffle=False,
                        drop_last=True)


TEST DFire dataset
DFire Removed wrong images: 0
DFire empty images: 2005
DFire only smoke images: 1186
DFire only fire images: 220
DFire smoke and fire images: 895
Test dataset len: 4306


## Run Evaluations

In [19]:
print('___________________________ NO COMPRESSION MODEL ___________________________')
metrics_model_no_comp = validate.eval_fn(val_loader, model_no_comp, config.DEVICE)
print('___________________________ LOW COMPRESSION MODEL ___________________________')
metrics_model_low_comp = validate.eval_fn(val_loader, model_low_comp, config.DEVICE)
print('___________________________ MEDIUM COMPRESSION MODEL ___________________________')
metrics_model_med_comp = validate.eval_fn(val_loader, model_med_comp, config.DEVICE)
print('___________________________ HIGH COMPRESSION MODEL ___________________________')
metrics_model_high_comp = validate.eval_fn(val_loader, model_high_comp, config.DEVICE)

___________________________ NO COMPRESSION MODEL ___________________________


Validating: 100%|██████████████████████████████████████████████████████████████████████████████████| 67/67 [00:04<00:00, 15.13it/s]


SMOKE -> Precision: 0.9025 - Recall: 0.9021 - Accuracy: 0.9060 - F1: 0.9023
FIRE -> Precision: 0.9352 - Recall: 0.9099 - Accuracy: 0.9604 - F1: 0.9224
Mean F1 Score: 0.9123
___________________________ LOW COMPRESSION MODEL ___________________________


Validating: 100%|██████████████████████████████████████████████████████████████████████████████████| 67/67 [00:04<00:00, 16.52it/s]


SMOKE -> Precision: 0.9024 - Recall: 0.9011 - Accuracy: 0.9056 - F1: 0.9018
FIRE -> Precision: 0.9216 - Recall: 0.9324 - Accuracy: 0.9620 - F1: 0.9270
Mean F1 Score: 0.9144
___________________________ MEDIUM COMPRESSION MODEL ___________________________


Validating: 100%|██████████████████████████████████████████████████████████████████████████████████| 67/67 [00:04<00:00, 16.65it/s]


SMOKE -> Precision: 0.9028 - Recall: 0.9001 - Accuracy: 0.9053 - F1: 0.9015
FIRE -> Precision: 0.9346 - Recall: 0.9144 - Accuracy: 0.9613 - F1: 0.9244
Mean F1 Score: 0.9129
___________________________ HIGH COMPRESSION MODEL ___________________________


Validating: 100%|██████████████████████████████████████████████████████████████████████████████████| 67/67 [00:04<00:00, 16.27it/s]

SMOKE -> Precision: 0.9081 - Recall: 0.9006 - Accuracy: 0.9083 - F1: 0.9044
FIRE -> Precision: 0.9416 - Recall: 0.9144 - Accuracy: 0.9632 - F1: 0.9278
Mean F1 Score: 0.9161





# FASDD UAV

In [20]:
# uav_imgs_dir = '../../../datasets/fasdd/fasdd_cv/images/'
# uav_imgs_list = os.listdir(uav_imgs_dir)
# num_uav_imgs = len(uav_imgs_list)
# print(f'Number of UAV images: {num_uav_imgs}')

In [21]:
# img = cv2.imread(uav_imgs_dir + uav_imgs_list[1500])
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = cv2.resize(img, (config.IMG_H, config.IMG_H))
# plt.imshow(img)
# plt.show()

In [22]:
# img = img / 256.
# img = torch.tensor(img, dtype=torch.float)
# img = torch.permute(img, (2, 0, 1))
# img = img.unsqueeze(dim=0)

In [23]:
# print(img.shape)

In [24]:
# out = model(img.to(config.DEVICE))

In [25]:
# print(out)