<a href="https://colab.research.google.com/github/stepanbabayan/DFBS-Object-Classification/blob/colab/infer_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Clone Repository

In [None]:
!git clone https://github.com/stepanbabayan/DFBS-Object-Classification.git

## Switch to colab branch 

In [None]:
mv ./DFBS-Object-Classification/ ./Model/

In [None]:
cd Model

In [None]:
!git checkout colab

In [None]:
# !git pull

## Additional Environment Setups

In [None]:
import zipfile
with zipfile.ZipFile('./data_Inference.zip', 'r') as zip_ref:
    zip_ref.extractall('')

In [None]:
import sys
sys.path.append('Model/')

## Imports

In [None]:
import os

import torch
from torchsummary import summary

import numpy as np

# from Model.test import evaluate
from Model.load_data import load_images
from Model.models import Model
from Model.infer import infer_evaluate

# from sklearn.metrics import classification_report

## Environment variables

In [None]:
use_gpu = True

In [None]:
# Training Device
if use_gpu:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device('cpu')

print(f'Device: {device}')

## Data path

In [None]:
# Choose the dataset
num_classes = 10 # Choices: {5, 6, 10}

In [None]:
classes_5 = ['C-H', 'C-N', 'Mrk SB', 'sdA', 'sdB']
classes_6 = ['C-H', 'C-N', 'Mrk Abs', 'Mrk SB', 'sdA', 'sdB']
classes_10 = ['C Ba', 'C-H', 'C-N', 'C-R', 'Mrk Abs', 'Mrk AGN', 'Mrk SB', 'sdA', 'sdB', 'sdO']

assert num_classes in {5, 6, 10}

if num_classes == 10:
    classes = classes_10
    checkpoint_name = 'Dense_10_Focal_25_3_Final/139.pth'
elif num_classes == 5:
    classes = classes_5
    checkpoint_name = 'Dense_5_High_Focal_25_3_Final/59.pth'
else:
    classes = classes_6
    checkpoint_name = 'Dense_6_High_Focal_25_3_Final/136.pth'


In [None]:
# Datasets
data_root = f'./data_Inference'

In [None]:
infer_dir = f'{data_root}/Subtypes/19_2'
input_shape = (160, 50)

In [None]:
print('Num classes:', num_classes)

## Project Parameters

In [None]:
root_dir = os.path.abspath('./')

In [None]:
# Checkpoints are saved in Checkpoint folder
checkpoint_path = f'{root_dir}/Checkpoint/{checkpoint_name}'

print('Selected checkpoint:', checkpoint_path)

## Testing Parameters

In [None]:
# Batch sizes
infer_batch_size = 256

## Data Loaders

In [None]:
infer_data = load_images(
    path=infer_dir, batch_size=infer_batch_size,
    domain='inference', _drop_last=False
)

## Training Setup

In [None]:
# Model choices: arch = any(['default', 'default_prev', 'default_bn', 'mobilenet', 'resnet'])
#   default: the proposed network
#   default_bn: similar to the proposed, but with more BatchNorm layers
#   default_prev: the network proposed in the previous work
#   mobilenet: MobileNetV2
#   resnet: Resnet

net = Model(num_classes=num_classes, input_shape=input_shape, arch='default').to(device)

### Layers

In [None]:
print(net)

### Output Summary

In [None]:
summary(net, (1, 160, 50))

In [None]:
# Setting the network up for evaluation
if device == torch.device('cpu'):
    net.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
else:
    net.load_state_dict(torch.load(checkpoint_path))
net.eval()

print('Loaded checkpoint:', checkpoint_path)

In [None]:
print('\nEvaluation started:')

predictions = infer_evaluate(dataloader=infer_data, model=net, device=device)

In [None]:
len(predictions)

In [None]:
np.unique(predictions, return_counts=True)