In [None]:
import os
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import pickle
import time
import numpy as np
import copy
from module.confprx import build_classifier_model, get_onehot_label, get_token, get_init_data, updata_data, top_k_recall, get_informativeness
from module.prxsampler import proxy, sample_pooling

In [None]:
pkl_p = 'data/processed_data.pkl'
with open(pkl_p, 'rb') as r:
    ds = pickle.load(r)
    
all_inp,all_label,all_encode = ds['train']['inp'],ds['train']['label'],ds['train']['enc']
validation_inp, validation_label, validation_encode = ds['validation']['inp'], ds['validation']['label'], ds['validation']['enc']
test_inp, test_label, test_encode = ds['test']['inp'], ds['test']['label'], ds['test']['enc']

def test_acc(model,data_ds,label,batch_size=64):
    predict = model.predict(data_ds,verbose=1)
    y_pred = predict[:,:]
    res = tf.keras.metrics.categorical_accuracy(label, y_pred)
    total = len(res)
    correct = np.sum(res)
    acc = 100.0*correct/total
    return acc

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
tb_p = 'model-saved/tb/0/'
model_p = 'model-saved/cls_0.hdf5'
if not os.path.exists(os.path.dirname(tb_p)):
    os.makedirs(os.path.dirname(tb_p))
chechpoint = ModelCheckpoint(model_p, verbose=1, save_best_only=True, period=1, monitor='val_categorical_accuracy')
tensorboard = TensorBoard(log_dir=tb_p)
callbacks = [chechpoint, tensorboard]


ptb_p = 'model-saved/tb-prx/0/'
pmodel_p = 'model-saved/prx_0.hdf5'
if not os.path.exists(os.path.dirname(ptb_p)):
    os.makedirs(os.path.dirname(ptb_p))
pchechpoint = ModelCheckpoint(pmodel_p, verbose=1, save_best_only=True, period=1, monitor='val_mae')
ptensorboard = TensorBoard(log_dir=ptb_p)
pcallbacks = [pchechpoint, ptensorboard]

In [None]:
# initial random train 
train_data, pool_data = get_init_data(all_inp, all_label, all_encode, 50)
cls_model = build_classifier_model()
opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
cls_batch_size = 32
cls_model.compile(optimizer=opt, loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()])
cls_model.fit(x=train_data['inp'], y=train_data['lab'], epochs=5, batch_size=cls_batch_size, validation_data=(test_inp,test_label),callbacks=callbacks)

In [None]:
val_info_lab = get_informativeness(cls_model, validation_inp)
tst_info_lab = get_informativeness(cls_model, test_inp)
proxy_little_num = -1
prx1 = proxy()
opt2 = tf.keras.optimizers.RMSprop(learning_rate=0.0001)
prx1.compile(optimizer=opt2, loss='binary_crossentropy', metrics=['mae'])
batch_size = 32
prx1.fit(x=validation_encode[:proxy_little_num], y=val_info_lab[:proxy_little_num], epochs=50, batch_size=batch_size, validation_data=(test_encode,tst_info_lab))

In [None]:
# save data: 
names = locals()
Data_p = 'model-saved/data_OP.pkl'
Tensorb_p = 'model-saved/tb/OP/'
Model_p = 'model-saved/cls_OP.hdf5'
prx_Tensorb_p = 'model-saved/tb-prx/OP/'
prx_Model_p = 'model-saved/prx_OP.hdf5'

W = 2
q = 0.2
N0 = 15

num_rounds = 20
sample_num = 50
time_cls_train,time_prx_select = [],[]
time_proxy_train = []
pool_len_l = []
cls_acc_l, prx_top_sample_num_recall_l,prx_top_half_recall_l, trn_prx_top_half_recall_l, trn_prx_top_sample_num_recall_l  = [],[],[],[],[]
cls_acc=test_acc(cls_model,test_inp,test_label)
cls_acc_l.append(cls_acc)
print(f'=====================================\n   cls_acc = {cls_acc}')



for i in range(num_rounds):
    i += 1
    # ------  save MODEL -------
    model_p = Model_p.replace('OP',str(i))
    tb_p = Tensorb_p.replace('OP',str(i))
    if not os.path.exists(os.path.dirname(tb_p)):
        os.makedirs(os.path.dirname(tb_p))
    chechpoint = ModelCheckpoint(model_p, verbose=1, save_best_only=True, period=1, monitor='val_categorical_accuracy')
    tensorboard = TensorBoard(log_dir=tb_p)
    callbacks = [chechpoint, tensorboard]

    pmodel_p = prx_Model_p.replace('OP',str(i))
    ptb_p = prx_Tensorb_p.replace('OP',str(i))
    if not os.path.exists(os.path.dirname(ptb_p)):
        os.makedirs(os.path.dirname(ptb_p))
    pchechpoint = ModelCheckpoint(pmodel_p, verbose=1, save_best_only=True, period=1, monitor='val_mae')
    ptensorboard = TensorBoard(log_dir=ptb_p)
    pcallbacks = [pchechpoint, ptensorboard]
    
    select_proxy = proxy()
    select_proxy.set_weights(prx1.get_weights())  
        
    window_len = len(pool_data['inp']) // W
    delete_num = int(len(pool_data['inp']) * q)//W
    select_time = 0
    for w in range(W):
        print(f'!!!!!! {w}')
        if w != W-1:
            sample_pool_num = sample_num//W
            begin = w*window_len
            end = (w+1)*window_len
        else:
            sample_pool_num = sample_num-w*(sample_num//W)
            begin = w*window_len
            end = len(pool_data['inp'])
        if w == 0:
            train_idx,train_inp,train_lab,train_enc = copy.deepcopy(train_data['idx']),copy.deepcopy(train_data['inp']),copy.deepcopy(train_data['lab']),copy.deepcopy(train_data['enc'])
            train_idx0,train_inp0,train_lab0,train_enc0 = copy.deepcopy(train_data['idx']),copy.deepcopy(train_data['inp']),copy.deepcopy(train_data['lab']),copy.deepcopy(train_data['enc'])
            
        w_pool_idx, w_pool_inp, w_pool_lab, w_pool_enc, train_idx, train_inp, train_lab, train_enc, pipe_train_idx, pipe_train_inp, pipe_train_lab, pipe_train_enc, select_time = sample_pooling(i,N0,w,W, delete_num, select_proxy, pool_data, sample_pool_num, begin, end, select_time, train_idx, train_inp, train_lab, train_enc, train_idx0, train_inp0, train_lab0, train_enc0)     
        if w == 0 :
            pool_idx, pool_inp, pool_lab, pool_enc = copy.deepcopy(w_pool_idx), copy.deepcopy(w_pool_inp), copy.deepcopy(w_pool_lab), copy.deepcopy(w_pool_enc)
        else:
            pool_idx, pool_inp, pool_lab, pool_enc = np.hstack((pool_idx, w_pool_idx)), np.hstack((pool_inp, w_pool_inp)), np.vstack((pool_lab, w_pool_lab)), np.vstack((pool_enc, w_pool_enc))  

        tc1 = time.time()
        ch = cls_model.fit(x=pipe_train_inp, y=pipe_train_lab, epochs=5, batch_size=cls_batch_size, validation_data=(test_inp,test_label),callbacks=callbacks)
        tc2 = time.time()
        timect = tc2 - tc1
        cls_acc=test_acc(cls_model,test_inp,test_label)
        
    print(f" ---  round {i}, len(pool_idx) --- : {len(pool_idx)}")   
    pool_len_l.append(len(pool_idx))
    new_train_d = { 'idx': train_idx, 'inp': train_inp,'lab': train_lab,'enc': train_enc}
    new_pool_d = {'idx': pool_idx, 'inp': pool_inp, 'lab': pool_lab, 'enc': pool_enc}        
        
    train_data, pool_data = copy.deepcopy(new_train_d), copy.deepcopy(new_pool_d)
    
    ## evaluate :
    cls_acc = test_acc(cls_model,test_inp,test_label)
    cls_acc_l.append(cls_acc)
    time_cls_train.append(timect)
    time_prx_select.append(select_time)
    
    #---- informativeness update， proxy model update 
    val_info_lab = get_informativeness(cls_model, validation_inp)
    tst_info_lab = get_informativeness(cls_model, test_inp)
    tp3 = time.time()
    ph = prx1.fit(x=validation_encode[:proxy_little_num], y=val_info_lab[:proxy_little_num], epochs=20, batch_size=batch_size, validation_data=(test_encode,tst_info_lab))
    tp4 = time.time()
    timept = tp4 - tp3

print('cls_acc_l = ', cls_acc_l)
print('time_cls_train = ' , time_cls_train)
print('time_proxy_select = ' , time_prx_select)
print('time_proxy_train = ' , time_proxy_train)