In [None]:
import os
from copy import deepcopy
import numpy as np
import os

import torch
import cv2 as cv

import albumentations as A
from tqdm import tqdm

from torch.utils.data import DataLoader

In [None]:
# File paths
extracted_data_path = 'extracted-data-path'
model_path = "your-model-path"

# Image processing parameters
RESIZE_SIZE = (128, 128)
DAY_TO_USE = 8
IMAGES_IN_A_DAY = 4
SEQ_LEN = DAY_TO_USE * IMAGES_IN_A_DAY

# Device for inference
device = "cuda"

# Class labels dictionary
class_dict = {0: 'Col-0', 1: 'Cvi-0', 2: 'Is-1', 3: 'Kz-9', 4: 'Ler-1', 5: 'TOU-I-17', 6: 'Uk-1', 7: 'Zdr-1'}

# List of available class outputs
available_outputs = [0, 1]

In [None]:
def _fix_images_paths(images_paths, targets):
    images_paths = deepcopy(images_paths)
    targets = deepcopy(targets)

    while os.path.basename(images_paths[0]).split('_')[1:4] != os.path.basename(images_paths[1]).split('_')[1:4]:
        images_paths = images_paths[1:]
        targets = targets[1:]
    while os.path.basename(images_paths[-1]).split('_')[1:4] != os.path.basename(images_paths[-2]).split('_')[1:4]:
        images_paths = images_paths[:-1]
        targets = targets[:-1]
    return images_paths, targets


def prepare_dataset(image_files, targets, fix_len=10, ):
    images_paths, targets_cls = _fix_images_paths(image_files, targets)

    data_as_dict = {}
    for el, _trg in zip(images_paths, targets_cls):
        _cls, _rep_name = el.split(os.sep)[-3:-1]
        k = (_cls, _rep_name)
        if k not in data_as_dict:
            data_as_dict[k] = []
        data_as_dict[k].append((el, _trg))

    # Add more images to the dataset to make it divisible by fix_len
    for k in data_as_dict:
        while len(data_as_dict[k]) % fix_len != 0:
            data_as_dict[k].append(data_as_dict[k][-1])

    datas = {}
    for k in data_as_dict:
        data_type = 'train'
        if data_type not in datas:
            datas[data_type] = []
        files_in_rep = len(data_as_dict[k])
        for start_idx in range(0, files_in_rep - fix_len + 1, 2):
            seq = data_as_dict[k][start_idx:start_idx + fix_len]
            assert len(seq) == fix_len, '...'
            datas[data_type].append(seq)

    return datas


def load_data_from_ds(extracted_data_path, class_dict):

    if not extracted_data_path.endswith(os.sep):
        extracted_data_path += os.sep

    # some stats about the data
    image_files = []
    targets = []
    class_name_img = []

    # loading all image_paths (IN A SORTED ORDER, this is really important to avoid any weird exceptions)
    for class_name in class_dict.keys():
        if class_name.startswith('.'):
            continue

        if not os.path.isdir(extracted_data_path + class_name):
            continue

        repetitions_list = os.listdir(extracted_data_path + class_name)
        repetitions_list.sort()
        for repetition in repetitions_list:
            if repetition.startswith('.'):
                continue
            image_list = os.listdir(extracted_data_path + class_name + os.sep + repetition)
            image_list.sort()
            image_files.extend(
                [extracted_data_path + class_name + os.sep + repetition + os.sep + img for img in image_list]
            )
            targets.extend([class_dict[class_name]] * len(image_list))
            class_name_img.extend([class_name] * len(image_list))

    # targets = np.array(targets)

    return image_files, class_name_img, class_dict


class ClassificationPlantSequenceDataset(torch.utils.data.Dataset):
    def __init__(self, data, img_size):
        self.data = data
        self.img_size = img_size

        self.transform = A.Compose([
            A.Resize(RESIZE_SIZE[0], RESIZE_SIZE[1]),
        ])

    def __getitem__(self, index):
        seq, classes = list(zip(*self.data[index]))
        assert len(set(classes)) == 1, 'wrong seq clses'
        seq_cls = classes[0]

        images = []
        for im_path in seq:
            img = cv.imread(im_path)
            img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
            img = cv.resize(img, self.img_size)
            img = self.transform(image=img)['image']
            img = img.astype(np.float32) / 255.
            img = torch.from_numpy(img.transpose((2, 0, 1)))
            images.append(img)

        images = torch.stack(images)
        images = torch.permute(images, (1, 0, 2, 3))
        return images, seq_cls

    def __len__(self):
        return len(self.data)


def reverse_dict(d):
    return dict([(v, k) for k, v in d.items()])


def remove_unavailable_outputs(opt_list, available_outputs):
    opt_list = deepcopy(opt_list)
    for i, val in enumerate(opt_list):
        if i not in available_outputs:
            opt_list[i] = 0
    return opt_list


def evaluate_sequence_model(extracted_data_path, model_path, resize_size, device, class_dict, available_outputs, seq_len):
    image_files, targets, _ = load_data_from_ds(extracted_data_path, class_dict=reverse_dict(class_dict))
    datas = prepare_dataset(image_files, targets, fix_len=max(seq_len, 10))

    available_class_names = []
    for i in available_outputs:
        available_class_names.append(class_dict[i])

    infer_dataset = ClassificationPlantSequenceDataset(datas['train'], img_size=resize_size)

    dataloader = DataLoader(infer_dataset, batch_size=4, shuffle=False)

    model = torch.jit.load(model_path, map_location=device)
    model.eval()
    model.to(device)

    correct = 0
    total = 0

    # Aggregated values for metrics
    TP_all, TN_all, FP_all, FN_all = 0, 0, 0, 0

    for i, batch_data in tqdm(enumerate(dataloader), total=len(dataloader), desc="Inference on seqs"):
        images_tensor, seq_cls_batch = batch_data

        # Move tensors to device
        images_tensor = images_tensor.to(device)

        # Model inference
        outputs = model(images_tensor)
        outputs = torch.softmax(outputs, dim=1)
        outputs = outputs.detach().cpu().numpy()

        for idx, output in enumerate(outputs):
            seq_cls = seq_cls_batch[idx]

            # Make unused outputs 0
            output_list = remove_unavailable_outputs(output.tolist(), available_outputs)

            max_val = np.argmax(output_list)

            if seq_cls not in available_class_names:
                continue

            # print("seq_cls: ", seq_cls)
            # print("opt_cls: ", class_dict[max_val], "\n")

            if max_val in available_outputs:
                if seq_cls == class_dict[max_val]:
                    TP_all += 1
                else:
                    FP_all += 1
            else:
                if seq_cls == class_dict[max_val]:
                    FN_all += 1
                else:
                    TN_all += 1

            total += 1

    # Now, compute the metrics using the aggregated counts
    # To prevent division by zero
    eps = 1e-7

    precision = TP_all / (TP_all + FP_all + eps)
    recall = TP_all / (TP_all + FN_all + eps)
    specificity = TN_all / (TN_all + FP_all + eps)
    fpr = 1 - specificity
    fnr = FN_all / (TP_all + FN_all + eps)
    fdr = FP_all / (TP_all + FP_all + eps)
    for_ = FN_all / (TN_all + FN_all + eps)  # False Omission Rate
    misclassification_rate = (FP_all + FN_all) / (TP_all + TN_all + FP_all + FN_all + eps)
    mcc = (TP_all * TN_all - FP_all * FN_all) / ((TP_all + FP_all) * (TP_all + FN_all) * (TN_all + FP_all) * (TN_all + FN_all) + eps) ** 0.5
    f1 = 2 * precision * recall / (precision + recall + eps)

    overall_accuracy = TP_all / total

    print("Overall Accuracy: ", overall_accuracy)

    results = {
        'Precision': precision,
        'Recall': recall,
        'F1 Score': f1,
        'Specificity': specificity,
        'False Positive Rate': fpr,
        'False Negative Rate': fnr,
        'False Discovery Rate': fdr,
        'False Omission Rate': for_,
        'Misclassification Rate': misclassification_rate,
        'MCC': mcc,
        'Overall Accuracy': overall_accuracy
    }

    return results

## Calculate Accuracy

In [None]:
# Perform sequence model evaluation
results = evaluate_sequence_model(extracted_data_path, model_path, RESIZE_SIZE, device, class_dict, available_outputs, seq_len=SEQ_LEN)

# Print results
print(results)