<a href="https://colab.research.google.com/github/supplient/bachelor_design/blob/equal_sif_freq/EqualTrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Environment Prepare

In [None]:
!pip install keras_bert

In [None]:
!git clone https://github.com/supplient/bachelor_design.git
import os
os.chdir("bachelor_design")
!pwd

In [None]:
!git checkout equal_sif_freq
!git pull

In [1]:
from driver_amount import addh

[Locale] Using address head: /mnt/d/My Drive


# Data Prepare

Origin Data

In [2]:
from cut_and_tag import load_stopwords, cut_and_remove_stopwords, cut_and_tag
from preprocess import seq2str
import config
import json

In [3]:
cut_seqs, char_seqs, tag_seqs = cut_and_tag(
    addh + config.DATA_PATH, 
    addh + config.STOPWORDS_PATH
)

equal_strs = None
with open(addh + config.EQUAL_DATA_PATH, "r") as fd:
    equal_strs = json.load(fd)
    
stopwords = load_stopwords(addh + config.STOPWORDS_PATH)
equal_cut_seqs = []
equal_seqs = []
for equal_str in equal_strs:
    equal_cut = cut_and_remove_stopwords(equal_str, stopwords)
    equal_cut_seqs.append(equal_cut)
    
    equal_seq = []
    for w in equal_cut:
        for c in w:
            equal_seq.append(c)
    equal_seqs.append(equal_seq)
    
origin_seqs = char_seqs[:len(equal_seqs)]
origin_cut_seqs = cut_seqs[:len(equal_seqs)]

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 1.176 seconds.
Prefix dict has been built successfully.


# Param Load

In [4]:
import json
params = None
with open(addh + config.EQUAL_PARAM_PATH, "r") as fd:
    params = json.load(fd)

# Train

Since our embedding is designed for batch work, it should be better if we combine origin_seqs and equal_seqs.

In [5]:
origin_num = len(origin_seqs)
all_char_seqs = []
all_char_seqs.extend(origin_seqs)
all_char_seqs.extend(equal_seqs)
all_cut_seqs = []
all_cut_seqs.extend(origin_cut_seqs)
all_cut_seqs.extend(equal_cut_seqs)

In [None]:
from char_emb import CharEmbedder
from SIF import SIF
from dist_cal import DistCal
from tqdm.notebook import trange, tqdm

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [None]:
train_rec = {}
# sif = SIF(params["sif_alpha"])

emb_method = "sum_last_four"
dist_method = "cos"
dist_theta = params.get(dist_method + "_theta", None)

In [None]:
# Char Embed
char_embedder = CharEmbedder()
all_char_emb_seqs = char_embedder.embed(all_char_seqs, emb_method)

In [None]:
alpha_precision_list = [0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001]

Train SIF's alpha

In [None]:
for alpha_precision in tqdm(alpha_precision_list, desc="Alpha precision"):
    for x in trange(1, 10, desc="Alpha scale"):
        alpha = alpha_precision * x
        
        # Compose sentence vectors
        sif = SIF(alpha)
        all_sen_vecs = sif.compose(all_cut_seqs, all_char_seqs, all_char_emb_seqs)
        
        # Split
        origin_sen_vecs = all_sen_vecs[:origin_num]
        equal_sen_vecs = all_sen_vecs[origin_num:]
        if len(origin_sen_vecs) != len(equal_sen_vecs):
            raise Exception("Length should be the same.")

        # Init distance calculater
        dist_cal = DistCal(all_sen_vecs)
        
        # Train theta
        ## Set train params
        epoch = 10000
        delta = None
        min_delta = 10**(-7)
        
        ## Use two experiments
        ## * Check whether origin and equal are similiar
        ## * Check whether origins are different
        N = len(origin_sen_vecs)
        similiar_total = N # TP + FN
        different_total = N * (N -1)/2 # FP + TN
        
        similiar_count = 0 # TP
        different_count = 0 # TN
        last_delta = 0
        epoch_range = range(epoch)
        with tqdm(epoch_range, desc="Train theta", leave=False) as epoch_tqdm:
            for epoch_count in epoch_range:
                # Do experiments
                similiar_count = 0
                for origin_sen, equal_sen in zip(origin_sen_vecs, equal_sen_vecs):
                    dist = dist_cal.cal(origin_sen, equal_sen, dist_method)
                    if dist_theta == None: # When dist_theta is not set, take the first dist as its initial value
                        dist_theta = dist
                    if dist < dist_theta:
                        similiar_count += 1

                different_count = 0
                for i in range(N-1):
                    for j in range(i+1, N):
                        dist = dist_cal.cal(origin_sen_vecs[i], origin_sen_vecs[j], dist_method)
                        if dist > dist_theta:
                            different_count += 1

                similiar_rate = similiar_count / similiar_total
                different_rate = different_count / different_total

                # Finetune theta
                if delta == None:
                    delta = dist_theta/10
                  
                now_delta = 0
                if similiar_rate > different_rate:
                    now_delta = -delta
                elif similiar_rate < different_rate:
                    now_delta = delta
                else:
                    now_delta = -last_delta

                if now_delta == -last_delta:
                    delta /= 10
                    now_delta /= 10
                dist_theta += now_delta
                last_delta = now_delta

                if dist_theta <= 0:
                    dist_theta = 0

                # Update progroess info
                epoch_tqdm.set_description("Train theta-Epoch %i" % epoch_count)
                epoch_tqdm.set_postfix(
                    similiar_rate=similiar_rate, 
                    different_rate=different_rate,
                    dist_theta=dist_theta,
                    delta=delta
                )
                epoch_tqdm.update(epoch_count+1)

                # Stop when delta's precision is enough
                if delta <= min_delta:
                    epoch_tqdm.update(epoch_range[-1] + 1)
                    break
        
        # Cache train records
        if not train_rec.get(alpha_precision, None):
            train_rec[alpha_precision] = {}
        train_rec[alpha_precision][x] = {
            "TP": similiar_count,
            "FP": different_total - different_count,
            "FN": similiar_total - similiar_count,
            "TN": different_count,
            "dist_theta": dist_theta
        }

# Save train record

In [None]:
with open(addh + config.EQUAL_SIF_TRAIN_REC_PATH, "w") as fd:
    json.dump(train_rec, fd, indent=4)