In [1]:
import pickle
import numpy as np
from sklearn import svm

In [2]:
def extract(filename, threshold=20):
    '''Extract (feature/label) from filename
    Args:
        filename (str): the filename of feature (*.pkl) from confusionTable_extractFeature
        threshold (int): may use it for append positive sample from feature 
    Return:
        train_feature (np.array): feature array
        train_label (np.array): the label of feature 
            (1 for positive, 2 for negative )
    '''
    
    # Initial 
    with open(filename, 'rb') as fp:
        bigDict = pickle.load(fp)    
    label = list()
    feature = list()

    # idx = 0 
    for error_ch, (cands_val) in bigDict.items():
    #     if idx>50: break
    #     idx += 1
        # two situation for error: (error-pair) or (higher-score)
        for cand, (score, log) in cands_val.items():
            # Feature 
            tmp = [i for i in log[:4]]
            tmp.extend(log[4])
            tmp.append(log[-1])

            feature.append(tmp)

            # Positive case (error)
            if len(log) > 6:
                label.append(1)
            else:
                label.append(2)
    
    # Combine
    train_feature = np.asarray(feature, dtype='float')            
    tmp_label = np.asarray(label, dtype='int')
    pos = np.where(tmp_label == 1 )[0]
    neg = np.where(tmp_label == 2 )[0]
    train_label = {1:pos, 2:neg}        

    print('The number of sample = {}'.format(train_feature.shape))
    print('Positive case (candidate) = {}'.format(len(train_label[1])))
    print('Negative case (uncandidate) = {}'.format(len(train_label[2])))
    
    return (train_feature, train_label)

In [29]:
def train(feature, label, train_cnt, test_cnt=0, cross_validation=False):
    '''Train SVM model from feature
    Args:
        feature (np.array): feature array
        label (np.array): the label of feature 
        train_cnt (int): the number of training samples 
        test_cnt (int): (default=samples-train_cnt) the number of testing samples
    Return:
        xxx    
    '''
    
    assert train_cnt < len(label[1]), 'Train count must less than positive samples'
    
    # If test_cnt not declare, use all the remain set as test set    
    if train_cnt==0:
        assert False, 'sample count cannot be zero'
        
    if test_cnt==0:
        test_cnt = len(label[1]) - train_cnt
        
    # Picke feature/label to train & test set    
    np.random.shuffle(label[1])
    np.random.shuffle(label[2])

    train_idx = np.concatenate(
        (label[1][:train_cnt],
         label[2][:train_cnt]))
    train_label = np.concatenate(
        (np.full(train_cnt, 1, dtype=int), np.full(train_cnt, 2, dtype=int)))
    train_feature = feature[train_idx]

    test_idx = np.concatenate(
        (label[1][train_cnt:train_cnt+test_cnt], 
         label[2][train_cnt:train_cnt+test_cnt]))
    test_label = np.concatenate(
        (np.full(test_cnt, 1, dtype=int), np.full(test_cnt, 2, dtype=int)))
    test_feature = feature[test_idx]

    # Training 
    clf = svm.SVC(kernel='rbf',C=5, probability=True)
    clf.fit(train_feature,train_label)
    
    # Testing 
    print(clf.score(train_feature, train_label))
    accuracy = clf.score(test_feature, test_label)
    
    output = [accuracy, train_cnt, test_cnt*2]
#     return output
    return clf

In [27]:
if __name__ == '__main__':
    filename = 'confu.pkl'
    
    (feature, label) = extract(filename)
    output = train(feature, label, 1000, 5000)

The number of sample = (3949878, 7)
Positive case (candidate) = 11465
Negative case (uncandidate) = 3938413
0.9205


In [30]:
clf = train(feature, label, 1000, 5000)

0.9285


In [53]:
clf.predict_log_proba([4,0,0,0,0,0,6])



array([[-0.46991276, -0.98098072]])

In [52]:
clf.predict([4,0,0,0,0,0,6])



array([1])

In [19]:
train_set = [3,5,10,20,50,100,200,300,500,700,1000,1500,2000,3000,5000,7000,10000]

In [10]:
performance

[[0.7985953585761647, 3, 22924],
 [0.74864746945898775, 5, 22920],
 [0.81335661283282412, 10, 22910],
 [0.86845784185233732, 20, 22890],
 [0.86141042487954445, 50, 22830],
 [0.84879014518257812, 100, 22730],
 [0.86977363515312911, 200, 22530],
 [0.88571428571428568, 300, 22330],
 [0.89238486092111258, 500, 21930],
 [0.90204366000928937, 700, 21530],
 [0.89961777353081696, 1000, 20930],
 [0.90280983442047169, 1500, 19930],
 [0.90417326994189118, 2000, 18930],
 [0.90265800354400472, 3000, 16930],
 [0.90100541376643462, 5000, 12930],
 [0.90179171332586783, 7000, 8930],
 [0.90307167235494878, 10000, 2930]]

In [25]:
performance = []
for t in train_set:
    performance.append(train(feature, label, t))

1.0
0.8
1.0
1.0
0.96
0.89
0.94
0.916666666667
0.905
0.913571428571
0.9035
0.907666666667
0.912
0.912
0.9057
0.910071428571
0.9086


In [17]:
performance

[[0.69499214796719599, 3, 22924],
 [0.78630017452006984, 5, 22920],
 [0.81981667394151025, 10, 22910],
 [0.82629969418960247, 20, 22890],
 [0.80779675865089795, 50, 22830],
 [0.87817861856577206, 100, 22730],
 [0.8917887261429206, 200, 22530],
 [0.88392297357814598, 300, 22330],
 [0.89648882808937524, 500, 21930],
 [0.90102183000464464, 700, 21530],
 [0.90167224080267561, 1000, 20930],
 [0.90697441043652782, 1500, 19930],
 [0.90433174854727949, 2000, 18930],
 [0.90472533963378621, 3000, 16930],
 [0.90402165506573862, 5000, 12930],
 [0.90817469204927215, 7000, 8930],
 [0.90409556313993178, 10000, 2930]]

In [21]:
performance

[[0.76094922352120054, 3, 22924],
 [0.79550610820244327, 5, 22920],
 [0.84718463553033607, 10, 22910],
 [0.809916994320664, 20, 22890],
 [0.88589575120455544, 50, 22830],
 [0.87338319401671805, 100, 22730],
 [0.88211273857079453, 200, 22530],
 [0.8838781907747425, 300, 22330],
 [0.89758321933424534, 500, 21930],
 [0.90315838365071988, 700, 21530],
 [0.90243669374104152, 1000, 20930],
 [0.90260913196186654, 1500, 19930],
 [0.90385631273111466, 2000, 18930],
 [0.90626107501476671, 3000, 16930],
 [0.90208816705336425, 5000, 12930],
 [0.90190369540873461, 7000, 8930],
 [0.90546075085324229, 10000, 2930]]

In [12]:
with open('./confusionTable/SVMacc_setNumber.csv', 'w', encoding='utf8') as wp:
    wp.write('Train_set,Test_set,Acc\n')
    for i in performance:
        wp.write('{p[1]},{p[2]},{p[0]}\n'.format(p=i))