In [1]:
import os
from pathlib import Path

from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch.nn as nn 
import torch.optim as optim
from torchinfo import summary

import config
import dataset
import dataset_fasdd
import models
import loss
import val_epoch
import utils

# Validation Dataset

In [2]:
# VALIDATION DATASET
val_transform = A.Compose([
    A.Resize(config.IMG_H, config.IMG_W, p=1),
    ToTensorV2(p=1),
    ]
)

### DFire

In [7]:
print("\nTEST DFire dataset")
val_dfire_dataset = dataset.DFireDataset(
    img_h = config.IMG_H,
    img_w = config.IMG_W,
    img_dir = config.VAL_IMG_DIR,
    label_dir = config.VAL_LABEL_DIR,
    num_classes = config.N_CLASSES,
    ds_len = config.DS_LEN,
    transform=val_transform)
print(f'Test dataset len: {len(val_dfire_dataset)}')



TEST DFire dataset
DFire Removed wrong images: 0
DFire empty images: 2005
DFire only smoke images: 1186
DFire only fire images: 220
DFire smoke and fire images: 895
Test dataset len: 4306


### FASDD UAV

In [4]:
print("\nTEST FASDD UAV dataset")
val_fasdd_uav_ds = dataset_fasdd.FASDDDataset(
    img_h=config.IMG_H, 
    img_w=config.IMG_W, 
    imgs_dir=config.FASDD_UAV_IMGS_DIR, 
    labels_file=config.FASDD_UAV_TEST_LABELS_FILE, 
    num_classes=config.N_CLASSES,
    ds_len=config.DS_LEN,
    transform=val_transform)
print(f'\nTest FASDD UAV dataset len: {len(val_fasdd_uav_ds)}')


TEST FASDD UAV dataset
DFire Removed wrong images: 0
DFire empty images: 1997
DFire only smoke images: 846
DFire only fire images: 35
DFire smoke and fire images: 1303

Test FASDD UAV dataset len: 4181


### FASDD CV

In [5]:
print("\nTEST FASDD CV dataset")
val_fasdd_cv_ds = dataset_fasdd.FASDDDataset(
    img_h=config.IMG_H, 
    img_w=config.IMG_W, 
    imgs_dir=config.FASDD_CV_IMGS_DIR, 
    labels_file=config.FASDD_CV_TEST_LABELS_FILE, 
    num_classes=config.N_CLASSES,
    ds_len=config.DS_LEN,
    transform=val_transform)
print(f'\nTest FASDD CV dataset len: {len(val_fasdd_cv_ds)}')


TEST FASDD CV dataset
DFire Removed wrong images: 0
DFire empty images: 6533
DFire only smoke images: 3902
DFire only fire images: 2091
DFire smoke and fire images: 3358

Test FASDD CV dataset len: 15884


## Concatenate

In [8]:
print("Concatenate Test DFire and FASDD UAV datasets")
val_ds_concat = torch.utils.data.ConcatDataset((val_dfire_dataset, val_fasdd_uav_ds))
print(f'Test dataset len: {len(val_ds_concat)}')

print("Concatenate with FASDD CV dataset")
val_ds = torch.utils.data.ConcatDataset((val_ds_concat, val_fasdd_cv_ds))
print(f'Test dataset len: {len(val_ds)}')

Concatenate Test DFire and FASDD UAV datasets
Test dataset len: 8487
Concatenate with FASDD CV dataset
Test dataset len: 24371


### LOADER

In [9]:
# LOADERS
val_loader = DataLoader(dataset=val_ds,
                        batch_size=config.BATCH_SIZE,
                        num_workers=config.NUM_WORKERS,
                        pin_memory=config.PIN_MEMORY,
                        shuffle=False,
                        drop_last=True)

# Loss Function

In [10]:
# LOSS FUNCTION
if config.LOSS_FN == "BCE":
    print(f'Loss Function: BCE')
    print(f'Smoke Precision Weight: {config.SMOKE_PRECISION_WEIGHT}')
    loss_fn = loss.BCE_LOSS(device=config.DEVICE, smoke_precision_weight=config.SMOKE_PRECISION_WEIGHT)
else:
    print("Wrong loss function")
    raise SystemExit("Wrong loss function")

Loss Function: BCE
Smoke Precision Weight: 0.8


# Setup Model and Load Weights

In [11]:
if config.MODEL == "BED":
    print("Using BED Classifier")
    model = models.BED_CLASSIFIER(num_classes=config.N_CLASSES).to(config.DEVICE)  
else:
    print("Wrong Model")
    raise SystemExit("Wrong Model")

Using BED Classifier


In [12]:
optimizer = optim.Adam(
    model.parameters(), 
    lr=config.LEARNING_RATE, 
    weight_decay=config.WEIGHT_DECAY)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min',
    factor=config.FACTOR, 
    patience=config.PATIENCE, 
    threshold=config.THRES, 
    threshold_mode='abs',
    min_lr=config.MIN_LR)

# Load Model WITH Batch Norm

In [13]:
#model_name = 'BED_classifier__smoke__precision=0.9219__recall=0.9152__epoch=124.pt'
#model_name = 'BED_classifier__smoke__precision=0.9123__recall=0.9127__epoch=133.pt'
# model_name = 'BED_classifier__smoke__precision=0.9111__recall=0.904__epoch=127.pt'
# model_path = 'experiments_256/v0_img_divided_by_256/weights/' + model_name
# epoch_saved = utils.load_checkpoint(model_path, model, optimizer, scheduler, config.DEVICE)

### Model trained with DFire and FASDD

In [14]:
model_name = 'BED_classifier__best_smoke__precision=0.935__epoch=87.pt'
model_path = 'experiments_256_add_fasdd/test_00/weights/' + model_name
epoch_saved = utils.load_checkpoint(model_path, model, optimizer, scheduler, config.DEVICE)

Loading Model. Trained during 87 epochs


# Fuse Conv2d and BatchNorm

In [15]:
modules_to_fuse = [ 
    ["model.conv1", "model.bn1"],
    ["model.conv2", "model.bn2"],
    ["model.conv31", "model.bn31"],
    ["model.conv32", "model.bn32"],
    ["model.conv33", "model.bn33"],
    ["model.conv34", "model.bn34"],
    ["model.conv41", "model.bn41"],
    ["model.conv42", "model.bn42"],
    ["model.conv43", "model.bn43"],
    ["model.conv44", "model.bn44"],
    ["model.conv45", "model.bn45"],
    ["model.conv46", "model.bn46"]
]

In [16]:
model.eval()
fused_model = torch.ao.quantization.fuse_modules(model, modules_to_fuse)

# Print Fused Model

In [17]:
print(summary(fused_model, input_size=(1, 3, config.IMG_H, config.IMG_W)))

Layer (type:depth-idx)                   Output Shape              Param #
BED_CLASSIFIER                           [1, 2]                    --
├─Sequential: 1-1                        [1, 2]                    --
│    └─Conv2d: 2-1                       [1, 32, 224, 224]         896
│    └─Identity: 2-2                     [1, 32, 224, 224]         --
│    └─ReLU: 2-3                         [1, 32, 224, 224]         --
│    └─Dropout2d: 2-4                    [1, 32, 224, 224]         --
│    └─MaxPool2d: 2-5                    [1, 32, 112, 112]         --
│    └─Conv2d: 2-6                       [1, 16, 112, 112]         4,624
│    └─Identity: 2-7                     [1, 16, 112, 112]         --
│    └─ReLU: 2-8                         [1, 16, 112, 112]         --
│    └─Dropout2d: 2-9                    [1, 16, 112, 112]         --
│    └─MaxPool2d: 2-10                   [1, 16, 56, 56]           --
│    └─Conv2d: 2-11                      [1, 16, 56, 56]           272
│    └─Ide

# Evaluate Fused Model vs Un-Fused Model

In [18]:
model.eval()
fused_model.eval()

with torch.no_grad():
    print("____________________________ MODEL BEFORE FUSION ____________________________")
    val_losses, val_metrics = val_epoch.eval_fn(
        loader=val_loader, 
        model=model,                         
        loss_fn=loss_fn,
        device=config.DEVICE)
    print("\n____________________________ MODEL AFTER FUSION ____________________________")
    val_losses, val_metrics = val_epoch.eval_fn(
        loader=val_loader, 
        model=fused_model,                         
        loss_fn=loss_fn,
        device=config.DEVICE)

____________________________ MODEL BEFORE FUSION ____________________________


Validating: 100%|████████████████████████████████████████████████████████████████████████████████| 380/380 [00:22<00:00, 16.85it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
19.195      |12.298      |6.897       
SMOKE -> Precision: 0.935 - Recall: 0.878 - Accuracy: 0.914 - F1: 0.906
FIRE -> Precision: 0.913 - Recall: 0.974 - Accuracy: 0.961 - F1: 0.942

____________________________ MODEL AFTER FUSION ____________________________


Validating: 100%|████████████████████████████████████████████████████████████████████████████████| 380/380 [00:22<00:00, 16.94it/s]

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
19.193      |12.296      |6.897       
SMOKE -> Precision: 0.935 - Recall: 0.878 - Accuracy: 0.914 - F1: 0.906
FIRE -> Precision: 0.913 - Recall: 0.974 - Accuracy: 0.961 - F1: 0.942





# Load Model WITHOUT Batch Norm

In [19]:
model_woBN = models.NoBN_BED_CLASSIFIER(num_classes=config.N_CLASSES).to(config.DEVICE) 

optimizer_noBN = optim.Adam(
    model_woBN.parameters(), 
    lr=config.LEARNING_RATE, 
    weight_decay=config.WEIGHT_DECAY)

scheduler_noBN = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_noBN, 
    mode='min',
    factor=config.FACTOR, 
    patience=config.PATIENCE, 
    threshold=config.THRES, 
    threshold_mode='abs',
    min_lr=config.MIN_LR)

model_woBN.eval()

NoBN_BED_CLASSIFIER(
  (model): Sequential(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (dropout1): Dropout2d(p=0.3, inplace=False)
    (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU()
    (dropout2): Dropout2d(p=0.3, inplace=False)
    (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv31): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
    (relu31): ReLU()
    (conv32): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu32): ReLU()
    (conv33): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
    (relu33): ReLU()
    (conv34): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu34): ReLU()
    (maxpool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv41): Conv

In [20]:
print(summary(model_woBN, input_size=(1, 3, config.IMG_H, config.IMG_W)))

Layer (type:depth-idx)                   Output Shape              Param #
NoBN_BED_CLASSIFIER                      [1, 2]                    --
├─Sequential: 1-1                        [1, 2]                    --
│    └─Conv2d: 2-1                       [1, 32, 224, 224]         896
│    └─ReLU: 2-2                         [1, 32, 224, 224]         --
│    └─Dropout2d: 2-3                    [1, 32, 224, 224]         --
│    └─MaxPool2d: 2-4                    [1, 32, 112, 112]         --
│    └─Conv2d: 2-5                       [1, 16, 112, 112]         4,624
│    └─ReLU: 2-6                         [1, 16, 112, 112]         --
│    └─Dropout2d: 2-7                    [1, 16, 112, 112]         --
│    └─MaxPool2d: 2-8                    [1, 16, 56, 56]           --
│    └─Conv2d: 2-9                       [1, 16, 56, 56]           272
│    └─ReLU: 2-10                        [1, 16, 56, 56]           --
│    └─Conv2d: 2-11                      [1, 32, 56, 56]           4,640
│    └─

# Load Pretrained Weights Into No Batch Norm Model

- Iterate over fused model modules
- Iterate over model wo BN modules
- If names match, load weights

In [21]:
fused_model.eval()

for ori_model_name, ori_model_mod in fused_model.named_modules():
    # Use below line if model wo BN has no names defined, but self.convxx = ...
    #ori_model_name = ori_model_name.split('.')[-1]
    for model_woBN_name, model_woBN_mod in  model_woBN.named_modules():
        if ori_model_name == model_woBN_name:
            print(f'original model name: {ori_model_name} - no BN model name: {model_woBN_name}')
            if isinstance(ori_model_mod, nn.Conv2d):
                if ori_model_name == model_woBN_name:
                    model_woBN_mod.load_state_dict(ori_model_mod.state_dict())
                    print(f'\t****** Loading weights of Conv2d layer ori {ori_model_name} into no BN {model_woBN_name}')
            elif isinstance(ori_model_mod, nn.BatchNorm2d):
                print(f'\toooooo BN should never print here, as Fused Model should not have such layers')    
            elif isinstance(ori_model_mod, nn.Linear):
                if ori_model_name == model_woBN_name:
                    model_woBN_mod.load_state_dict(ori_model_mod.state_dict())
                    print(f'\t****** Loading weights of Linear layer ori {ori_model_name} into no BN {model_woBN_name}')
            else:
                print(f'\t______ Ignore weights or params of layer ori {ori_model_name} and no BN {model_woBN_name}')

original model name:  - no BN model name: 
	______ Ignore weights or params of layer ori  and no BN 
original model name: model - no BN model name: model
	______ Ignore weights or params of layer ori model and no BN model
original model name: model.conv1 - no BN model name: model.conv1
	****** Loading weights of Conv2d layer ori model.conv1 into no BN model.conv1
original model name: model.relu1 - no BN model name: model.relu1
	______ Ignore weights or params of layer ori model.relu1 and no BN model.relu1
original model name: model.dropout1 - no BN model name: model.dropout1
	______ Ignore weights or params of layer ori model.dropout1 and no BN model.dropout1
original model name: model.maxpool2 - no BN model name: model.maxpool2
	______ Ignore weights or params of layer ori model.maxpool2 and no BN model.maxpool2
original model name: model.conv2 - no BN model name: model.conv2
	****** Loading weights of Conv2d layer ori model.conv2 into no BN model.conv2
original model name: model.relu

# Evaluate Fused Model WO BN vs Un-Fused Original Model

In [22]:
model.eval()
model_woBN.eval()

with torch.no_grad():
    print("____________________________ MODEL BEFORE FUSION ____________________________")
    val_losses, val_metrics = val_epoch.eval_fn(
        loader=val_loader, 
        model=model,                         
        loss_fn=loss_fn,
        device=config.DEVICE)
    print("\n____________________________ MODEL AFTER FUSION ____________________________")
    val_losses, val_metrics = val_epoch.eval_fn(
        loader=val_loader, 
        model=model_woBN,                         
        loss_fn=loss_fn,
        device=config.DEVICE)

____________________________ MODEL BEFORE FUSION ____________________________


Validating: 100%|████████████████████████████████████████████████████████████████████████████████| 380/380 [00:22<00:00, 16.78it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
19.195      |12.298      |6.897       
SMOKE -> Precision: 0.935 - Recall: 0.878 - Accuracy: 0.914 - F1: 0.906
FIRE -> Precision: 0.913 - Recall: 0.974 - Accuracy: 0.961 - F1: 0.942

____________________________ MODEL AFTER FUSION ____________________________


Validating: 100%|████████████████████████████████████████████████████████████████████████████████| 380/380 [00:22<00:00, 16.89it/s]

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
19.193      |12.296      |6.897       
SMOKE -> Precision: 0.935 - Recall: 0.878 - Accuracy: 0.914 - F1: 0.906
FIRE -> Precision: 0.913 - Recall: 0.974 - Accuracy: 0.961 - F1: 0.942





# Save Fused Model and Model WO BN with Pretrained Weights

In [None]:
# checkpoint_name = 'BED_classifier__fused_ConvBN.pt'
# utils.save_checkpoint(epoch_saved, fused_model, optimizer, scheduler, checkpoint_name)
# checkpoint_name_noBN = 'BED_classifier__NoBN_fused_ConvBN_v2.pt'
# utils.save_checkpoint(epoch_saved, model_woBN, optimizer_noBN, scheduler_noBN, checkpoint_name_noBN)

### Model Trained with DFire and FASDD

In [None]:
# checkpoint_name_noBN = 'BED_classifier__fused__dfire_fasdd.pt'
# utils.save_checkpoint(epoch_saved, model_woBN, optimizer_noBN, scheduler_noBN, checkpoint_name_noBN)

# Export to ONNX

In [None]:
# torch_input = torch.randn(1, 3, config.IMG_H, config.IMG_W).to(config.DEVICE)
# #onnx_program_fused = torch.onnx.dynamo_export(fused_model, torch_input)
# onnx_program_noBN = torch.onnx.dynamo_export(model_woBN, torch_input)

In [None]:
#onnx_program_fused.save("BED_classifier__fused_ConvBN.onnx")
#onnx_program_noBN.save("BED_classifier__NoBN_fused_ConvBN_v2.onnx")

# FASDD
#onnx_program_noBN.save("BED_classifier__fused__dfire_fasdd.onnx")

# Optimize with JIT

In [None]:
# frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(model.eval()))

In [None]:
# frozen_mod.save('BED_classifier__jit_optimized.pt')

# Some print syntax

In [None]:
model.model.bn1.weight

In [None]:
# Name is of type string
for name, mod in fused_model.named_modules():
    print(name, mod)