In [17]:
import pandas as pd
from nebula.data.yg_ar.setup_data_image_hard import read_data
from nebula.common import to_scale_one, write_pickle, read_pickle
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
import os
import os.path as osp
import numpy as np

In [18]:
def create_label_map(labels):
    label_set = set()
    for lt in labels:
        label_set.add(lt)
        
    label_set = list(label_set)
    label_set.sort()

    label_map = {}
    count = 0
    for l in label_set:
        label_map[l] = count
        count += 1
        
    return label_map

In [21]:
df_path = 'C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/bg_medium_ds_sift.pkl' ## corect SIFt df
df, train_df, test_df, valid_df = read_pickle(df_path)

In [22]:
print(df.loc[0])

image        [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...
label_a                                                  camel
label_at                                               camel_1
file_name    camel_1_hair_0_cloth_0_pants_2_Z1005_XOP5_YOP1...
Name: 0, dtype: object


In [23]:
label_map_a = create_label_map(df["label_a"])
label_map_at = create_label_map(df["label_at"])

In [24]:
train_x = train_df["image"].to_list()

In [25]:
train_y_a = train_df["label_a"].map(label_map_a).to_list()
train_y_at = train_df["label_at"].map(label_map_at).to_list()

In [26]:
test_x = test_df["image"].to_list()

In [27]:
test_y_a = test_df["label_a"].map(label_map_a).to_list()
test_y_at = test_df["label_at"].map(label_map_at).to_list()

In [28]:
def train_svm(data_x, data_y):
    clf = svm.SVC(max_iter=50)
    clf.fit(data_x, data_y)
    return clf


def train_svm_poly8(data_x, data_y):
    clf = svm.SVC(kernel = "poly", degree = 8, C=20)
    clf.fit(data_x, data_y)
    return clf


def train_logistic(data_x, data_y):
    clf = LogisticRegression(random_state=0)
    clf.fit(data_x, data_y)
    return clf


def train_gbt(data_x, data_y):
    clf = GradientBoostingClassifier(
        n_estimators=100,
        learning_rate=0.1,
        max_depth=3,
        random_state=0,
        verbose=1,
        n_iter_no_change=2,
    )
    clf.fit(data_x, data_y)
    return clf


def evaluate(model, test_x, test_y):
    res = model.predict(test_x)
    correct = res == test_y
    accuracy = correct.sum() / len(res)
    return res, accuracy


def create_dirs_to_file(path):
    dirs = "/".join(osp.join(path).split("/")[:-1])
    if not osp.exists(dirs):
        os.makedirs(dirs)


def load_or_train(train_x, train_y, test_x, test_y, train_func, label_map, path):
    
    if osp.exists(path):
        return read_pickle(path)
    
    create_dirs_to_file(path)
    
    trained_model = train_func(train_x, train_y)
    predictions, accuracy = evaluate(trained_model, test_x, test_y)
    
    df, df_incorrect, df_correct = format_results(predictions, test_y, label_map)
    
    write_pickle(path, (trained_model, predictions, accuracy, df, df_incorrect, df_correct, label_map)) 
    
    return trained_model, predictions, accuracy, df, df_incorrect, df_correct, label_map


def format_results(predictions, labels, label_map):
    df = pd.DataFrame(
        data={
            "prediction": predictions,
            "label": labels
        }
    )
    df["check"] = df["prediction"] == df["label"]

    label_map_reverse = {v:k for k, v in label_map.items()}

    df["prediction_name"] = df.prediction.map(label_map_reverse)
    df["label_name"] = df.label.map(label_map_reverse)

    df_incorrect = df[~df.check]
    df_correct = df[df.check]

    return df, df_incorrect, df_correct

In [30]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg_sift/svm_a.pkl"
(
    trained_svm_a, 
    predictions_svm_a, 
    accuracy_svm_a, 
    df_svm_a, 
    df_incorrect_svm_a, 
    df_correct_svm_a,
    label_map_svm_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_svm, 
    label_map_a, 
    save_path
)
print(accuracy_svm_a)
print(df_incorrect_svm_a.head())



0.3337962962962963
   prediction  label  check prediction_name         label_name
1           0      3  False           camel  lord_of_the_dance
2           0      3  False           camel  lord_of_the_dance
3           0      3  False           camel  lord_of_the_dance
4           1      3  False           chair  lord_of_the_dance
7           0      3  False           camel  lord_of_the_dance


In [31]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg_sift/svm_at.pkl"
(
    trained_svm_at, 
    predictions_svm_at, 
    accuracy_svm_at, 
    df_svm_at,
    df_incorrect_svm_at, 
    df_correct_svm_at,
    label_map_svm_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_svm, 
    label_map_at, 
    save_path
)
print(accuracy_svm_at)
print(df_incorrect_svm_at.head())



0.22453703703703703
   prediction  label  check      prediction_name           label_name
0          12     14  False  lord_of_the_dance_1  lord_of_the_dance_3
1          12     14  False  lord_of_the_dance_1  lord_of_the_dance_3
2          39     14  False         warrior_II_4  lord_of_the_dance_3
4          12     14  False  lord_of_the_dance_1  lord_of_the_dance_3
5          39     14  False         warrior_II_4  lord_of_the_dance_3


In [32]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg_sift/logistic_a.pkl"
(
    trained_logistic_a, 
    predictions_logistic_a, 
    accuracy_logistic_a, 
    df_logistic_a,
    df_incorrect_logistic_a, 
    df_correct_logistic_a,
    label_map_logistic_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_logistic, 
    label_map_a, 
    save_path
)
print(accuracy_logistic_a)
print(df_incorrect_logistic_a.head())

0.6125
    prediction  label  check prediction_name         label_name
2            8      3  False      warrior_II  lord_of_the_dance
5            8      3  False      warrior_II  lord_of_the_dance
7            1      3  False           chair  lord_of_the_dance
11           6      3  False        triangle  lord_of_the_dance
12           8      3  False      warrior_II  lord_of_the_dance


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [34]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg_sift/logistic_at.pkl"
(
    trained_logistic_at, 
    predictions_logistic_at, 
    accuracy_logistic_at, 
    df_logistic_at,
    df_incorrect_logistic_at, 
    df_correct_logistic_at,
    label_map_logistic_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_logistic, 
    label_map_at, 
    save_path
)
print(accuracy_logistic_at)
print(df_incorrect_logistic_at.head())

0.32731481481481484
   prediction  label  check      prediction_name           label_name
0          12     14  False  lord_of_the_dance_1  lord_of_the_dance_3
1          12     14  False  lord_of_the_dance_1  lord_of_the_dance_3
2          39     14  False         warrior_II_4  lord_of_the_dance_3
3          12     14  False  lord_of_the_dance_1  lord_of_the_dance_3
4          36     14  False         warrior_II_1  lord_of_the_dance_3


In [35]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg_sift/gbt01_a.pkl"
(
    trained_gbt_a, 
    predictions_gbt_a, 
    accuracy_gbt_a, 
    df_gbt_a, 
    df_incorrect_gbt_a, 
    df_correct_gbt_a,
    label_map_gbt_a
)= load_or_train(
    train_x, 
    train_y_a, 
    test_x, 
    test_y_a, 
    train_gbt, 
    label_map_a, 
    save_path
)
print(accuracy_gbt_a)
print(df_incorrect_gbt_a.head())

      Iter       Train Loss   Remaining Time 
         1           2.1219           54.26s
         2           2.0058           52.62s
         3           1.9165           52.27s
         4           1.8428           51.04s
         5           1.7821           50.29s
         6           1.7287           49.36s
         7           1.6821           48.35s
         8           1.6407           48.29s
         9           1.6033           47.72s
        10           1.5698           46.66s
        20           1.3489           41.58s
        30           1.2256           37.40s
        40           1.1425           31.81s
        50           1.0820           26.08s
        60           1.0340           20.98s
        70           0.9949           15.59s
        80           0.9615           10.36s
        90           0.9325            5.19s
       100           0.9061            0.00s
0.6111111111111112
    prediction  label  check prediction_name         label_name
11           6  

In [36]:
save_path = "C:/Users/aphri/Documents/t0002/pycharm/data/yg_ar/classic_models_medium_bg_sift/gbt01_at.pkl"
(
    trained_gbt_at, 
    predictions_gbt_at, 
    accuracy_gbt_at, 
    df_gbt_at, 
    df_incorrect_gbt_at, 
    df_correct_gbt_at,
    label_map_gbt_at
)= load_or_train(
    train_x, 
    train_y_at, 
    test_x, 
    test_y_at, 
    train_gbt, 
    label_map_at, 
    save_path
)
print(accuracy_gbt_at)
print(df_incorrect_gbt_at.head())

      Iter       Train Loss   Remaining Time 
         1           3.4428            3.49m
         2           3.2947            3.58m
         3           3.1801            3.61m
         4           3.0847            3.56m
         5           3.0031            3.52m
         6           2.9329            3.43m
         7           2.8711            3.39m
         8           2.8156            3.37m
         9           2.7647            3.35m
        10           2.7186            3.34m
        20           2.3852            2.95m
        30           2.1754            2.56m
        40           2.0222            2.21m
        50           1.8965            1.85m
        60           1.7914            1.49m
        70           1.7000            1.11m
        80           1.6214           44.86s
        90           1.5494           22.47s
       100           1.4840            0.00s
0.30185185185185187
   prediction  label  check      prediction_name           label_name
0        