# Runs Detector and Classifier on Sample Test Images

## Load Libraries

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchsummary import summary
from torchvision.transforms import transforms

## Setup Paths and Other Stuff

In [None]:
randomSeed = 2024
np.random.seed(randomSeed)
torch.manual_seed(randomSeed)

print(f'PyTorch Version: {torch.__version__}')
device = 'cpu'

trackingLabels = ['person', 'car']
colors = ['m', 'b']
maxObjects = 20
numClasses = len(trackingLabels)

classifierModelPath = './classificationModel_20240505_2.pt'
detectorModelPath = './DetectionModel_20240505_new.pt'
sampleImagesPath = os.path.join(os.getcwd(),'finalPackage/testingImages')
if not os.path.isdir(sampleImagesPath):
    sampleImagesPath = os.path.join(os.getcwd(),'testingImages')

## Define Sizing Function and Plotting Function

In [None]:
def sizeBoxImage(img, bbox):
    x, y, w, h = bbox
    x, y, w, h = int(x), int(y), int(w), int(h)
    img = img[y:y+h, x:x+w]
    if w <= 80 and h <= 80:
        img = cv2.resize(img, (80, 80))
    elif w <= 160 and h <= 160:
        img = cv2.resize(img, (160, 160))
        img = cv2.pyrDown(img)
    if w <= 320 and h <= 320:
        img = cv2.resize(img, (320, 320))
        img = cv2.pyrDown(img)
        img = cv2.pyrDown(img)
    else:
        img = cv2.resize(img, (640, 640))
        img = cv2.pyrDown(img)
        img = cv2.pyrDown(img)
        img = cv2.pyrDown(img) ## final image size is 80x80
    return img

def plot_sample(image, labels, bboxes, num=None, showbb=True):
    
    plt.imshow(image.squeeze(), cmap="gray")  # Convert (C, H, W) tensor to (H, W, C) for plotting
    if showbb:
        if num:
            plt.title(f'Number of Objects: {num}')
        try:
            for bbox, label in zip(bboxes, labels):
                try:
                    label = int(label)
                except:
                    label = int(label[0])
                x, y, w, h = bbox
                x, y, w, h = int(x), int(y), int(w), int(h)
                plt.gca().add_patch(plt.Rectangle((x, y), w, h, linewidth=1, edgecolor=colors[label], facecolor='none'))
                plt.text(x, y-5, f'{trackingLabels[label]}', color=colors[label])
        except:
            try:
                x, y, w, h = bboxes
            except:
                x, y, w, h = bboxes[0]
                label = label[0]
            label = int(label)
            x, y, w, h = int(x), int(y), int(w), int(h)
            plt.gca().add_patch(plt.Rectangle((x, y), w, h, linewidth=1, edgecolor=colors[labels], facecolor='none'))
            plt.text(x, y-5, f'{trackingLabels[labels]}', color=colors[labels])
        plt.axis('off')
    else:
        plt.title(f'Sample {trackingLabels[labels]}')

## Load Classifier

In [None]:
class Classifier(nn.Module):
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        inplace = False
        self.batchNorm = nn.BatchNorm2d(1)
        resnet = models.resnet18(weights=None)
        
        self.relu = nn.ReLU(inplace=inplace)
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1, bias=False)
        
        layers = list(resnet.children())[:6]
        self.features = nn.Sequential(*layers)
        
        self.classifier = nn.Sequential(nn.Linear(128, 32), 
                                         nn.ReLU(inplace=inplace), 
                                         nn.Linear(32, num_classes),
                                         nn.Softmax(1))
        
    def forward(self, x):
        x = self.batchNorm(x)
        x = self.features(x)
        x = self.relu(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        
        classifier_output = self.classifier(x)
        return classifier_output
    
classifier = Classifier(numClasses)

summary(classifier, (1,80,80))
print()

In [None]:
criterionClassifier = nn.CrossEntropyLoss()
optimizerClassifier = optim.Adam(classifier.parameters(), lr = 0.001)   
classifier = Classifier(numClasses)
checkpoint = torch.load(classifierModelPath)
classifier.load_state_dict(checkpoint['model_state_dict'])
optimizerClassifier.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
trainLoss = [checkpoint['loss']]

## Load Detector

In [None]:
class Detector(nn.Module):
    def __init__(self, max_objects):
        super(Detector, self).__init__()

        resnet = models.resnet18(weights='DEFAULT')
        
        self.relu = nn.ReLU(inplace=True)
        
        # Freeze parameters of the ResNet layers
        for param in resnet.parameters():
            param.requires_grad = False
        
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=1, bias=False)
        
        layers = list(resnet.children())[:-1]
        self.features = nn.Sequential(*layers)
        
        self.max_objects = max_objects
        
        self.bb = nn.Sequential(nn.Linear(512, 64),
                                nn.ReLU(inplace=True),
                                nn.Linear(64, 4*max_objects))
        
    def forward(self, x):
        x = self.features(x)
        x = self.relu(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        
        x = self.bb(x)
        return x
    
detector = Detector(maxObjects)

summary(detector, (1,640,640))
print()

In [None]:
# Define optimizer and learning rate scheduler
optimizerDetector = optim.Adam(detector.parameters(), lr=0.001)
detector = Detector(maxObjects)
checkpoint = torch.load(detectorModelPath)
detector.load_state_dict(checkpoint['model_state_dict'])
optimizerDetector.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
trainLossDetect = checkpoint['loss']  

## Load Sample Imagery and Call Detector, then Classifier

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL image to PyTorch tensor
    transforms.Normalize(mean=[262.6299], std=[117.4840])  # Normalize image
])

for ifile in os.listdir(sampleImagesPath):
    image_file = os.path.join(sampleImagesPath, f"{ifile}")

    image = cv2.imread(str(image_file), cv2.IMREAD_GRAYSCALE).astype(np.float32)
    image = cv2.resize(image, (640, 640))
    image = transform(image)
    
    predBoxes = detector(image.unsqueeze(0))  ### DETECTOR
    predBoxes = predBoxes.view(-1, maxObjects, 4)
    
    predLabels = []
    predScores = []
    boxesKeep = []
    for box in predBoxes.squeeze():
        if not torch.all(box > 1):
            continue

        img = sizeBoxImage(image.squeeze().numpy(), box)
        img = transform(img).unsqueeze(0)
        output = classifier(img.to(device))  ### CLASSIFIER
        predLabels.append(torch.max(output, 1)[1].data.squeeze().item())
        predScores.append(torch.max(output, 1)[0].data.squeeze().item())
        boxesKeep.append(box.tolist())
        
    try:
        boxesKeep = torch.tensor(boxesKeep)
    except:
        boxesKeep = torch.tensor(boxesKeep)
    plt.figure(figsize=(10,10))
    plot_sample(image, predLabels, boxesKeep.tolist())
    plt.show()