In [17]:
import pandas as pd
from nebula.data.yg_ar.setup_data_image_hard import read_data
from nebula.common import to_scale_one, write_pickle, read_pickle
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
import os
import os.path as osp
import numpy as np

In [18]:
def create_label_map(labels):
    label_set = set()
    for lt in labels:
        label_set.add(lt)
        
    label_set = list(label_set)
    label_set.sort()

    label_map = {}
    count = 0
    for l in label_set:
        label_map[l] = count
        count += 1
        
    return label_map

In [19]:
df_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/bg_medium_ds_point2_scale.pkl"
df, train_df, test_df, valid_df = read_pickle(df_path)

In [20]:
label_map_a = create_label_map(df["label_a"])
label_map_at = create_label_map(df["label_at"])

In [21]:
train_x = train_df["image"].to_list()

In [22]:
train_y_a = train_df["label_a"].map(label_map_a).to_list()
train_y_at = train_df["label_at"].map(label_map_at).to_list()

In [23]:
test_x = test_df["image"].to_list()

In [24]:
test_y_a = test_df["label_a"].map(label_map_a).to_list()
test_y_at = test_df["label_at"].map(label_map_at).to_list()

In [26]:
def train_svm(data_x, data_y):
    clf = svm.SVC(max_iter=50)
    clf.fit(data_x, data_y)
    return clf


def train_logistic(data_x, data_y):
    clf = LogisticRegression(random_state=0)
    clf.fit(data_x, data_y)
    return clf


def train_gbt(data_x, data_y):
    clf = GradientBoostingClassifier(
        n_estimators=100,
        learning_rate=0.1,
        max_depth=3,
        random_state=0,
        verbose=1,
        n_iter_no_change=2,
    )
    clf.fit(data_x, data_y)
    return clf


def evaluate(model, test_x, test_y):
    res = model.predict(test_x)
    correct = res == test_y
    accuracy = correct.sum() / len(res)
    return res, accuracy


def create_dirs_to_file(path):
    dirs = "/".join(osp.join(path).split("/")[:-1])
    if not osp.exists(dirs):
        os.makedirs(dirs)


def load_or_train(train_x, train_y, test_x, test_y, train_func, label_map, path):
    
    if osp.exists(path):
        return read_pickle(path)
    
    create_dirs_to_file(path)
    
    trained_model = train_func(train_x, train_y)
    predictions, accuracy = evaluate(trained_model, test_x, test_y)
    
    df, df_incorrect, df_correct = format_results(predictions, test_y, label_map)
    
    write_pickle(path, (trained_model, predictions, accuracy, df, df_incorrect, df_correct, label_map)) 
    
    return trained_model, predictions, accuracy, df, df_incorrect, df_correct, label_map


def format_results(predictions, labels, label_map):
    df = pd.DataFrame(
        data={
            "prediction": predictions,
            "label": labels
        }
    )
    df["check"] = df["prediction"] == df["label"]

    label_map_reverse = {v:k for k, v in label_map.items()}

    df["prediction_name"] = df.prediction.map(label_map_reverse)
    df["label_name"] = df.label.map(label_map_reverse)

    df_incorrect = df[~df.check]
    df_correct = df[df.check]

    return df, df_incorrect, df_correct

In [27]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg/svm_a.pkl"
(
    trained_svm_a, 
    predictions_svm_a, 
    accuracy_svm_a, 
    df_svm_a, 
    df_incorrect_svm_a, 
    df_correct_svm_a,
    label_map_svm_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_svm, 
    label_map_a, 
    save_path
)
print(accuracy_svm_a)
print(df_incorrect_svm_a.head())



0.638425925925926
    prediction  label  check    prediction_name   label_name
2            3      9  False  lord_of_the_dance  warrior_III
11           1      9  False              chair  warrior_III
24           1      9  False              chair  warrior_III
34           1      9  False              chair  warrior_III
35           1      9  False              chair  warrior_III


In [28]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg/svm_at.pkl"
(
    trained_svm_at, 
    predictions_svm_at, 
    accuracy_svm_at, 
    df_svm_at,
    df_incorrect_svm_at, 
    df_correct_svm_at,
    label_map_svm_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_svm, 
    label_map_at, 
    save_path
)
print(accuracy_svm_at)
print(df_incorrect_svm_at.head())



0.35555555555555557
    prediction  label  check prediction_name     label_name
0           34     32  False   warrior_III_3  warrior_III_1
7           34     32  False   warrior_III_3  warrior_III_1
11          35     32  False   warrior_III_4  warrior_III_1
12          34     32  False   warrior_III_3  warrior_III_1
13          34     32  False   warrior_III_3  warrior_III_1


In [29]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg/logistic_a.pkl"
(
    trained_logistic_a, 
    predictions_logistic_a, 
    accuracy_logistic_a, 
    df_logistic_a,
    df_incorrect_logistic_a, 
    df_correct_logistic_a,
    label_map_logistic_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_logistic, 
    label_map_a, 
    save_path
)
print(accuracy_logistic_a)
print(df_incorrect_logistic_a.head())

0.5958333333333333
    prediction  label  check    prediction_name   label_name
0            8      9  False         warrior_II  warrior_III
1            6      9  False           triangle  warrior_III
3            3      9  False  lord_of_the_dance  warrior_III
5            8      9  False         warrior_II  warrior_III
10           6      9  False           triangle  warrior_III


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [30]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg/logistic_at.pkl"
(
    trained_logistic_at, 
    predictions_logistic_at, 
    accuracy_logistic_at, 
    df_logistic_at,
    df_incorrect_logistic_at, 
    df_correct_logistic_at,
    label_map_logistic_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_logistic, 
    label_map_at, 
    save_path
)
print(accuracy_logistic_at)
print(df_incorrect_logistic_at.head())

0.23194444444444445
   prediction  label  check prediction_name     label_name
0           3     32  False         camel_4  warrior_III_1
1          24     32  False      triangle_1  warrior_III_1
2          33     32  False   warrior_III_2  warrior_III_1
3          34     32  False   warrior_III_3  warrior_III_1
5          37     32  False    warrior_II_2  warrior_III_1


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [31]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg/gbt01_a.pkl"
(
    trained_gbt_a, 
    predictions_gbt_a, 
    accuracy_gbt_a, 
    df_gbt_a, 
    df_incorrect_gbt_a, 
    df_correct_gbt_a,
    label_map_gbt_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_gbt, 
    label_map_a, 
    save_path
)
print(accuracy_gbt_a)
print(df_incorrect_gbt_a.head())

      Iter       Train Loss   Remaining Time 
         1           2.1981           20.62m
         2           2.1159           20.41m
         3           2.0409           20.27m
         4           1.9773           19.97m
         5           1.9204           19.75m
         6           1.8693           19.45m
         7           1.8185           19.21m
         8           1.7746           18.98m
         9           1.7327           18.74m
        10           1.6921           18.53m
        20           1.3944           16.35m
        30           1.1954           14.27m
        40           1.0483           12.29m
        50           0.9332           10.33m
        60           0.8410            8.32m
        70           0.7623            6.33m
        80           0.6948            4.22m
        90           0.6394            2.12m
       100           0.5881            0.00s
0.7805555555555556
    prediction  label  check    prediction_name   label_name
2            3     

In [32]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg/gbt01_at.pkl"
(
    trained_gbt_at, 
    predictions_gbt_at, 
    accuracy_gbt_at, 
    df_gbt_at, 
    df_incorrect_gbt_at, 
    df_correct_gbt_at,
    label_map_gbt_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_gbt, 
    label_map_at, 
    save_path
)
print(accuracy_gbt_at)
print(df_incorrect_gbt_at.head())

      Iter       Train Loss   Remaining Time 
         1           3.4795           80.16m
         2           3.3319           78.81m
         3           3.2064           77.83m
         4           3.1075           76.90m
         5           3.0216           76.05m
         6           2.9358           75.25m
         7           2.8590           74.43m
         8           2.7875           73.82m
         9           2.7192           72.83m
        10           2.6562           71.93m
        20           2.1771           63.46m
        30           1.8566           55.11m
        40           1.6202           47.65m
        50           1.4306           39.94m
        60           1.2735           32.61m
        70           1.1447           24.50m
        80           1.0303           16.25m
        90           0.9321            8.08m
       100           0.8484            0.00s
0.31666666666666665
   prediction  label  check      prediction_name     label_name
0          34  