In [1]:
import logging
import datetime

import numpy as np
import torch
from torch.nn.utils import parameters_to_vector
import torch.optim as optim
from torchinfo import summary

import matplotlib.pyplot as plt

import config
import modules.dataloaders as dataloaders

# import modules.brevitas.model_mobilenetv2_mini_Brevitas as cnv_model
import modules.brevitas.model_mobilenetv2_1M_Brevitas as cnv_model


import modules.loss as loss
import modules.metrics as metrics
import modules.train_epoch as train_epoch
import modules.val_epoch as val_epoch
import modules.utils as utils

In [2]:
# from brevitas.export import export_onnx_qcdq
# from brevitas.export import export_qonnx

# Logger

In [3]:
log_path = config.LOGS_FOLDER

logger = logging.getLogger("GonLogger")
logger.propagate = False
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(log_path + 'logfile.log')
formatter = logging.Formatter('%(message)s')
file_handler.setFormatter(formatter)

# add file handler to logger
logger.addHandler(file_handler)

logger.info('BED Classifier.\n' +  
            '\tOne Head.\n' +
            '\tWeighted for Precision.\n' +
            '\tBrevitas Default.\n'+ 
            '\tDataset images divided by 255.\n')

# Hyperparameters Log

In [4]:
''' ============================
    Print Config Values
============================ '''
print('\nDatasets Length')
print(f'\tTrain and Val: {"Full" if config.DS_LEN == None else config.DS_LEN}')
print(f'\nLoad Model: {config.LOAD_MODEL}')
if (config.LOAD_MODEL == True):
    print(f'\tModel: {config.LOAD_MODEL_FILE}')
print(f'Device: {config.DEVICE}')
print('Optimizer:')
print(f'\tLearning Rate: {config.LEARNING_RATE}')
print(f'\tWeight Decay: {config.WEIGHT_DECAY}')
print('Scheduler:')
print(f'\tScheduler factor: {config.FACTOR}')
print(f'\tScheduler patience: {config.PATIENCE}')
print(f'\tScheduler threshold: {config.THRES}')
print(f'\tScheduler min learning rate: {config.MIN_LR}')
print(f'Batch Size: {config.BATCH_SIZE}')
print(f'Num Workers: {config.NUM_WORKERS}')
print(f'Pin Memory: {config.PIN_MEMORY}')
print(f'Epochs: {config.EPOCHS}')
print('\nIMG DIMS:')
print(f'\tWidth: {config.IMG_W}\n\tHeight: {config.IMG_H}')
print('\nBrevitas Config:')
print(f'\tFixed Point: {config.FIXED_POINT}')
print(f'\tWeights Bit Width: {config.WEIGHTS_BIT_WIDTH}')
print(f'\tBig Layers Weights Bit Width: {config.BIG_LAYERS_WEIGHTS_BIT_WIDTH}')
print(f'\tBias Bit Width: {config.BIAS_BIT_WIDTH}')
print(f'\tActivations Bit Width: {config.ACTIVATIONS_BIT_WIDTH}')

logger.info('\nDatasets Length')
logger.info(f'\tTrain and Val: {"Full" if config.DS_LEN == None else config.DS_LEN}')
logger.info(f'\nLoad Model: {config.LOAD_MODEL}')
if (config.LOAD_MODEL == True):
    logger.info(f'\tModel: {config.LOAD_MODEL_FILE}')
logger.info(f'\nDevice: {config.DEVICE}')
logger.info('Optimizer:')
logger.info(f'\tLearning Rate: {config.LEARNING_RATE}')
logger.info(f'\tWeight Decay: {config.WEIGHT_DECAY}')
logger.info('Scheduler:')
logger.info(f'\tScheduler factor: {config.FACTOR}')
logger.info(f'\tScheduler patience: {config.PATIENCE}')
logger.info(f'\tScheduler threshold: {config.THRES}')
logger.info(f'\tScheduler min learning rate: {config.MIN_LR}')
logger.info(f'\nBatch Size: {config.BATCH_SIZE}')
logger.info(f'Num Workers: {config.NUM_WORKERS}')
logger.info(f'Pin Memory: {config.PIN_MEMORY}')
logger.info(f'Epochs: {config.EPOCHS}')
logger.info('\nIMG DIMS:')
logger.info(f'\tWidth: {config.IMG_W}\n\tHeight: {config.IMG_H}')
logger.info('\nBrevitas Config:')
logger.info(f'\tFixed Point: {config.FIXED_POINT}')
logger.info(f'\tWeights Bit Width: {config.WEIGHTS_BIT_WIDTH}')
logger.info(f'\tBig Layers Weights Bit Width: {config.BIG_LAYERS_WEIGHTS_BIT_WIDTH}')
logger.info(f'\tBias Bit Width: {config.BIAS_BIT_WIDTH}')
logger.info(f'\tActivations Bit Width: {config.ACTIVATIONS_BIT_WIDTH}')


Datasets Length
	Train and Val: Full

Load Model: False
Device: cuda
Optimizer:
	Learning Rate: 0.001
	Weight Decay: 0.001
Scheduler:
	Scheduler factor: 0.8
	Scheduler patience: 2
	Scheduler threshold: 0.001
	Scheduler min learning rate: 1e-06
Batch Size: 64
Num Workers: 8
Pin Memory: True
Epochs: 100

IMG DIMS:
	Width: 224
	Height: 224

Brevitas Config:
	Fixed Point: True
	Weights Bit Width: 4
	Big Layers Weights Bit Width: 4
	Bias Bit Width: 4
	Activations Bit Width: 4


# Dataloaders

In [5]:
train_loader = dataloaders.get_train_loader()
val_loader = dataloaders.get_val_loader()


TRAIN DFIRE dataset


Corrupt JPEG data: 1 extraneous bytes before marker 0xd9


DFire Removed wrong images: 0
DFire empty images: 7833
DFire only smoke images: 4681
DFire only fire images: 944
DFire smoke and fire images: 3763

Train DFire dataset len: 17221

TRAIN FASDD UAV dataset
FASDD Removed wrong images: 0
FASDD empty images: 5994
FASDD only smoke images: 2541
FASDD only fire images: 105
FASDD smoke and fire images: 3911

Train FASDD UAV dataset len: 12551

VAL FASDD UAV dataset
FASDD Removed wrong images: 0
FASDD empty images: 3995
FASDD only smoke images: 1693
FASDD only fire images: 70
FASDD smoke and fire images: 2607

Val FASDD UAV dataset len: 8365

TRAIN FASDD CV dataset


Corrupt JPEG data: 1 extraneous bytes before marker 0xd9


FASDD Removed wrong images: 0
FASDD empty images: 19600
FASDD only smoke images: 11708
FASDD only fire images: 6276
FASDD smoke and fire images: 10076

Train FASDD CV dataset len: 47660

Val FASDD CV dataset


Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Corrupt JPEG data: 1 extraneous bytes before marker 0xd9


FASDD Removed wrong images: 0
FASDD empty images: 13066
FASDD only smoke images: 7804
FASDD only fire images: 4183
FASDD smoke and fire images: 6717

Val FASDD CV dataset len: 31770

Concatenate Train DFire and Train FASDD UAV datasets
Train dataset len: 29772
Concatenate with Val FASDD UAV dataset
Train dataset len: 38137
Concatenate with Train FASDD CV dataset
Train dataset len: 85797
Concatenate with Val FASDD CV dataset
Train dataset len: 117567

TEST DFire dataset
DFire Removed wrong images: 0
DFire empty images: 2005
DFire only smoke images: 1186
DFire only fire images: 220
DFire smoke and fire images: 895

Test dataset len: 4306

TEST FASDD UAV dataset
FASDD Removed wrong images: 0
FASDD empty images: 1997
FASDD only smoke images: 846
FASDD only fire images: 35
FASDD smoke and fire images: 1303

Test FASDD UAV dataset len: 4181

TEST FASDD CV dataset
FASDD Removed wrong images: 0
FASDD empty images: 6533
FASDD only smoke images: 3902
FASDD only fire images: 2091
FASDD smoke and 

### Plot Some Train Pictures

In [6]:
for i, (img, label) in enumerate(train_loader):

    plt.subplots(8,4, figsize=(8, 16))
    
    for idx in range(config.BATCH_SIZE):
        plt.subplot(8, 4, idx+1)
        plt.imshow(img[idx].permute(1, 2, 0))
        title = ""
        if label[idx][0] == 1 and label[idx][1] == 1:
            title += "Smoke and Fire"
        elif label[idx][0] == 1 and label[idx][1] == 0:
            title += "Only Smoke"
        elif label[idx][0] == 0 and label[idx][1] == 1:
            title += "Only Fire"
        else:
            title += "Empty"
        plt.title(title)
        
        if (idx == 31):
            break
    plt.tight_layout()
    plt.savefig(config.RUN_FOLDER + 'train_pictures.png')
    #plt.show()
    plt.close()
    break

### Plot Some Val Pictures

In [7]:
for i, (img, label) in enumerate(val_loader):

    plt.subplots(8,4, figsize=(8, 16))
    
    for idx in range(config.BATCH_SIZE):
        plt.subplot(8, 4, idx+1)
        plt.imshow(img[idx].permute(1, 2, 0))
        title = ""
        if label[idx][0] == 1 and label[idx][1] == 1:
            title += "Smoke and Fire"
        elif label[idx][0] == 1 and label[idx][1] == 0:
            title += "Only Smoke"
        elif label[idx][0] == 0 and label[idx][1] == 1:
            title += "Only Fire"
        else:
            title += "Empty"
        plt.title(title)
        
        if (idx == 31):
            break
    plt.tight_layout()
    plt.savefig(config.RUN_FOLDER + 'val_pictures.png')
    plt.close()
    break

# Load Model

In [8]:
import importlib
importlib.reload(cnv_model)

<module 'modules.brevitas.model_mobilenetv2_1M_Brevitas' from '/home/gmoreno/uav/code/classifier_my_mobilenetv2/modules/brevitas/model_mobilenetv2_1M_Brevitas.py'>

In [9]:
# model = cnv_model.MobileNetV2_MINI().to(config.DEVICE)

model = cnv_model.MobileNetV2_1M().to(config.DEVICE)

### Optimizer and Scheduler

In [10]:
optimizer = optim.Adam(model.parameters(), 
                       lr=config.LEARNING_RATE, 
                       weight_decay=config.WEIGHT_DECAY)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                 mode='min',
                                                 factor=config.FACTOR, 
                                                 patience=config.PATIENCE, 
                                                 threshold=config.THRES, 
                                                 threshold_mode='abs',
                                                 min_lr=config.MIN_LR)



### Parameters

In [11]:
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'\nTrainable parameters = {n_trainable}')
logger.info(f'\nTrainable parameters = {n_trainable}')

n_params = parameters_to_vector(model.parameters()).numel()
print(f'Total parameters = {n_params}\n')
logger.info(f'Total parameters = {n_params}\n')


Trainable parameters = 1078274
Total parameters = 1078274



### Check Model Shape: Random Input

In [12]:
dummy_input = np.random.rand(4, config.NUM_CHANNELS, config.IMG_H, config.IMG_W)
dummy_input = torch.tensor(dummy_input, dtype=torch.float32, device=config.DEVICE)
out_test = model(dummy_input)
print(f'Model shape is {out_test}')
#print(f'BED Model Arquitecture\n{cnv_model}')

  return super().rename(names)


Model shape is tensor([[ -0.1089, -23.5607],
        [ -0.1488,   6.7013],
        [ -6.4327, -15.8676],
        [ -1.7720,   1.8813]], device='cuda:0', grad_fn=<AddmmBackward0>)


### Torchinfo

In [13]:
print(summary(model, input_size=(1, config.NUM_CHANNELS, config.IMG_H, config.IMG_W)))

Layer (type:depth-idx)                                                           Output Shape              Param #
MobileNetV2_1M                                                                   [1, 2]                    --
├─Sequential: 1-1                                                                [1, 320, 7, 7]            869,984
│    └─QuantIdentity: 2-1                                                        [1, 3, 224, 224]          --
│    │    └─ActQuantProxyFromInjector: 3-1                                       [1, 3, 224, 224]          --
│    │    └─ActQuantProxyFromInjector: 3-2                                       [1, 3, 224, 224]          --
│    └─Sequential: 2-2                                                           [1, 32, 112, 112]         --
│    │    └─QuantConv2d: 3-3                                                     [1, 32, 112, 112]         864
│    │    └─BatchNorm2d: 3-4                                                     [1, 32, 112, 112]         64

# Loss Function

In [14]:
if config.LOSS_FN == "BCE":
    print(f'Loss Function: BCE')
    logger.info(f'\nLoss Function: BCE')
    print(f'Smoke Precision Weight: {config.SMOKE_PRECISION_WEIGHT}')
    logger.info(f'Smoke Precision Weight: {config.SMOKE_PRECISION_WEIGHT}')
    loss_fn = loss.BCE_LOSS(device=config.DEVICE, smoke_precision_weight=config.SMOKE_PRECISION_WEIGHT)
else:
    print("Wrong loss function")
    logger.info("Wrong loss function")
    raise SystemExit("Wrong loss function")

Loss Function: BCE
Smoke Precision Weight: 0.8


# Loggers and Plotters for Losses and Metrics

In [15]:
train_losses_logger = utils.LogLosses()
train_metrics_logger = utils.LogMetrics()
lr_logger = utils.LogLR(log_path=config.PLOTS_FOLDER)

val_losses_logger = utils.LogLosses()
val_metrics_logger = utils.LogMetrics()

loss_plotter = utils.PlotMetrics(log_path=config.PLOTS_FOLDER, model_name=config.MODEL, loss_or_metric='Loss')
metrics_plotter = utils.PlotMetrics(log_path=config.PLOTS_FOLDER, model_name=config.MODEL, loss_or_metric='Metric')

# Main Function to Train

In [16]:
def train_loop(model, start_epoch=0, epochs_to_train=config.EPOCHS):

    ''' ==============================================================
                                TRAINING LOOP
    ============================================================== '''
    start = datetime.datetime.now()
    start_time = start.strftime("%H:%M:%S")
    print(f'\n***Start Training: {start_time}\n')
    logger.info(f'\n***Start Training: {start_time}\n')
    
    # Start with infinite validation loss
    best_valid_loss = np.inf
    best_smoke_precision = 0. #torch.tensor([0.])
    smoke_f1_min_save = 0.9 #torch.tensor([0.9])
    best_mean_f1 = 0.

    if start_epoch == 0:
        epochs_plot = []
    else:
        epochs_plot = [i for i in range(start_epoch)]    

    end_epoch = start_epoch + epochs_to_train
        
    for epoch in range(start_epoch, end_epoch):

        print(f'\n=== EPOCH {epoch}/{end_epoch-1} ===')
        logger.info(f'\n=== EPOCH {epoch}/{end_epoch-1} ===')
        
        #====================== TRAINING ========================#
        current_lr = train_epoch.get_lr(optimizer=optimizer)
        logger.info(f'Learning Rate = {current_lr}\n')
        lr_logger.log_lr(current_lr)
                
        train_losses, train_metrics = train_epoch.train_fn(
            loader=train_loader, 
            model=model, 
            optimizer=optimizer, 
            loss_fn=loss_fn,
            device=config.DEVICE)
        
        train_losses_logger.update_metrics(train_losses)
        train_metrics_logger.update_metrics(train_metrics)
                
        logger.info(utils.print_metrics_to_logger("TRAIN Stats", train_losses, train_metrics))
        
        #===================== VALIDATING =======================#
        with torch.no_grad():
            val_losses, val_metrics = val_epoch.eval_fn(
                loader=val_loader, 
                model=model,                         
                loss_fn=loss_fn,
                device=config.DEVICE)
            
            scheduler.step(val_losses['Total'])
            
            val_losses_logger.update_metrics(val_losses)
            val_metrics_logger.update_metrics(val_metrics)

            logger.info(utils.print_metrics_to_logger("VAL Stats", val_losses, val_metrics))
            
        epochs_plot.append(epoch)

        loss_plotter.plot_all_metrics(
            train_losses_logger.get_metrics(),
            val_losses_logger.get_metrics(),
            epochs_plot)

        metrics_plotter.plot_all_metrics(
            train_metrics_logger.get_metrics(),
            val_metrics_logger.get_metrics(),
            epochs_plot)

        lr_logger.plot_lr(epochs_plot)
        #======================= SAVING =========================#
        if ( (epoch+1) % 5 ) == 0:
            save_name = config.WEIGHTS_FOLDER + config.MODEL + '_classifier__5epoch.pt'
            utils.save_checkpoint(epoch, model, optimizer, scheduler, save_name) 
            
        if best_valid_loss > val_losses['Total']:
            best_valid_loss = val_losses['Total']
            print(f"\nSaving model with new best validation loss: {best_valid_loss:.4f}")
            logger.info(f"Saving model with new best validation loss: {best_valid_loss:.4f}")
            save_name = config.WEIGHTS_FOLDER + config.MODEL + '_classifier__' + 'best_loss'  + '.pt'
            utils.save_checkpoint(epoch, model, optimizer, scheduler, save_name) 
            save_onnx = config.ONNX_FOLDER + config.MODEL + '_classifier__' + 'best_loss'  #+ '.onnx'
            utils.export_onnx(model, (1, config.NUM_CHANNELS, config.IMG_H, config.IMG_W), save_onnx, config.DEVICE)

        # # Save model if precision increases and F1 > 0.9
        # if ( best_smoke_precision < val_metrics['Precision'][0] ) and ( val_metrics['F1'][0] > smoke_f1_min_save ) :
        #     best_smoke_precision = val_metrics['Precision'][0]
        #     print(f"\nSaving model with new best smoke precision: {best_smoke_precision:.4f}")
        #     logger.info(f"Saving model with new best smoke precision: {best_smoke_precision:.4f}")
        #     save_precision_name = f'best_smoke__precision={np.round(best_smoke_precision, decimals=4)}__epoch={epoch}'
        #     save_name = config.WEIGHTS_FOLDER + config.MODEL + '_classifier__' + save_precision_name + '.pt'
        #     utils.save_checkpoint(epoch, model, optimizer, scheduler, save_name)  
        #     save_onnx = config.ONNX_FOLDER + config.MODEL + '_classifier__' + save_precision_name #+ '.onnx'
        #     utils.export_onnx(model, (1, config.NUM_CHANNELS, config.IMG_H, config.IMG_W), save_onnx, config.DEVICE)
            
        # Save model if precision > 0.9 and recall > 0.9
        if ( val_metrics['Precision'][0] > 0.9 ) and ( val_metrics['Recall'][0] > 0.9 ) :
            print("\nSaving model with precision > 0.9 and recall > 0.9")
            logger.info("Saving model with precision > 0.9 and recall > 0.9")
            save_pre_name = f'smoke__precision={np.round(val_metrics["Precision"][0], decimals=4)}__' 
            save_rec_name = f'recall={np.round(val_metrics["Recall"][0], decimals=4)}__'
            save_pre_rec_name = save_pre_name + save_rec_name + f'epoch={epoch}'
            save_name = config.WEIGHTS_FOLDER + config.MODEL + '_classifier__' + save_pre_rec_name + '.pt'
            utils.save_checkpoint(epoch, model, optimizer, scheduler, save_name) 
            save_onnx = config.ONNX_FOLDER + config.MODEL + '_classifier__' + save_pre_rec_name #+ '.onnx'
            utils.export_onnx(model, (1, config.NUM_CHANNELS, config.IMG_H, config.IMG_W), save_onnx, config.DEVICE)

        # Save model if best mean F1 increases
        val_f1_mean = (val_metrics['F1'][0] + val_metrics['F1'][1]) / 2
        if (val_f1_mean > best_mean_f1) :
            best_mean_f1 = val_f1_mean
            print(f'Saving model with best Mean F1: {best_mean_f1:.4f}')
            logger.info(f'Saving model with best Mean F1: {best_mean_f1:.4f}')
            save_f1_name = 'best_mean_F1'
            save_name = config.WEIGHTS_FOLDER + config.MODEL + '_classifier__' + save_f1_name + '.pt'
            utils.save_checkpoint(epoch, model, optimizer, scheduler, save_name) 
            save_onnx = config.ONNX_FOLDER + config.MODEL + '_classifier__' + save_f1_name #+ '.onnx'
            utils.export_onnx(model, (1, config.NUM_CHANNELS, config.IMG_H, config.IMG_W), save_onnx, config.DEVICE)
        
    logger.info('Saving last model')   
    torch.save(model.state_dict(), config.WEIGHTS_FOLDER + 'last_' + config.MODEL + '_classifier.pt') 
    
    #======================= FINISH =========================#
    end = datetime.datetime.now()
    end_time = end.strftime("%H:%M:%S")
    print(f'\n***Script finished: {end_time}\n')  
    print(f'Time elapsed: {end-start}')
    logger.info(f'\n***Script finished: {end_time}\n')  
    logger.info(f'Time elapsed: {end-start}')
    
    return model

In [17]:
# print(len(val_losses_logger.total))

# Training Loop

In [18]:
if __name__ == "__main__":
    
    print("Starting script\n")
    logger.info("Starting script\n")
    
    model = train_loop(model)

Starting script


***Start Training: 16:30:16


=== EPOCH 0/99 ===
Learning Rate = 0.001



Training:  34%|███████████████████████████████████████████████▍                                                                                           | 626/1836 [02:03<02:39,  7.57it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  42%|██████████████████████████████████████████████████████████▋                                                                                | 775/1836 [02:32<02:27,  7.17it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  45%|██████████████████████████████████████████████████████████████▋                                                                            | 828/1836 [02:43<03:44,  4.49it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  75%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                  | 1384/1836 [04:30<00:56,  8.01it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
143.579     |68.002      |75.577      



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:42<00:00,  8.87it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
89.920      |40.761      |49.159      
SMOKE -> Precision: 0.4973 - Recall: 0.1039 - Accuracy: 0.5280 - F1: 0.1719
FIRE -> Precision: 0.3121 - Recall: 0.0658 - Accuracy: 0.6501 - F1: 0.1087

Saving model with new best validation loss: 89.9198
Model exported to ONNX: experiments_brevitas/test_v03_1M_w1a1_full_ds/onnx/MY_MBLNET_V2_classifier__best_loss
Saving model with best Mean F1: 0.1403
Model exported to ONNX: experiments_brevitas/test_v03_1M_w1a1_full_ds/onnx/MY_MBLNET_V2_classifier__best_mean_F1

=== EPOCH 1/99 ===
Learning Rate = 0.001



Training:   6%|████████▍                                                                                                                                  | 112/1836 [00:23<03:51,  7.43it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:   8%|██████████▉                                                                                                                                | 145/1836 [00:29<04:13,  6.67it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  38%|█████████████████████████████████████████████████████▏                                                                                     | 703/1836 [02:20<02:56,  6.41it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  91%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋            | 1673/1836 [05:29<00:54,  2.98it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
314.347     |160.513     |153.834     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:44<00:00,  8.61it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
1919.706    |722.861     |1196.845    
SMOKE -> Precision: 0.0000 - Recall: 0.0000 - Accuracy: 0.5286 - F1: 0.0000
FIRE -> Precision: 0.3242 - Recall: 1.0000 - Accuracy: 0.3242 - F1: 0.4896
Saving model with best Mean F1: 0.2448
Model exported to ONNX: experiments_brevitas/test_v03_1M_w1a1_full_ds/onnx/MY_MBLNET_V2_classifier__best_mean_F1

=== EPOCH 2/99 ===
Learning Rate = 0.001



Training:  38%|████████████████████████████████████████████████████▊                                                                                      | 697/1836 [02:19<03:17,  5.77it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  47%|█████████████████████████████████████████████████████████████████▋                                                                         | 868/1836 [02:51<03:45,  4.29it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  74%|█████████████████████████████████████████████████████████████████████████████████████████████████████▋                                    | 1353/1836 [04:15<01:13,  6.53it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  75%|███████████████████████████████████████████████████████████████████████████████████████████████████████▉                                  | 1383/1836 [04:21<01:02,  7.26it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
716.785     |370.249     |346.537     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:35<00:00, 10.60it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
389.402     |191.279     |198.122     
SMOKE -> Precision: 0.4665 - Recall: 0.7368 - Accuracy: 0.4787 - F1: 0.5713
FIRE -> Precision: 0.3665 - Recall: 0.0164 - Accuracy: 0.6720 - F1: 0.0313
Saving model with best Mean F1: 0.3013
Model exported to ONNX: experiments_brevitas/test_v03_1M_w1a1_full_ds/onnx/MY_MBLNET_V2_classifier__best_mean_F1

=== EPOCH 3/99 ===
Learning Rate = 0.001



Training:  63%|██████████████████████████████████████████████████████████████████████████████████████▉                                                   | 1157/1836 [03:34<02:01,  5.61it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  74%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉                                    | 1357/1836 [04:09<01:13,  6.50it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍             | 1656/1836 [04:58<00:38,  4.69it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1815/1836 [05:24<00:03,  6.20it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
717.988     |350.773     |367.215     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:36<00:00, 10.42it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
1796.202    |579.394     |1216.809    
SMOKE -> Precision: 0.4702 - Recall: 0.8234 - Accuracy: 0.4793 - F1: 0.5985
FIRE -> Precision: 0.3242 - Recall: 1.0000 - Accuracy: 0.3242 - F1: 0.4896
Saving model with best Mean F1: 0.5441
Model exported to ONNX: experiments_brevitas/test_v03_1M_w1a1_full_ds/onnx/MY_MBLNET_V2_classifier__best_mean_F1

=== EPOCH 4/99 ===
Learning Rate = 0.0008



Training:   9%|███████████▉                                                                                                                               | 157/1836 [00:29<04:50,  5.79it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  45%|██████████████████████████████████████████████████████████████▊                                                                            | 829/1836 [02:26<02:06,  7.98it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  61%|████████████████████████████████████████████████████████████████████████████████████▍                                                     | 1124/1836 [03:20<01:31,  7.82it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  87%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                  | 1597/1836 [04:40<00:33,  7.10it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
516.633     |258.026     |258.606     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:35<00:00, 10.69it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
594.979     |342.930     |252.049     
SMOKE -> Precision: 0.4715 - Recall: 0.9996 - Accuracy: 0.4717 - F1: 0.6408
FIRE -> Precision: 0.0000 - Recall: 0.0000 - Accuracy: 0.6758 - F1: 0.0000

=== EPOCH 5/99 ===
Learning Rate = 0.0008



Training:  20%|███████████████████████████▋                                                                                                               | 365/1836 [01:05<03:48,  6.43it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  52%|████████████████████████████████████████████████████████████████████████▌                                                                  | 959/1836 [02:51<02:12,  6.60it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▏                               | 1413/1836 [07:00<02:53,  2.44it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  83%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                        | 1516/1836 [08:09<03:58,  1.34it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
449.757     |241.442     |208.315     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:54<00:00,  7.00it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
444.539     |145.849     |298.690     
SMOKE -> Precision: 0.4702 - Recall: 0.8458 - Accuracy: 0.4780 - F1: 0.6044
FIRE -> Precision: 0.3206 - Recall: 0.6643 - Accuracy: 0.4349 - F1: 0.4325

=== EPOCH 6/99 ===
Learning Rate = 0.0008



Training:   0%|                                                                                                                                                     | 0/1836 [00:00<?, ?it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  61%|███████████████████████████████████████████████████████████████████████████████████▌                                                      | 1112/1836 [08:42<02:56,  4.09it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  72%|███████████████████████████████████████████████████████████████████████████████████████████████████▏                                      | 1319/1836 [10:23<05:48,  1.49it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋              | 1645/1836 [12:47<00:53,  3.59it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
437.208     |224.757     |212.451     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:42<00:00,  8.88it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
1362.372    |1194.229    |168.143     
SMOKE -> Precision: 0.4714 - Recall: 1.0000 - Accuracy: 0.4714 - F1: 0.6408
FIRE -> Precision: 0.3272 - Recall: 0.2560 - Accuracy: 0.5882 - F1: 0.2872

=== EPOCH 7/99 ===
Learning Rate = 0.00064



Training:  15%|████████████████████▋                                                                                                                      | 273/1836 [00:55<05:59,  4.35it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  17%|███████████████████████▎                                                                                                                   | 308/1836 [01:03<07:11,  3.54it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  37%|███████████████████████████████████████████████████▏                                                                                       | 676/1836 [02:16<04:07,  4.68it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  43%|███████████████████████████████████████████████████████████▍                                                                               | 785/1836 [02:37<03:06,  5.63it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
356.712     |180.677     |176.035     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:43<00:00,  8.84it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
427.639     |55.318      |372.321     
SMOKE -> Precision: 0.4729 - Recall: 0.7254 - Accuracy: 0.4894 - F1: 0.5726
FIRE -> Precision: 0.3240 - Recall: 0.9807 - Accuracy: 0.3306 - F1: 0.4871

=== EPOCH 8/99 ===
Learning Rate = 0.00064



Training:  37%|███████████████████████████████████████████████████▌                                                                                       | 681/1836 [02:14<02:56,  6.56it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  39%|██████████████████████████████████████████████████████▊                                                                                    | 724/1836 [02:23<03:27,  5.36it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  48%|██████████████████████████████████████████████████████████████████▌                                                                        | 880/1836 [02:53<02:15,  7.06it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎          | 1693/1836 [05:33<00:29,  4.85it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
358.091     |177.997     |180.094     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:43<00:00,  8.83it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
278.741     |57.196      |221.544     
SMOKE -> Precision: 0.4681 - Recall: 0.5460 - Accuracy: 0.4935 - F1: 0.5040
FIRE -> Precision: 0.3242 - Recall: 1.0000 - Accuracy: 0.3242 - F1: 0.4896

=== EPOCH 9/99 ===
Learning Rate = 0.00064



Training:   1%|█                                                                                                                                           | 14/1836 [00:04<06:42,  4.52it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:   4%|██████                                                                                                                                      | 80/1836 [00:18<05:34,  5.25it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:   9%|████████████                                                                                                                               | 159/1836 [00:33<05:18,  5.27it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  11%|███████████████▋                                                                                                                           | 207/1836 [00:43<04:19,  6.27it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
374.796     |183.178     |191.618     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:42<00:00,  8.88it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
142.699     |68.911      |73.789      
SMOKE -> Precision: 0.4611 - Recall: 0.0067 - Accuracy: 0.5280 - F1: 0.0132
FIRE -> Precision: 0.3185 - Recall: 0.2438 - Accuracy: 0.5857 - F1: 0.2762

=== EPOCH 10/99 ===
Learning Rate = 0.0005120000000000001



Training:   1%|█▉                                                                                                                                          | 26/1836 [00:05<04:49,  6.26it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:   5%|██████▎                                                                                                                                     | 83/1836 [00:17<03:27,  8.46it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  62%|█████████████████████████████████████████████████████████████████████████████████████▋                                                    | 1140/1836 [03:50<02:42,  4.28it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  69%|██████████████████████████████████████████████████████████████████████████████████████████████▉                                           | 1263/1836 [04:13<01:47,  5.33it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
275.614     |140.017     |135.597     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:43<00:00,  8.71it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
868.003     |729.304     |138.699     
SMOKE -> Precision: 0.4714 - Recall: 1.0000 - Accuracy: 0.4714 - F1: 0.6408
FIRE -> Precision: 0.5000 - Recall: 0.0003 - Accuracy: 0.6758 - F1: 0.0005

=== EPOCH 11/99 ===
Learning Rate = 0.0005120000000000001



Training:   8%|██████████▌                                                                                                                                | 140/1836 [00:29<03:43,  7.59it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  50%|████████████████████████████████████████████████████████████████████▊                                                                      | 909/1836 [03:16<02:23,  6.44it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  63%|███████████████████████████████████████████████████████████████████████████████████████                                                   | 1159/1836 [04:15<02:29,  4.53it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  91%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋            | 1672/1836 [06:18<00:27,  5.98it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
297.986     |152.190     |145.796     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:54<00:00,  7.03it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
681.125     |361.534     |319.592     
SMOKE -> Precision: 0.0000 - Recall: 0.0000 - Accuracy: 0.5286 - F1: 0.0000
FIRE -> Precision: 0.3242 - Recall: 1.0000 - Accuracy: 0.3242 - F1: 0.4896

=== EPOCH 12/99 ===
Learning Rate = 0.0005120000000000001



Training:  19%|██████████████████████████▉                                                                                                                | 356/1836 [01:31<11:49,  2.09it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  21%|█████████████████████████████▏                                                                                                             | 386/1836 [01:37<03:31,  6.84it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  22%|██████████████████████████████▍                                                                                                            | 402/1836 [01:41<03:38,  6.56it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  91%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████             | 1664/1836 [06:45<00:25,  6.64it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
313.282     |157.812     |155.470     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:52<00:00,  7.26it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
593.855     |528.691     |65.164      
SMOKE -> Precision: 0.4714 - Recall: 1.0000 - Accuracy: 0.4714 - F1: 0.6408
FIRE -> Precision: 0.0625 - Recall: 0.0001 - Accuracy: 0.6752 - F1: 0.0003

=== EPOCH 13/99 ===
Learning Rate = 0.0004096000000000001



Training:  16%|██████████████████████▎                                                                                                                    | 294/1836 [01:13<03:35,  7.17it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  41%|█████████████████████████████████████████████████████████▍                                                                                 | 759/1836 [03:05<02:48,  6.41it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  56%|████████████████████████████████████████████████████████████████████████████▉                                                             | 1024/1836 [04:08<03:27,  3.91it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  74%|██████████████████████████████████████████████████████████████████████████████████████████████████████▋                                   | 1366/1836 [05:30<01:04,  7.24it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
206.812     |99.045      |107.767     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:50<00:00,  7.55it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
184.770     |72.442      |112.328     
SMOKE -> Precision: 0.3200 - Recall: 0.0007 - Accuracy: 0.5282 - F1: 0.0014
FIRE -> Precision: 0.0000 - Recall: 0.0000 - Accuracy: 0.6758 - F1: 0.0000

=== EPOCH 14/99 ===
Learning Rate = 0.0004096000000000001



Training:  16%|██████████████████████▎                                                                                                                    | 294/1836 [01:12<06:23,  4.02it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  16%|██████████████████████▎                                                                                                                    | 295/1836 [01:12<05:16,  4.87it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▋                               | 1420/1836 [05:12<01:17,  5.40it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊              | 1647/1836 [05:57<00:26,  7.09it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
290.441     |146.506     |143.936     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:42<00:00,  8.85it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
807.903     |390.541     |417.362     
SMOKE -> Precision: 0.4714 - Recall: 1.0000 - Accuracy: 0.4714 - F1: 0.6408
FIRE -> Precision: 0.3242 - Recall: 1.0000 - Accuracy: 0.3242 - F1: 0.4896
Saving model with best Mean F1: 0.5652
Model exported to ONNX: experiments_brevitas/test_v03_1M_w1a1_full_ds/onnx/MY_MBLNET_V2_classifier__best_mean_F1

=== EPOCH 15/99 ===
Learning Rate = 0.0004096000000000001



Training:  25%|██████████████████████████████████▉                                                                                                        | 462/1836 [01:38<06:54,  3.31it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  40%|██████████████████████████████████████████████████████▉                                                                                    | 726/1836 [02:32<06:02,  3.06it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  62%|██████████████████████████████████████████████████████████████████████████████████████▏                                                   | 1147/1836 [03:59<02:06,  5.45it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1813/1836 [06:20<00:04,  4.63it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
349.341     |178.320     |171.022     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:44<00:00,  8.63it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
144.253     |67.963      |76.289      
SMOKE -> Precision: 0.0000 - Recall: 0.0000 - Accuracy: 0.5286 - F1: 0.0000
FIRE -> Precision: 0.3235 - Recall: 0.9779 - Accuracy: 0.3299 - F1: 0.4862

=== EPOCH 16/99 ===
Learning Rate = 0.0003276800000000001



Training:  33%|█████████████████████████████████████████████▊                                                                                             | 605/1836 [02:11<04:17,  4.78it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  47%|█████████████████████████████████████████████████████████████████▎                                                                         | 862/1836 [03:12<04:04,  3.98it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  68%|██████████████████████████████████████████████████████████████████████████████████████████████                                            | 1252/1836 [04:31<01:24,  6.89it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  71%|█████████████████████████████████████████████████████████████████████████████████████████████████▋                                        | 1300/1836 [04:41<01:19,  6.78it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
199.177     |98.188      |100.988     



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:42<00:00,  8.90it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
279.104     |223.790     |55.313      
SMOKE -> Precision: 0.0000 - Recall: 0.0000 - Accuracy: 0.5286 - F1: 0.0000
FIRE -> Precision: 0.3074 - Recall: 0.0110 - Accuracy: 0.6713 - F1: 0.0213

=== EPOCH 17/99 ===
Learning Rate = 0.0003276800000000001



Training:  10%|█████████████▏                                                                                                                             | 175/1836 [00:37<05:29,  5.03it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  24%|████████████████████████████████▉                                                                                                          | 435/1836 [01:30<03:06,  7.53it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  58%|████████████████████████████████████████████████████████████████████████████████▋                                                         | 1074/1836 [03:42<01:56,  6.56it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:  81%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                         | 1494/1836 [05:08<00:43,  7.88it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training: 100%|█

Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
175.058     |88.214      |86.844      



Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380/380 [00:43<00:00,  8.77it/s]


Total Loss  |Smoke Loss  |Fire Loss   
------------ ------------ ------------
341.479     |162.823     |178.656     
SMOKE -> Precision: 0.4714 - Recall: 1.0000 - Accuracy: 0.4714 - F1: 0.6408
FIRE -> Precision: 0.3240 - Recall: 0.9803 - Accuracy: 0.3304 - F1: 0.4870

=== EPOCH 18/99 ===
Learning Rate = 0.0003276800000000001



Training:   1%|█▊                                                                                                                                          | 24/1836 [00:07<06:38,  4.54it/s]Corrupt JPEG data: 1 extraneous bytes before marker 0xd9
Training:   2%|██▊                                                                                                                                         | 37/1836 [00:10<08:46,  3.41it/s]


KeyboardInterrupt: 

In [None]:
# if __name__ == "__main__":
    
#     print("Train More script\n")
#     logger.info("Train More\n")
    
#     model = train_loop(model, start_epoch=75, epochs_to_train=15)

# Test with DFire MINI Dataset: Train and Test

In [None]:
# import importlib
# importlib.reload(config)
# importlib.reload(dataloaders)

In [None]:
train_dfire_mini_loader = dataloaders.get_dfire_mini_train_loader()
test_dfire_mini_loader = dataloaders.get_dfire_mini_test_loader()

### Load Checkpoint with Best F1 Mean

In [None]:
model_path = config.WEIGHTS_FOLDER + config.MODEL + '_classifier__best_mean_F1.pt'

In [None]:
utils.load_checkpoint(
    model_path, 
    model=model, 
    optimizer=optimizer, 
    scheduler=scheduler, 
    device=config.DEVICE)

In [None]:
model.to('cuda');

### Whole Test Loader, to check it is the same as training

In [None]:
with torch.no_grad():
    val_losses, val_metrics = val_epoch.eval_fn(
        loader=val_loader, 
        model=model,                         
        loss_fn=loss_fn,
        device=config.DEVICE)

In [None]:
logger.info('\nTesting with FULL TEST LOADER')  
#logger.info(val_losses)
logger.info(val_metrics)

### Train DFire MINI

In [None]:
with torch.no_grad():
    val_losses, val_metrics = val_epoch.eval_fn(
        loader=train_dfire_mini_loader, 
        model=model,                         
        loss_fn=loss_fn,
        device=config.DEVICE)

In [None]:
logger.info('\nTesting with DFire MINI TRAIN after LOADING F1 Best Mean CHECKPOINT')  
#logger.info(val_losses)
logger.info(val_metrics)

### Test DFire MINI

In [None]:
with torch.no_grad():
    val_losses, val_metrics = val_epoch.eval_fn(
        loader=test_dfire_mini_loader, 
        model=model,                         
        loss_fn=loss_fn,
        device=config.DEVICE)

In [None]:
logger.info('\nTesting with DFire MINI TEST after LOADING F1 Best Mean CHECKPOINT')  
#logger.info(val_losses)
logger.info(val_metrics)

# Convert the Model to BIPOLAR OUT

In [None]:
import brevitas.nn as qnn
import torch.nn as nn

In [None]:
class CNV_BIPOLAR_OUT(nn.Module):
    def __init__(self, base_model):
        super(CNV_BIPOLAR_OUT, self).__init__()
        self.base_model = base_model
        self.qnt_output = qnn.QuantIdentity(
            quant_type='binary', 
            scaling_impl_type='const',
            bit_width=1, min_val=-1.0, max_val=1.0)

    def forward(self, x):
        x = self.base_model(x)
        x = self.qnt_output(x)
        return x

In [None]:
cnv_bipolar_out = CNV_BIPOLAR_OUT(model).to(config.DEVICE)

### New Evaluation for BIPOLAR Out Model

In [None]:
from tqdm import tqdm

def eval_bipolar_fn(loader, model, device):
    
    model.eval()
    loop = tqdm(loader, desc='Validating', leave=True)

    for batch_idx, (x, y) in enumerate(loop):
        x, y = x.to(device), y.to(device)
        yhat = model(x)

        # print(y.shape)
        # print(yhat.shape)
        
        yhat[yhat < 1] = 0
    
        metrics.precision_metric.update(yhat, y)
        metrics.recall_metric.update(yhat, y)
        metrics.accuracy_metric.update(yhat, y)
        metrics.f1_metric.update(yhat, y)
   
    precision = metrics.precision_metric.compute()
    recall = metrics.recall_metric.compute()
    accuracy = metrics.accuracy_metric.compute()
    f1 = metrics.f1_metric.compute()
    
    metrics.precision_metric.reset()
    metrics.recall_metric.reset()
    metrics.accuracy_metric.reset()
    metrics.f1_metric.reset()

    print(f'SMOKE -> Precision: {precision[0]:.4f} - Recall: {recall[0]:.4f} - Accuracy: {accuracy[0]:.4f} - F1: {f1[0]:.4f}')
    print(f'FIRE -> Precision: {precision[1]:.4f} - Recall: {recall[1]:.4f} - Accuracy: {accuracy[1]:.4f} - F1: {f1[1]:.4f}')
    
    return (
        {
        'Accuracy': [accuracy[0].item(), accuracy[1].item()],
        'Precision': [precision[0].item(), precision[1].item()],
        'Recall': [recall[0].item(), recall[1].item()],
        'F1': [f1[0].item(), f1[1].item()] 
        }
    )

### Full DS

In [None]:
cnv_bipolar_out.eval()
with torch.no_grad():
    val_metrics = eval_bipolar_fn(
        loader=val_loader, 
        model=cnv_bipolar_out,                         
        device=config.DEVICE)

### Mini Train

In [None]:
with torch.no_grad():
    val_metrics = eval_bipolar_fn(
        loader=train_dfire_mini_loader, 
        model=cnv_bipolar_out,                         
        device=config.DEVICE)

### Mini Test

In [None]:
with torch.no_grad():
    val_metrics = eval_bipolar_fn(
        loader=test_dfire_mini_loader, 
        model=cnv_bipolar_out,                         
        device=config.DEVICE)

# Export Bipolar to QONNX

In [None]:
save_f1_name = 'best_mean_F1'
save_bipolar_onnx = config.ONNX_FOLDER + config.MODEL + '_classifier__' + save_f1_name + '__BIPOLAR_Out'
utils.export_onnx(cnv_bipolar_out, (1, config.NUM_CHANNELS, config.IMG_H, config.IMG_W), save_bipolar_onnx, config.DEVICE)