In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install 'tensorflow_text'

In [None]:
import json
import textwrap

import pandas as pd
import tensorflow_hub as hub
import numpy as np
import tensorflow_text
from sklearn.model_selection import train_test_split
import tensorflow as tf
import matplotlib.pyplot as plt
from tabulate import tabulate

In [None]:
tr_data = pd.read_csv("../input/dodiom-dataset/tr_corpus_second_run.csv", converters={
    'idiom_indices': eval,
    'idiom_words': eval,
    'lemmas': eval,
    'words': eval
})
tr_data = tr_data[tr_data.likes + tr_data.dislikes + tr_data.reports >= 0]
tr_data = tr_data[tr_data.rating >= 0.0]
tr_data = tr_data.reset_index()
print(len(tr_data))
print(tr_data.columns)

In [None]:
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-multilingual/3")

In [None]:
%%time
idiom_embeds = {idiom: embed([idiom])[0].numpy() for idiom in tr_data.idiom.unique()}
sentence_embeds = embed([x for x in tr_data.submission])
sentence_idiom_embeds = embed([" ".join(x) for x in tr_data.idiom_words])

In [None]:
def is_seperate(arr):
    for i in range(len(arr) - 1):
        if arr[i+1] != arr[i] + 1:
            return False
    return True

In [None]:
tr_data["embed"] = [list(np.concatenate((
    sentence_embeds[ix],
    idiom_embeds[idiom],
    sentence_idiom_embeds[ix],
    np.array([is_seperate([tr_data.iloc[0].idiom_indices])], dtype=float)
    ))) for ix, idiom in enumerate(tr_data.idiom)]

In [None]:
#X = np.array([np.concatenate((sentence_embeds[ix], idiom_embeds[idiom], sentence_idiom_embeds[ix])) for ix, idiom in enumerate(tr_data.idiom)])
X = np.arange(len(tr_data))
y = np.array([int(x == "idiom") for x in tr_data.category])

In [None]:
print(X.shape, y.shape)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True)

df_train = tr_data.iloc[X_train]
df_test = tr_data.iloc[X_test]

X_train = np.stack(df_train.embed, axis=0)
X_test = np.stack(df_test.embed, axis=0)

print(len([x for x in df_train.idiom.unique()]), "idioms")

In [None]:
X_train.shape, y_train.shape

In [None]:
df_data = pd.DataFrame({
    "idiom": [x for x in tr_data.idiom.unique()],
    "train pos": [len(df_train[(df_train.idiom == x) & (df_train.category == "idiom")]) for x in tr_data.idiom.unique()],
    "train neg": [len(df_train[(df_train.idiom == x) & (df_train.category == "nonidiom")]) for x in tr_data.idiom.unique()],
    "test pos": [len(df_test[(df_test.idiom == x) & (df_test.category == "idiom")]) for x in tr_data.idiom.unique()],
    "test neg": [len(df_test[(df_test.idiom == x) & (df_test.category == "nonidiom")]) for x in tr_data.idiom.unique()],
})
df_data

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(1, activation="sigmoid")
])

earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
model.compile(optimizer=opt,
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=['accuracy'])
hist = model.fit(X_train, y_train,
                 validation_data=(X_test, y_test),
                 batch_size=32, 
                 epochs=100,
                 callbacks=[earlystopping])

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), edgecolor="black", facecolor="white")

ax1.plot(hist.history["accuracy"], label="train accuracy")
ax2.plot(hist.history["loss"], label="train loss")

ax1.plot(hist.history["val_accuracy"], label="test accuracy")
ax2.plot(hist.history["val_loss"], label="test loss")

ax1.set_ylim(0, 1.1)
ax1.set_xlabel("Epoch")
ax2.set_xlabel("Epoch")
ax1.set_ylabel("Accuracy")
ax2.set_ylabel("Loss")
ax1.legend(loc="best")
ax2.legend(loc="best")

plt.tight_layout()
plt.show()

In [None]:
idioms = []
poss = []
negs = []
tps = []
fps = []
tns = []
fns = []
precisions = []
recalls = []
f1s = []

misclass = []

def wrap(text: str) -> str:
    return "\n".join(textwrap.wrap(text))
    

for idiom in tr_data.idiom.unique():
    if len(df_test[df_test.idiom == idiom].embed) == 0:
        continue
    
    idioms.append(idiom)

    X = np.stack(df_test[df_test.idiom == idiom].embed, axis=0)
    y = np.array([int(x == "idiom") for x in df_test[df_test.idiom == idiom].category])
    out = model(X).numpy().flatten()

    outr = np.array(np.rint(out), dtype=int)
    for ix, yi in enumerate(y):
        if yi != outr[ix]:
            item = df_test[df_test.idiom == idiom].iloc[ix]
            misclass.append([len(misclass) + 1, item.idiom, wrap(item.submission), item.category, out[ix]])

    tpm = tf.keras.metrics.TruePositives()
    tpm.update_state(y, out)
    tp = int(tpm.result().numpy())
    tps.append(tp)

    fpm = tf.keras.metrics.FalsePositives()
    fpm.update_state(y, out)
    fp = int(fpm.result().numpy())
    fps.append(fp)

    tnm = tf.keras.metrics.TrueNegatives()
    tnm.update_state(y, out)
    tn = int(tnm.result().numpy())
    tns.append(tn)

    fnm = tf.keras.metrics.FalseNegatives()
    fnm.update_state(y, out)
    fn = int(fnm.result().numpy())
    fns.append(fn)

    poss.append(tp + fn)
    negs.append(fp + tn)

    if tp + fp != 0:
        precision = tp / (tp + fp)
    else:
        precision = 0
    precisions.append(precision)

    if tp + fn != 0:
        recall = tp / (tp + fn)
    else:
        recall = 0
    recalls.append(recall)

    if precision + recall != 0:
        f1_score = 2 * ((precision * recall) / (precision + recall))
    else:
        f1_score = 0
    f1s.append(f1_score)

print(tabulate(misclass, headers=["index", "Idiom", "Submission", "Category" , "Prediction"]))

In [None]:
eval_data = pd.DataFrame({
    "idiom": idioms,
    "POS": poss,
    "NEG": negs,
    "TP": tps,
    "FP": fps,
    "TN": tns,
    "FN": fns,
    "Precision": precisions,
    "Recall": recalls,
    "F1": f1s
})
eval_data.sort_values("F1", ascending=False).reset_index()

In [None]:
print(f"Macro avg. F1: {eval_data.F1.mean():.3f}")