In [None]:
import warnings
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

# modelling
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.exceptions import UndefinedMetricWarning

# generators
from imblearn.over_sampling import SMOTE

# custom
from datasets import IMB_DATASETS, load_data, prepare_data


In [2]:
warnings.simplefilter("ignore", category=FutureWarning)
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
warnings.simplefilter("ignore", category=UserWarning)

In [3]:
np.random.seed(42)

In [4]:
N_GEN_FITS = 5
N_CLF_FITS = 5

# number of nearest neighbors used in SMOTE
K = 5
SAMPLING_STRATEGY = 1

In [None]:
columns = ["dataset", "precision", "recall", "f1"]
scores_df = pd.DataFrame(columns=columns, dtype=str)


for data_name in tqdm(IMB_DATASETS, leave=False):
    X, y = load_data(data_name)
    X, y = prepare_data(X, y)

    precs, recs, f1s = [], [], []
    for _ in range(N_GEN_FITS):
        generator = SMOTE(k_neighbors=K, sampling_strategy=SAMPLING_STRATEGY)
        X_augmented, y_augmented = generator.fit_resample(X, y)
        
        X_real = X[(y==1)]
        y_real = y[(y==1)]
        assert np.all(y_real == 1)
        X_synthetic = X_augmented[len(y):]
        y_synthetic = y_augmented[len(y):]
        assert np.all(y_synthetic == 1)

        for _ in range(N_CLF_FITS):
            X_real_train, X_real_test = train_test_split(X_real, test_size=0.5)
            X_synthetic_train, X_synthetic_test = train_test_split(X_synthetic, test_size=0.5)
            # X_synthetic_train, X_synthetic_test = train_test_split(X_synthetic, train_size=len(X_real_train))
            # X_synthetic_test, _ = train_test_split(X_synthetic_test, train_size=len(X_real_test))

            X_train = np.concatenate((X_real_train, X_synthetic_train))
            y_train = np.concatenate((np.ones(len(X_real_train)), np.zeros(len(X_synthetic_train)))).astype(int)
            X_test = np.concatenate((X_real_test, X_synthetic_test))
            y_test = np.concatenate((np.ones(len(X_real_test)), np.zeros(len(X_synthetic_test)))).astype(int)

            clf = RandomForestClassifier()
            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_test)
            
            precs.append(precision_score(y_test, y_pred))
            recs.append(recall_score(y_test, y_pred))
            f1s.append(f1_score(y_test, y_pred))
        
    data_socres = [f"{data_name}",
                    f"{np.array(precs).mean():.4f} +- {np.array(precs).std():.4f}",
                    f"{np.array(recs).mean():.4f} +- {np.array(recs).std():.4f}",
                    f"{np.array(f1s).mean():.4f} +- {np.array(f1s).std():.4f}"
                    ]
    scores_df = pd.concat([scores_df, pd.DataFrame([data_socres], columns=columns)], ignore_index=True)
    print(f"{data_name}: f1 -- {np.array(f1s).mean():.4f} +- {np.array(f1s).std():.4f}")

print(scores_df)
# scores_df.to_csv("results/augment/naive.csv", index=False)
    

  0%|          | 0/8 [00:00<?, ?it/s]

ecoli: f1 -- 0.0000 +- 0.0000
yeast_me2: f1 -- 0.0000 +- 0.0000
solar_flare_m0: f1 -- 0.0034 +- 0.0117
abalone: f1 -- 0.0073 +- 0.0088
car_eval_34: f1 -- 0.0081 +- 0.0139
car_eval_4: f1 -- 0.0000 +- 0.0000
mammography: f1 -- 0.0011 +- 0.0037
abalone_19: f1 -- 0.0000 +- 0.0000
          dataset         precision            recall                f1
0           ecoli  0.0000 +- 0.0000  0.0000 +- 0.0000  0.0000 +- 0.0000
1       yeast_me2  0.0000 +- 0.0000  0.0000 +- 0.0000  0.0000 +- 0.0000
2  solar_flare_m0  0.0071 +- 0.0262  0.0024 +- 0.0080  0.0034 +- 0.0117
3         abalone  0.0253 +- 0.0306  0.0043 +- 0.0051  0.0073 +- 0.0088
4     car_eval_34  0.0128 +- 0.0215  0.0060 +- 0.0103  0.0081 +- 0.0139
5      car_eval_4  0.0000 +- 0.0000  0.0000 +- 0.0000  0.0000 +- 0.0000
6     mammography  0.0055 +- 0.0196  0.0006 +- 0.0021  0.0011 +- 0.0037
7      abalone_19  0.0000 +- 0.0000  0.0000 +- 0.0000  0.0000 +- 0.0000
