In [6]:
#!/usr/bin/env python
# coding: utf-8

import os
from typing import Dict
import cv2
import numpy as np
from tqdm import tqdm
import pickle
from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.preprocessing.standard_binarizer.binarizer import StandardBinarizer
from tmu.data import TMUDataset
from tmu.composite.components.base import TMComponent
from tmu.composite.composite import TMComposite
from tmu.composite.config import TMClassifierConfig
from tmu.composite.callbacks.base import TMCompositeCallback
import logging
import scipy.io as sio
from PIL import Image
from matplotlib import pyplot as plt
import io
import sys
sys.path.insert(0,"..")
from TM_mat_comp.components.domf_component import DomfComponent
from TM_mat_comp.components.fft_component import FftComponent
from TM_mat_comp.components.std_component import StdComponent
from TM_mat_comp.components.psd_component import PsdComponent

logging.getLogger('matplotlib.font_manager').setLevel(logging.INFO)
logging.getLogger('PIL.PngImagePlugin').setLevel(logging.INFO)
_LOGGER = logging.getLogger(__name__)

In [7]:
class DataProcessor(TMUDataset):
    def __init__(self, dataset_name, use_cache=False):
        super().__init__()
        self.dataset_name = dataset_name
        self.use_cache = use_cache
        self.classes = {'pulse':0,'31':1,'12':2,'sine':3,'34':4,'11':5,'chirp_uneven':6,'not_sweep':7,'32':8,'40':9}
        # {"11": 0, "12": 1, "31": 2, "32": 3, "34": 4, "40": 5, "chirp_uneven": 6, "not_sweep": 7, "pulse": 8, "sine": 9}
        self.cache_dir = './cache'
        # self.w = w
        # self.h = h
        os.makedirs(self.cache_dir, exist_ok=True)

    def _cache_path(self, dataset_type):
        return os.path.join(self.cache_dir, f"{self.dataset_name}_{dataset_type}_{self.w}x{self.h}.pkl")

    def _load_from_cache(self, dataset_type):
        path = self._cache_path(dataset_type)
        if os.path.exists(path):
            with open(path, 'rb') as f:
                return pickle.load(f)
        return None

    def _save_to_cache(self, dataset_type, data):
        path = self._cache_path(dataset_type)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'wb') as f:
            pickle.dump(data, f)

    def _transform(self, name, dataset):
        return dataset

    def _retrieve_dataset(self) -> Dict[str, np.ndarray]:
        dataset = {}
        for dtype in ['train', 'test']:
            if self.use_cache:
                cached_data = self._load_from_cache(dtype)
                if cached_data:
                    dataset.update(cached_data)
                    continue
            data = self.create_train_data(dtype)
            if self.use_cache:
                self._save_to_cache(dtype, {f'x_{dtype}': data[0], f'y_{dtype}': data[1]})
            dataset[f'x_{dtype}'] = data[0]
            dataset[f'y_{dtype}'] = data[1]
        return dataset
 

    def create_train_data(self, dataset_type): 
        X_data = []
        y_data = []
        main_path = os.path.join('/data', self.dataset_name) 

        data_dir = os.path.join(main_path, dataset_type)

        for class_folder in os.listdir(data_dir):
            class_dir = os.path.join(data_dir, class_folder)
            for j in os.listdir(class_dir)[0:5]:
                class_path = os.path.join(class_dir, j)
                mat_data = sio.loadmat(class_path)
                X_data.append(mat_data)
                y_data.append(self.classes.get(class_folder))

        return np.array(X_data), np.array(y_data)



In [8]:
class TMCompositeEvaluationCallback(TMCompositeCallback):

    def __init__(self, data):
        super().__init__()
        self.best_acc = 0.0
        self.data = data

    def on_epoch_end(self, composite, epoch, logs=None):
        preds = composite.predict(data=self.data)
        acc = (preds == self.data["Y"]).mean()
        _LOGGER.info(f"Epoch {epoch} - Accuracy: {acc:.2f}")

if __name__ == "__main__":
    
    # General hyperparameters
    epochs = 2
    device = "CPU"
    size_w, size_h = 80, 80
    multiprocessing_mode = True
    p_w, p_h = 10, 10
    
    
    
    data = DataProcessor("data/", use_cache=False).get()
    X_train_org = data["x_train"]
    Y_train = data["y_train"]
    X_test_org = data["x_test"]
    Y_test = data["y_test"]

    data_train = dict(
        X=X_train_org,
        Y=Y_train
    )

    data_test = dict(
        X=X_test_org,
        Y=Y_test
    )


    composite_model = TMComposite(
        use_multiprocessing=multiprocessing_mode,
        components=[
            DomfComponent(TMClassifier, TMClassifierConfig(
                number_of_clauses=10,
                T=5,
                s=10,
                max_included_literals=32,
                weighted_clauses=True,
                patch_dim=(p_w, p_h),
                platform=device
            ), epochs=epochs, target_size=(size_w, size_h)),

            FftComponent(TMClassifier, TMClassifierConfig(
                number_of_clauses=10,
                T=5,
                s=10,
                max_included_literals=32,
                weighted_clauses=True,
                patch_dim=(p_w, p_h),
                platform=device
            ), epochs=epochs, target_size=(size_w, size_h)),

            StdComponent(TMClassifier, TMClassifierConfig(
                number_of_clauses=10,
                T=5,
                s=10,
                max_included_literals=32,
                weighted_clauses=True,
                patch_dim=(p_w, p_h),
                platform=device
            ), epochs=epochs, target_size=(size_w, size_h)),

            PsdComponent(TMClassifier, TMClassifierConfig(
                number_of_clauses=10,
                T=5,
                s=10,
                max_included_literals=32,
                weighted_clauses=True,
                patch_dim=(p_w, p_h),
                platform=device
            ), epochs=epochs, target_size=(size_w, size_h)),
            
        ])
    
    # Train the composite model
    composite_model.fit(
        data=data_train,
        callbacks=[
            TMCompositeEvaluationCallback(data=data_test)
        ]
    )

Component 0: DomfComponent:   0%|          | 0/2 [00:00<?, ?it/s]
[A
[A

[A[A

Component 0: DomfComponent:  50%|█████     | 1/2 [00:00<00:00,  5.36it/s]

[A[A
[A
Component 0: DomfComponent: 100%|██████████| 2/2 [00:00<00:00,  5.53it/s]

Component 2: StdComponent: 100%|██████████| 2/2 [00:00<00:00,  4.15it/s]
Component 0: DomfComponent: 100%|██████████| 2/2 [00:00<00:00,  3.80it/s]
Component 1: FftComponent: 100%|██████████| 2/2 [00:00<00:00,  3.81it/s]
Component 3: PsdComponent: 100%|██████████| 2/2 [00:00<00:00,  3.85it/s]


In [9]:
preds = composite_model.predict(data=data_test)

y_true = data_test["Y"].flatten()
print(y_true)
for k, v in preds.items():
    print(v)
    print(f"{k} Accuracy: %.1f" % (100 * (v == y_true).mean()))

[0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 3 3 3 3 3 4 4 4 4 4 5 5 5 5 5 6 6 6 6 6 7 7
 7 7 7 8 8 8 8 8 9 9 9 9 9]
[7 7 7 7 7 0 8 0 4 7 0 7 7 7 7 7 7 7 7 7 7 0 7 7 7 9 0 0 7 7 0 8 4 4 8 4 4
 0 7 7 3 3 3 3 3 9 7 7 9 9]
composite Accuracy: 10.0
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 0 9 9 0 0 0 0 0 0 0 0 0 9 0
 0 0 0 3 3 3 3 3 9 0 0 9 9]
DomfComponent-TMClassifier-number_of_clauses=10-T=5-s=10.0-max_included_literals=32-platform=CPU-weighted_clauses=True-patch_dim=(10, 10)) Accuracy: 18.0
[0 0 4 4 0 0 8 8 4 0 0 0 0 0 0 8 8 8 8 8 8 0 4 8 8 4 0 0 8 8 0 8 4 4 8 4 4
 0 8 0 8 8 4 8 8 8 0 8 8 8]
FftComponent-TMClassifier-number_of_clauses=10-T=5-s=10.0-max_included_literals=32-platform=CPU-weighted_clauses=True-patch_dim=(10, 10)) Accuracy: 16.0
[7 7 7 7 7 0 0 0 0 7 0 7 7 7 7 7 7 7 7 7 7 0 7 7 7 7 0 0 7 7 0 0 0 0 0 7 7
 0 7 7 0 0 0 0 0 7 7 7 7 7]
StdComponent-TMClassifier-number_of_clauses=10-T=5-s=10.0-max_included_literals=32-platform=CPU-weighted_clauses=True-patch_dim=(10, 10)) Accuracy: 8.0
[

In [10]:
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score
)
for k, v in preds.items():
    cm = confusion_matrix(y_true, v)
    _LOGGER.info(f"{k} cm : \n{cm}")
    # if(k == "composite"):
    #     _LOGGER.info(f"{k} cm : \n{cm}")

2024-04-19 07:42:24,221 - __main__ - INFO - composite cm : 
[[0 0 0 0 0 0 0 5 0 0]
 [2 0 0 0 1 0 0 1 1 0]
 [1 0 0 0 0 0 0 4 0 0]
 [0 0 0 0 0 0 0 5 0 0]
 [1 0 0 0 0 0 0 4 0 0]
 [2 0 0 0 0 0 0 2 0 1]
 [1 0 0 0 2 0 0 0 2 0]
 [1 0 0 0 2 0 0 2 0 0]
 [0 0 0 5 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 2 0 3]]
2024-04-19 07:42:24,223 - __main__ - INFO - DomfComponent-TMClassifier-number_of_clauses=10-T=5-s=10.0-max_included_literals=32-platform=CPU-weighted_clauses=True-patch_dim=(10, 10)) cm : 
[[5 0 0 0 0 0 0 0 0 0]
 [5 0 0 0 0 0 0 0 0 0]
 [5 0 0 0 0 0 0 0 0 0]
 [4 0 0 1 0 0 0 0 0 0]
 [4 0 0 0 0 0 0 0 0 1]
 [4 0 0 0 0 0 0 0 0 1]
 [5 0 0 0 0 0 0 0 0 0]
 [4 0 0 0 0 0 0 0 0 1]
 [0 0 0 5 0 0 0 0 0 0]
 [2 0 0 0 0 0 0 0 0 3]]
2024-04-19 07:42:24,224 - __main__ - INFO - FftComponent-TMClassifier-number_of_clauses=10-T=5-s=10.0-max_included_literals=32-platform=CPU-weighted_clauses=True-patch_dim=(10, 10)) cm : 
[[3 0 0 0 2 0 0 0 0 0]
 [2 0 0 0 1 0 0 0 2 0]
 [5 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 5 0]
 [1 0 0 0