In [1]:
# %pip uninstall -y coniferest
# %pip install 'git+https://github.com/snad-space/coniferest@fix-devent-celeba'

In [2]:
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from coniferest.aadforest import AADForest
from coniferest.datasets import Dataset, DevNetDataset
from coniferest.isoforest import IsolationForest
from coniferest.label import Label
from coniferest.pineforest import PineForest
from coniferest.session.oracle import OracleSession, create_oracle_session

In [3]:
class Compare:
    models = {
        'Isolation Forest': IsolationForest,
        'AAD': AADForest,
        'Pine Forest': PineForest,
    }
    
    def __init__(self, dataset: Dataset, *, iterations=100, n_jobs=-1):
        self.model_kwargs = {
            'n_trees': 128,
            'n_jobs': n_jobs,
        }
        self.session_kwargs = {
            'data': dataset.data,
            'labels': dataset.labels,
            'max_iterations': iterations,
        }
        self.results = {}
        self.steps = np.arange(1, iterations + 1)
        self.total_anomaly_fraction = np.mean(dataset.labels == Label.A)

    def get_sessions(self, random_seed):
        model_kwargs = self.model_kwargs | {'random_seed': random_seed}

        return {
            name: create_oracle_session(model=model(**model_kwargs), **self.session_kwargs)
            for name, model in self.models.items()
        }

    def run(self, random_seeds):
        results = defaultdict(dict)
        
        for random_seed in tqdm(random_seeds):
            sessions = self.get_sessions(random_seed)
            for name, session in sessions.items():
                session.run()
                anomalies = np.cumsum(np.array(list(session.known_labels.values())) == Label.A)
                results[name][random_seed] = anomalies

        self.results |= results
        return self
    
    def plot(self, dataset_name: str, savefig=False):
        plt.figure(figsize=(8, 6))
        plt.title(f'Dataset: {dataset_name}')

        for name, anomalies_dict in self.results.items():
            anomalies = np.stack(list(anomalies_dict.values()))
            q10, median, q90 = np.quantile(anomalies, [0.1, 0.5, 0.9], axis = 0)

            plt.plot(self.steps, median, alpha=0.75, label=name)
            plt.fill_between(self.steps, q10, q90, alpha=0.5)

        plt.plot(self.steps, self.steps * self.total_anomaly_fraction, ls='--', color='grey', label='Theoretical radnom')

        plt.xlabel('Iteration')
        plt.ylabel('Number of anomalies')
        plt.grid()
        plt.legend()
        if savefig:
            plt.savefig(f'{dataset}.pdf')
        
        return self

In [None]:
print(DevNetDataset.avialble_datasets)

seeds = range(20)

for dataset in DevNetDataset.avialble_datasets:
    print(dataset)
    %time compare = Compare(DevNetDataset(dataset), iterations=100, n_jobs=10).run(seeds).plot(dataset, savefig=True)
    plt.show()

['donors', 'census', 'fraud', 'celeba', 'backdoor', 'campaign', 'thyroid']
donors


 60%|██████████████████████████████████▏                      | 12/20 [1:56:30<1:25:02, 637.84s/it]

In [None]:
%time compare = Compare(DevNetDataset("thyroid"), iterations=7200, n_jobs=15).run([0]).plot(f'{dataset}_full', savefig=True)
plt.show()