In [None]:
import json
import time
import optuna
from optuna.storages import JournalStorage, JournalFileStorage
from joblib import Parallel, delayed

import os
from typing import Dict
import cv2
import numpy as np
from tqdm import tqdm
import pickle
from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.preprocessing.standard_binarizer.binarizer import StandardBinarizer
from tmu.data import TMUDataset
from tmu.composite.components.base import TMComponent
from tmu.composite.composite import TMComposite
from tmu.composite.config import TMClassifierConfig
from tmu.composite.callbacks.base import TMCompositeCallback
import logging
import scipy.io as sio
from PIL import Image
from matplotlib import pyplot as plt
import io
import sys
sys.path.insert(0,"..")
from TM_mat_comp.components.domf_component import DomfComponent
from TM_mat_comp.components.fft_component import FftComponent
from TM_mat_comp.components.std_component import StdComponent
from TM_mat_comp.components.psd_component import PsdComponent

logging.getLogger('matplotlib.font_manager').setLevel(logging.INFO)
logging.getLogger('PIL.PngImagePlugin').setLevel(logging.INFO)
_LOGGER = logging.getLogger(__name__)


In [None]:
class DataProcessor(TMUDataset):
    def __init__(self, dataset_name, use_cache=False):
        super().__init__()
        self.dataset_name = dataset_name
        self.use_cache = use_cache
        self.classes = {'pulse':0,'31':1,'12':2,'sine':3,'34':4,'11':5,'chirp_uneven':6,'not_sweep':7,'32':8,'40':9}
        # {"11": 0, "12": 1, "31": 2, "32": 3, "34": 4, "40": 5, "chirp_uneven": 6, "not_sweep": 7, "pulse": 8, "sine": 9}
        self.cache_dir = './cache'
        os.makedirs(self.cache_dir, exist_ok=True)

    def _cache_path(self, dataset_type):
        return os.path.join(self.cache_dir, f"{self.dataset_name}_{dataset_type}_{self.w}x{self.h}.pkl")

    def _load_from_cache(self, dataset_type):
        path = self._cache_path(dataset_type)
        if os.path.exists(path):
            with open(path, 'rb') as f:
                return pickle.load(f)
        return None

    def _save_to_cache(self, dataset_type, data):
        path = self._cache_path(dataset_type)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'wb') as f:
            pickle.dump(data, f)

    def _transform(self, name, dataset):
        return dataset

    def _retrieve_dataset(self) -> Dict[str, np.ndarray]:
        dataset = {}
        for dtype in ['train', 'test']:
            if self.use_cache:
                cached_data = self._load_from_cache(dtype)
                if cached_data:
                    dataset.update(cached_data)
                    continue
            data = self.create_train_data(dtype)
            if self.use_cache:
                self._save_to_cache(dtype, {f'x_{dtype}': data[0], f'y_{dtype}': data[1]})
            dataset[f'x_{dtype}'] = data[0]
            dataset[f'y_{dtype}'] = data[1]
        return dataset
 

    def create_train_data(self, dataset_type): 
        X_data = []
        y_data = []
        main_path = os.path.join('/data', self.dataset_name) 

        data_dir = os.path.join(main_path, dataset_type)

        for class_folder in os.listdir(data_dir):
            class_dir = os.path.join(data_dir, class_folder)
            for j in os.listdir(class_dir)[0:10]:
                class_path = os.path.join(class_dir, j)
                mat_data = sio.loadmat(class_path)
                X_data.append(mat_data)
                y_data.append(self.classes.get(class_folder))

        return np.array(X_data), np.array(y_data)

class TMCompositeEvaluationCallback(TMCompositeCallback):

    def __init__(self, data):
        super().__init__()
        self.best_acc = 0.0
        self.data = data

    def on_epoch_end(self, composite, epoch, logs=None):
        preds = composite.predict(data=self.data)
        acc = (preds == self.data["Y"]).mean()
        _LOGGER.info(f"Epoch {epoch} - Accuracy: {acc:.2f}")

In [None]:

class TMCompositeTuner:

    def __init__(
            self,
            data_train,
            data_test,
            platform="CPU",
            max_epochs=200,
            n_jobs: int = 1,
            callbacks=None,
            use_multiprocessing=True,
            study_name="TMComposite_study",
            target_size = (80,80)
    ):
        self.data_train = data_train
        self.data_test = data_test
        self.last_accuracy = 0.0
        self.n_components = 1
        self.n_jobs = n_jobs
        self.study_name = study_name
        self.platform = platform
        self.max_epochs = max_epochs
        self.target_size = target_size
        if callbacks is None:
            callbacks = []

        self.callbacks = callbacks
        self.use_multiprocessing = use_multiprocessing

    def objective(self, trial: optuna.trial.Trial) -> float:
        components_list = []

        for i in range(self.n_components):
            component_type = trial.suggest_categorical(f'component_type_{i}',
                                                       ['DomfComponent'
                                                        'FftComponent',
                                                        'StdComponent',
                                                        'PsdComponent'
                                                        ])

            num_clauses = trial.suggest_int(f'num_clauses_{i}', 1000, 3000)
            T = trial.suggest_int(f'T_{i}', 100, 1500)
            s = trial.suggest_float(f's_{i}', 2.0, 15.0)
            max_included_literals = trial.suggest_int(f'max_literals_{i}', 16, 64)
            weighted_clauses = trial.suggest_categorical(f'weighted_clauses_{i}', [True, False])
            epochs = trial.suggest_int(f'epochs_{i}', 1, self.max_epochs)

            config = TMClassifierConfig(
                number_of_clauses=num_clauses,
                T=T,
                s=s,
                max_included_literals=max_included_literals,
                platform=self.platform,
                patch_dim=(10, 10),
                weighted_clauses=weighted_clauses
            )

            if component_type == 'DomfComponent':
                patch_dim = (trial.suggest_int(f'patch_dim_1_{i}', 1, 10), trial.suggest_int(f'patch_dim_2_{i}', 1, 10))
                config.patch_dim = patch_dim
                components_list.append(DomfComponent(TMClassifier, config, epochs=epochs, target_size=self.target_size))

            elif component_type == 'FftComponent':
                patch_dim = (trial.suggest_int(f'patch_dim_1_{i}', 1, 10), trial.suggest_int(f'patch_dim_2_{i}', 1, 10))
                config.patch_dim = patch_dim
                components_list.append(FftComponent(TMClassifier, config, epochs=epochs, target_size=self.target_size))

            elif component_type == 'StdComponent':
                patch_dim = (trial.suggest_int(f'patch_dim_1_{i}', 1, 10), trial.suggest_int(f'patch_dim_2_{i}', 1, 10))
                config.patch_dim = patch_dim
                components_list.append(StdComponent(TMClassifier, config, epochs=epochs, target_size=self.target_size))

            elif component_type == 'PsdComponent':
                patch_dim = (trial.suggest_int(f'patch_dim_1_{i}', 1, 10), trial.suggest_int(f'patch_dim_2_{i}', 1, 10))
                config.patch_dim = patch_dim
                components_list.append(PsdComponent(TMClassifier, config, epochs=epochs, target_size=self.target_size))


        composite_model = TMComposite(components=components_list, use_multiprocessing=self.use_multiprocessing)

        # Training and evaluation
        composite_model.fit(
            data=self.data_train,
            callbacks=self.callbacks
        )

        preds = composite_model.predict(data=self.data_test)
        accuracy = (preds['composite'] == self.data_test['Y'].flatten()).mean()

        # Adjust number of components for next trial
        if accuracy > self.last_accuracy:
            self.n_components += 1
        else:
            self.n_components = max(1, self.n_components - 1)

        self.last_accuracy = accuracy
        return accuracy

    def save_best_params(self, study, trial, filename="best_params.json"):
        best_data = {
            'params': study.best_params,
            'value': trial.value
        }
        with open(filename, "w") as f:
            json.dump(best_data, f)

    def gradual_saving_callback(self, study, trial):
        # Use np.isclose to handle potential floating-point precision issues
        if np.isclose(trial.value, study.best_value, atol=1e-10):
            self.save_best_params(study, trial, filename=f"best_params_trial_{trial.number}.json")

    def retry_optimize(self, study, objective, n_trials, callbacks, max_retries=5, wait_time=2.0):
        for _ in range(max_retries):
            try:
                study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
                return
            except Exception as e:
                if "database is locked" in str(e).lower():
                    time.sleep(wait_time)
                else:
                    raise e
        raise RuntimeError("Max retries reached for database access")

    def tune(self, n_trials: int = 100):
        storage = JournalStorage(JournalFileStorage("optuna-journal.log"))
        with Parallel(n_jobs=self.n_jobs) as parallel:
            if self.n_jobs == 1:
                study = optuna.create_study(direction='maximize', pruner=optuna.pruners.MedianPruner(), storage=storage,
                                            load_if_exists=True)
                self.retry_optimize(study, self.objective, n_trials, [self.gradual_saving_callback])
            else:
                study = optuna.create_study(study_name=self.study_name, direction='maximize', storage=storage,
                                            load_if_exists=True, pruner=optuna.pruners.MedianPruner())
                parallel(
                    delayed(self.retry_optimize)(study, self.objective, n_trials // self.n_jobs,
                                                 [self.gradual_saving_callback])
                    for i in range(self.n_jobs)
                )

        return study.best_params, study.best_value

In [None]:
if __name__ == "__main__":
    
    # General hyperparameters
    epochs = 2
    # device = "CPU"
    target_size = (31, 31)
    # multiprocessing_mode = True
    # p_w, p_h = 10, 10
    
    
    
    data = DataProcessor("data/", use_cache=False).get()
    X_train_org = data["x_train"]
    Y_train = data["y_train"]
    X_test_org = data["x_test"]
    Y_test = data["y_test"]

    data_train = dict(
        X=X_train_org,
        Y=Y_train
    )

    data_test = dict(
        X=X_test_org,
        Y=Y_test
    )

    # Instantiate tuner
    tuner = TMCompositeTuner(
        data_train=data_train,
        data_test=data_test,
        max_epochs=epochs,
        target_size=target_size,
        # callbacks=[TMCompositeEvaluationCallback(data_test)],
        n_jobs=1 )


    # Specify number of trials (iterations of the tuning process)
    n_trials = 3

    # Run the tuner
    best_params, best_value = tuner.tune(n_trials=n_trials)

    # Print out the results
    print("Best Parameters:", best_params)
    print("Best Value:", best_value)