In [2]:
import dataclasses
import numpy
import os
import pathlib

from preprocess import *
from samples import get_index, get_samples_and_labels
from experiments import run, run_cnn

In [3]:
@dataclasses.dataclass
class Patch:
    type: str
    n_patch: int


@dataclasses.dataclass
class Dataset:
    cfg: dict
    list_data: list = dataclasses.field(init=False)
    name: str
    patch: Patch
    path: str

    def __post_init__(self):
        self.load()

    def get_n_features(self):
        for p in pathlib.Path(self.path).rglob("*"):
            if p.is_file():
                x = numpy.load(str(p.resolve()))
                return x.shape[1]
        raise FileNotFoundError(f"problems in {self.path}")

    def load(self):
        if self.patch:
            self.load_patch()
        else:
            self.load_normal()

    def load_patch(self):
        x = numpy.empty(shape=(0, self.get_n_features()))
        for i, file in enumerate(sorted(pathlib.Path(self.path).rglob("*"))):
            data = numpy.load(str(file.resolve()))
            x = numpy.concatenate((x, data), axis=0)
        y = numpy.repeat(numpy.arange(1, self.cfg["n_labels"] + 1), (x.shape[0] / self.cfg["n_labels"]))
        self.list_data = preprocess(x)
        for d in self.list_data:
            setattr(d, "y", y)

    def load_normal(self):
        x, y = get_samples_and_labels(numpy.loadtxt(self.path))
        self.list_data = preprocess(x)
        for d in self.list_data:
            setattr(d, "y", y)

In [4]:
cfg = {
    "fold": 5,
    "n_labels": 5,
    "path_base": "dataset",
    "path_out": "out",
    "pca": True,
    "test_size": 0.2,
    "train_size": 0.8,
    }
    
list_index = get_index(cfg, os.path.join(cfg["path_base"], "surf64.txt"))

# Experiments

### MobileNetV2

In [5]:
mb = Dataset(cfg,
    "mobilenet",
    Patch("horizontal", 3),
    os.path.join(cfg["path_base"], "mobilenetv2/patch=3/horizontal"))
run_cnn(cfg, mb, list_index)

0 svm mobilenet (1125, 128) (1125,)
type: sum, accuracy (%): 80.0
+++++++++++++++++++++++++++++++++
1 svm mobilenet (1125, 128) (1125,)
type: sum, accuracy (%): 85.3333
+++++++++++++++++++++++++++++++++
2 svm mobilenet (1125, 128) (1125,)
type: sum, accuracy (%): 84.0
+++++++++++++++++++++++++++++++++
3 svm mobilenet (1125, 128) (1125,)
type: sum, accuracy (%): 77.3333
+++++++++++++++++++++++++++++++++
4 svm mobilenet (1125, 128) (1125,)
type: sum, accuracy (%): 76.0
+++++++++++++++++++++++++++++++++
mean accuracy (%): 80.5333, std deviation: 0.0364, rule: sum, mean elapsed time: 00:00:00 (0.4614260673522949)
best_accuracy: 80.5333 rule:sum

---------------------------------
