# Environment Prepare

In [1]:
!pip install keras_bert
!pip install prettytable

In [2]:
from driver_amount import addh

[Locale] Using address head: /mnt/d/My Drive


In [3]:
!git clone https://github.com/supplient/bachelor_design.git
import os
os.chdir("bachelor_design")
!pwd

In [4]:
!git checkout equal_realize
!git pull

# Data Prepare

Origin Data

In [5]:
from preprocess import load_file
import config
import json

In [6]:
char_seqs, tag_seqs = load_file(addh + config.DATA_PATH)
equal_seqs = None
with open(addh + config.EQUAL_DATA_PATH, "r") as fd:
    equal_seqs = json.load(fd)
origin_seqs = char_seqs[:len(equal_seqs)]

# Param Load

In [7]:
import json
params = None
# TODO the better way to do this is 
#    check whether the file exists first,
#    then create it and set default params when it does not exist.
params = {}
with open(addh + config.EQUAL_PARAM_PATH, "r") as fd:
    params = json.load(fd)
    
sif_alpha = params["sif_alpha"]

cos_theta = params["cos_theta"]

dist_theta = cos_theta

# Train

Since our embedding is designed for batch work, it should be better if we combine origin_seqs and equal_seqs.

In [8]:
origin_num = len(origin_seqs)
all_char_seqs = []
all_char_seqs.extend(origin_seqs)
all_char_seqs.extend(equal_seqs)

In [None]:
from char_emb import CharEmbedder
from SIF import SIF
from dist_cal import DistCal
from tqdm.notebook import trange, tqdm

In [None]:
train_rec = {}
char_embedder = CharEmbedder()
sif = SIF(sif_alpha)
dist_caler = DistCal(all_char_seqs)

In [None]:
for dist_method in tqdm(DistCal.methods, desc="Distance"):
    dist_theta = params.get(dist_method + "_theta", None)
    for emb_method in tqdm(CharEmbedder.methods, desc="Char embed", leave=False):
        # Embedding
        all_char_emb_seqs = char_embedder.embed(all_char_seqs)
        all_sen_vecs = sif.compose(all_char_seqs, all_char_emb_seqs)
        
        # Split
        origin_sen_vecs = all_sen_vecs[:origin_num]
        equal_sen_vecs = all_sen_vecs[origin_num:]
        if len(origin_sen_vecs) != len(equal_sen_vecs):
            raise Exception("Length should be the same.")
            
        # Train theta
        ## Set train params
        epoch = 10000
        delta = 0.001
        min_delta = 10**(-10)
        
        ## Use two experiments
        ## * Check whether origin and equal are similiar
        ## * Check whether origins are different
        N = len(origin_sen_vecs)
        similiar_total = N # TP + FN
        different_total = N * (N -1)/2 # FP + TN
        
        similiar_count = 0 # TP
        different_count = 0 # TN
        last_delta = 0
        with trange(epoch, desc="Train theta", leave=False) as epoch_tqdm:
            for epoch_count in epoch_tqdm:
                # Do experiments
                similiar_count = 0
                for origin_sen, equal_sen in zip(origin_sen_vecs, equal_sen_vecs):
                    dist = dist_cal.cal(origin_sen, equal_sen)
                    if dist < dist_theta:
                        similiar_count += 1
                    if dist_theta == None: # When dist_theta is not set, take the first dist as its initial value
                        dist_theta = dist

                different_count = 0
                for i in range(N-1):
                    for j in range(i+1, N):
                        dist = dist_cal.cal(origin_sen_vecs[i], origin_sen_vecs[j])
                        if dist > dist_theta:
                            different_count += 1

                similiar_rate = similiar_count / similiar_total
                different_rate = different_count / different_total

                # Finetune theta
                now_delta = 0
                if similiar_rate > different_rate:
                    now_delta = -delta
                elif similiar_rate < different_rate:
                    now_delta = delta
                else:
                    now_delta = -last_delta

                if now_delta == -last_delta:
                    delta /= 10
                    now_delta /= 10
                dist_theta += now_delta
                last_delta = now_delta

                if dist_theta <= 0:
                    raise Exception("dist_theta reach 0")
                if dist_theta >= 1:
                    raise Exception("dist_theta reach 1")

                # Update progroess info
                epoch_tqdm.set_description("Epoch %i" % epoch_count)
                epoch_tqdm.set_postfix(
                    similiar_rate=similiar_rate, 
                    different_rate=different_rate,
                    dist_theta=dist_theta,
                    delta=delta
                )

                # Stop when delta's precision is enough
                if delta <= min_delta:
                    epoch_tqdm.write("Delta reach " + str(min_delta) + ", which is enough")
                    break
        
        # Cache train records
        if not train_rec.get(dist_method, None):
            train_rec[dist_method] = {}
        train_rec[dist_method][emb_method] = {
            "TP": similiar_count,
            "FP": different_total - different_count,
            "FN": similiar_total - similiar_count,
            "TN": different_count,
            "dist_theta": dist_theta
        }

# Analyze train record

In [3]:
# Variables for drawing table
dist_labels = []
emb_labels = []
table_vals = []

for dist_method in train_rec.keys():
    dist_labels.append(dist_method)
    table_vals.append([])
    for emb_method in train_rec[dist_method].keys():
        emb_labels.append(emb_method)
        rec = train_rec[dist_method][emb_method]
        p = rec["TP"]/(rec["TP"] + rec["FP"])
        r = rec["TP"]/(rec["TP"] + rec["FN"])
        f1 = 2*p*r / (p+r)
        rec["precision"] = p
        rec["recall"] = r
        rec["F1"] = f1
        
        table_vals[-1].append(f1)

In [17]:
from prettytable import PrettyTable
header = ["F1"]
header.extend(emb_labels)
t = PrettyTable(
    field_names=header,
    header=True
)
for i in range(len(dist_labels)):
    row = [dist_labels[i]]
    row.extend(table_vals[i])
    t.add_row(row)
print(t)

+----+-----+-----+
| F1 |  e1 |  e2 |
+----+-----+-----+
| d1 | 0.4 | 0.4 |
+----+-----+-----+


# Save train record

In [None]:
with open(addh + config.EQUAL_TRAIN_REC_PATH, "w") as fd:
    json.dump(train_rec, fd, indent=4)