# Apyori Analysis
對最佳預測結果進行錯誤類別的關聯分析，檢查錯誤類別的混淆情況

In [8]:
import pandas as pd
from apyori import apriori

In [10]:
path = '..'

target = F"{path}/outputs/train/BiT-M-R50x1_batch256_lr0.0003_fullset.pkl"

## Analysis

In [16]:
data = pd.read_pickle(target)
data = data[data['set_name'].eq('test')].reset_index(drop=True)
data.tail()

Unnamed: 0,file,label,set_name,predict,predict_label
13449,d41c1fe0-1775-4940-835a-7e18ae079267.jpg,greenonion,test,"[7.358785e-05, 0.00015296991, 0.000106066844, ...",greenonion
13450,9e318bca-3186-4458-8ed9-b1fc53cc2952.jpg,greenonion,test,"[0.00018206096, 0.00028757274, 0.00016238686, ...",greenonion
13451,d2ff6ca0-7ad0-48a3-beb9-f168d527f277.jpg,greenonion,test,"[0.00045041938, 0.0018955074, 0.0033849792, 0....",greenonion
13452,74af1a5a-c258-49de-bfbe-1ccc659d5d80.jpg,greenonion,test,"[0.0006068007, 6.7271336e-05, 0.0002129048, 0....",greenonion
13453,495e470a-5add-4a3e-8541-7fb97214aa8c.jpg,greenonion,test,"[0.00016374468, 0.00038540628, 0.00019040132, ...",greenonion


In [14]:
association_rules = list(apriori(data[['label', 'predict_label']].values, min_support=0.003, min_lift = 1.1))
association_rules = pd.DataFrame(association_rules).sort_values(by=['support'], ascending=False, ignore_index=True)
association_rules

Unnamed: 0,items,support,ordered_statistics
0,"(cauliflower, broccoli)",0.010331,"[((broccoli), (cauliflower), 0.228995057660626..."
1,"(onion, greenonion)",0.004831,"[((greenonion), (onion), 0.11796733212341198, ..."
2,"(longan, litchi)",0.00446,"[((litchi), (longan), 0.1466992665036675, 4.39..."
3,"(kale, cauliflower)",0.00327,"[((cauliflower), (kale), 0.10185185185185185, ..."
4,"(lettuce, chinesecabbage)",0.003122,"[((chinesecabbage), (lettuce), 0.1105263157894..."


In [51]:
association_rules.to_csv(F"{path}/association_rules.csv", index=False)

## Split training data by Label

In [17]:
# Image size for model
image_size = 224

# options: (subset, fullset)
train_mode = 'fullset'

# options: (crop70xy, correct_by_rule, crop70xy_correct_by_rule)
image_preprocessing_method ="crop70xy_correct_by_rule"

In [4]:
association_rules = pd.read_csv(F"{path}/association_rules.csv")
association_rules['items'] = association_rules['items'].apply(lambda x: x.replace("frozenset({", "").replace("})", "").replace("'", "").split(", "))
association_rules

Unnamed: 0,items,support,ordered_statistics
0,"[broccoli, cauliflower]",0.010331,[OrderedStatistic(items_base=frozenset({'brocc...
1,"[greenonion, onion]",0.004831,[OrderedStatistic(items_base=frozenset({'green...
2,"[litchi, longan]",0.00446,[OrderedStatistic(items_base=frozenset({'litch...
3,"[kale, cauliflower]",0.00327,[OrderedStatistic(items_base=frozenset({'cauli...
4,"[chinesecabbage, lettuce]",0.003122,[OrderedStatistic(items_base=frozenset({'chine...


In [5]:
train = pd.read_pickle(F'{path}/data/{train_mode}_train_vector_{image_size}_{image_preprocessing_method}.pkl')
valid = pd.read_pickle(F'{path}/data/{train_mode}_valid_vector_{image_size}_{image_preprocessing_method}.pkl')
test = pd.read_pickle(F'{path}/data/{train_mode}_test_vector_{image_size}_{image_preprocessing_method}.pkl')
print(F"train: {train.shape}, valid: {valid.shape}, test: {test.shape}")

train: (62670, 22), valid: (13390, 22), test: (13454, 22)


In [6]:
for row in association_rules.itertuples():
    sub_train = train[train['label'].isin(row.items)].reset_index(drop=True)
    sub_valid = valid[valid['label'].isin(row.items)].reset_index(drop=True)
    sub_test = test[test['label'].isin(row.items)].reset_index(drop=True)
    sub_train.to_pickle(F'{path}/data/binary/{train_mode}_train_vector_{image_size}_{image_preprocessing_method}_{row.items[0]}_{row.items[1]}.pkl', protocol=4)
    sub_valid.to_pickle(F'{path}/data/binary/{train_mode}_valid_vector_{image_size}_{image_preprocessing_method}_{row.items[0]}_{row.items[1]}.pkl', protocol=4)
    sub_test.to_pickle(F'{path}/data/binary/{train_mode}_test_vector_{image_size}_{image_preprocessing_method}_{row.items[0]}_{row.items[1]}.pkl', protocol=4)