In [1]:
import pandas as pd
from nebula.data.yg_ar.setup_data_image_hard import read_data
from nebula.common import to_scale_one, write_pickle, read_pickle
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
import os
import os.path as osp
import numpy as np

  warn(f"Failed to load image Python extension: {e}")


In [2]:
def create_label_map(labels):
    label_set = set()
    for lt in labels:
        label_set.add(lt)
        
    label_set = list(label_set)
    label_set.sort()

    label_map = {}
    count = 0
    for l in label_set:
        label_map[l] = count
        count += 1
        
    return label_map

In [3]:
df_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/image_easy2_ds_pca.pkl"
df, train_df, test_df, valid_df = read_pickle(df_path)

In [4]:
label_map_a = create_label_map(df["label_a"])
label_map_at = create_label_map(df["label_at"])

In [5]:
train_x = train_df["image"].to_list()

In [6]:
train_y_a = train_df["label_a"].map(label_map_a).to_list()
train_y_at = train_df["label_at"].map(label_map_at).to_list()

In [7]:
test_x = test_df["image"].to_list()

In [8]:
test_y_a = test_df["label_a"].map(label_map_a).to_list()
test_y_at = test_df["label_at"].map(label_map_at).to_list()

In [9]:
def train_svm(data_x, data_y):
    clf = svm.SVC(max_iter=50)
    clf.fit(data_x, data_y)
    return clf


def train_logistic(data_x, data_y):
    clf = LogisticRegression(random_state=0)
    clf.fit(data_x, data_y)
    return clf


def train_gbt(data_x, data_y):
    clf = GradientBoostingClassifier(
        n_estimators=100,
        learning_rate=0.01,
        max_depth=3,
        random_state=0,
        verbose=1,
        n_iter_no_change=2,
    )
    clf.fit(data_x, data_y)
    return clf


def evaluate(model, test_x, test_y):
    res = model.predict(test_x)
    correct = res == test_y
    accuracy = correct.sum() / len(res)
    return res, accuracy


def create_dirs_to_file(path):
    dirs = "/".join(osp.join(path).split("/")[:-1])
    if not osp.exists(dirs):
        os.makedirs(dirs)


def load_or_train(train_x, train_y, test_x, test_y, train_func, label_map, path):
    
    if osp.exists(path):
        return read_pickle(path)
    
    create_dirs_to_file(path)
    
    trained_model = train_func(train_x, train_y)
    predictions, accuracy = evaluate(trained_model, test_x, test_y)
    
    df, df_incorrect, df_correct = format_results(predictions, test_y, label_map)
    
    write_pickle(path, (trained_model, predictions, accuracy, df, df_incorrect, df_correct, label_map)) 
    
    return trained_model, predictions, accuracy, df, df_incorrect, df_correct, label_map


def format_results(predictions, labels, label_map):
    df = pd.DataFrame(
        data={
            "prediction": predictions,
            "label": labels
        }
    )
    df["check"] = df["prediction"] == df["label"]

    label_map_reverse = {v:k for k, v in label_map.items()}

    df["prediction_name"] = df.prediction.map(label_map_reverse)
    df["label_name"] = df.label.map(label_map_reverse)

    df_incorrect = df[~df.check]
    df_correct = df[df.check]

    return df, df_incorrect, df_correct

In [11]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_easy2_pca/svm_a.pkl"
(
    trained_svm_a, 
    predictions_svm_a, 
    accuracy_svm_a, 
    df_svm_a, 
    df_incorrect_svm_a, 
    df_correct_svm_a,
    label_map_svm_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_svm, 
    label_map_a, 
    save_path
)
print(accuracy_svm_a)
print(df_incorrect_svm_a.head())



0.48333333333333334
   prediction  label  check    prediction_name  label_name
0           0      8  False              camel  warrior_II
1           1      8  False              chair  warrior_II
2           3      8  False  lord_of_the_dance  warrior_II
4           1      8  False              chair  warrior_II
6           3      8  False  lord_of_the_dance  warrior_II


In [12]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_easy2_pca/svm_at.pkl"
(
    trained_svm_at, 
    predictions_svm_at, 
    accuracy_svm_at, 
    df_svm_at,
    df_incorrect_svm_at, 
    df_correct_svm_at,
    label_map_svm_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_svm, 
    label_map_at, 
    save_path
)
print(accuracy_svm_at)
print(df_incorrect_svm_at.head())



0.4199074074074074
   prediction  label  check      prediction_name    label_name
2          13     37  False  lord_of_the_dance_2  warrior_II_2
3          36     37  False         warrior_II_1  warrior_II_2
4          36     37  False         warrior_II_1  warrior_II_2
5          39     37  False         warrior_II_4  warrior_II_2
6          36     37  False         warrior_II_1  warrior_II_2


In [13]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_easy2_pca/logistic_a.pkl"
(
    trained_logistic_a, 
    predictions_logistic_a, 
    accuracy_logistic_a, 
    df_logistic_a,
    df_incorrect_logistic_a, 
    df_correct_logistic_a,
    label_map_logistic_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_logistic, 
    label_map_a, 
    save_path
)
print(accuracy_logistic_a)
print(df_incorrect_logistic_a.head())

0.44351851851851853
   prediction  label  check    prediction_name  label_name
2           5      8  False        thunderbolt  warrior_II
5           9      8  False        warrior_III  warrior_II
6           3      8  False  lord_of_the_dance  warrior_II
7           3      8  False  lord_of_the_dance  warrior_II
8           1      8  False              chair  warrior_II


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [14]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_easy2_pca/logistic_at.pkl"
(
    trained_logistic_at, 
    predictions_logistic_at, 
    accuracy_logistic_at, 
    df_logistic_at,
    df_incorrect_logistic_at, 
    df_correct_logistic_at,
    label_map_logistic_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_logistic, 
    label_map_at, 
    save_path
)
print(accuracy_logistic_at)
print(df_incorrect_logistic_at.head())

0.16944444444444445
   prediction  label  check prediction_name    label_name
0          36     37  False    warrior_II_1  warrior_II_2
1          36     37  False    warrior_II_1  warrior_II_2
2          21     37  False   thunderbolt_2  warrior_II_2
3          38     37  False    warrior_II_3  warrior_II_2
4          39     37  False    warrior_II_4  warrior_II_2


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [16]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_easy2_pca/gbt_a.pkl"
(
    trained_gbt_a, 
    predictions_gbt_a, 
    accuracy_gbt_a, 
    df_gbt_a, 
    df_incorrect_gbt_a, 
    df_correct_gbt_a,
    label_map_gbt_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_gbt, 
    label_map_a, 
    save_path
)
print(accuracy_gbt_a)
print(df_incorrect_gbt_a.head())

      Iter       Train Loss   Remaining Time 
         1           2.2929           54.43m
         2           2.2836           59.06m
         3           2.2746           55.04m
         4           2.2660           53.68m
         5           2.2577           52.49m
         6           2.2495           55.07m
         7           2.2416           54.41m
         8           2.2338           53.26m
         9           2.2263           51.79m
        10           2.2189           50.60m
        20           2.1519           49.80m
        30           2.0938           46.60m
        40           2.0402           41.09m
        50           1.9916           34.64m
        60           1.9471           29.20m
        70           1.9057           22.21m
        80           1.8677           21.34m
        90           1.8334           23.38m
       100           1.8016            0.00s
0.462037037037037
   prediction  label  check    prediction_name  label_name
0           5      8  

In [15]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_easy2_pca/gbt_at.pkl"
(
    trained_gbt_at, 
    predictions_gbt_at, 
    accuracy_gbt_at, 
    df_gbt_at, 
    df_incorrect_gbt_at, 
    df_correct_gbt_at,
    label_map_gbt_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_gbt, 
    label_map_at, 
    save_path
)
print(accuracy_gbt_at)
print(df_incorrect_gbt_at.head())

      Iter       Train Loss   Remaining Time 
         1           3.6669          147.14m
         2           3.6492          146.56m
         3           3.6337          154.40m
         4           3.6194          155.92m
         5           3.6064          157.70m
         6           3.5945          158.72m
         7           3.5826          158.31m
         8           3.5719          158.01m
         9           3.5610          158.35m
        10           3.5505          157.51m
        20           3.4559          144.21m
        30           3.3749          133.33m
        40           3.3000          119.73m
        50           3.2296          102.98m
        60           3.1663           83.80m
        70           3.1052           64.11m
        80           3.0472           43.54m
        90           2.9936           21.79m
       100           2.9412            0.00s
0.18009259259259258
   prediction  label  check prediction_name    label_name
0          22     37 