In [None]:
from __future__ import annotations

In [None]:
from IPython.core.display import display, HTML
display(HTML('<style>.container { width:100% !important; }</style>'))

In [None]:
import os
import datetime
import numpy as np
import tensorflow as tf
import tqdm.notebook as tqdm
import tensorflow_addons as tfa

from tokens import tokens
from gensim.models import Word2Vec
from collections import defaultdict
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

In [None]:
def make_dataset(X, y, batch_size):
    X = tf.data.Dataset.from_tensor_slices(X)
    y = tf.data.Dataset.from_tensor_slices(y)
    
    return tf.data.Dataset.zip((X, y)).shuffle(len(X)).batch(batch_size)

In [None]:
def prepare_data(path, max_length=1000, labels=None, min_count=100):
    X_data, y_data, index = [], [], []
    for i in tqdm.tqdm(open(path)):
        ind, _, tokens, tags = i.split('\t')
        X_data.append(tokens.split()[:max_length])
        y_data.append(tags.split())
        index.append(int(ind))
        
    index = np.array(index)
        
    X_data = tf.keras.preprocessing.sequence.pad_sequences(X_data, maxlen=max_length, truncating='post', padding='post', value=0)
    
    if labels is None:
        labels, counts = np.unique([j for i in y_data for j in i], return_counts=True)
        labels = labels[counts >= min_count]
        
    label_to_id = {label: i for i, label in enumerate(labels)}
    
    for i in range(len(y_data)):
        value = [0] * len(labels)
        for j in y_data[i]:
            if j in label_to_id:
                value[label_to_id[j]] = 1
        y_data[i] = value

    y_data = np.array(y_data)
    
    return X_data, y_data, index, labels, label_to_id

In [None]:
X_data, y_data, index, labels, label_to_id = prepare_data('train/tokenized.txt')

In [None]:
counts = np.zeros(len(labels), dtype=int)
for i in y_data:
    counts += i

In [None]:
for i, j in zip(counts, labels):
    print(f'{i}:\t{j}')

In [None]:
contests = np.unique(index)

In [None]:
rng = np.random.default_rng(42)

In [None]:
train_contests = rng.choice(contests, replace=False, size=int(len(contests) * 0.8) + 1)
test_contests  = np.array([i for i in contests if i not in train_contests])

In [None]:
train_mask = np.isin(index, train_contests)
test_mask  = np.isin(index, test_contests)

In [None]:
train_dataset = make_dataset(X_data[train_mask], y_data[train_mask], 64)
test_dataset = make_dataset(X_data[test_mask], y_data[test_mask], 64)

In [None]:
w2v_model = Word2Vec.load('w2v.model')

In [None]:
emb_size = 128
weights = np.zeros((len(tokens) + 1, emb_size))

for i, token in enumerate(tokens):
    try:
        weights[i + 1] = w2v_model.wv[token]
    except:
        pass

embedding = tf.keras.layers.Embedding(len(tokens) + 1, emb_size, name='token_embedding', mask_zero=True)
embedding.build((None, ))
embedding.set_weights([weights])
embedding.trainable = False

In [None]:
inputs = tf.keras.layers.Input((None, ), dtype=tf.int32, name='token_input')
embedded = embedding(inputs)
dropout = tf.keras.layers.Dropout(0.2, name='embedding_dropout')(embedded)

n_layers = 4
kernels = [3, 5, 7]
layers = []

for k in kernels:
    x = dropout
    n = x.shape[-1]
    for i in range(n_layers):
        x = tf.keras.layers.Conv1D(n, k, activation=tf.keras.activations.swish, padding='same', name=f'conv1d_{k}_{n}')(x)
        x = tf.keras.layers.BatchNormalization(name=f'batch_norm_{k}_{n}')(x)
        n *= 2
        
    x = tf.keras.layers.GlobalMaxPooling1D(name=f'max_pool_{k}')(x)
    layers.append(x)

x = tf.keras.layers.Concatenate(axis=-1, name='pool_concatenate')(layers)
x = tf.keras.layers.Dropout(0.2, name='concatenate_dropout')(x)
x = tf.keras.layers.Dense(units=512, activation=tf.keras.activations.swish, name=f'dense_1')(x)
x = tf.keras.layers.Dense(units=len(labels), activation='sigmoid', name='prediction')(x)
model = tf.keras.models.Model(inputs=inputs, outputs=x, name='multilabel_model')

In [None]:
model.summary()

In [None]:
tf.keras.utils.plot_model(model)

In [None]:
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', tfa.metrics.F1Score(len(labels))])

In [None]:
logdir = os.path.join('logs', datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
    
callbacks = [
    tf.keras.callbacks.TensorBoard(logdir, update_freq=10),
    tf.keras.callbacks.ModelCheckpoint(os.path.join(logdir, 'weights_{epoch}'), save_freq='epoch'),
]

# callbacks = [
#     tf.keras.callbacks.ModelCheckpoint(f'weights/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/{{epoch}}'),
# ]

In [None]:
model.fit(train_dataset, epochs=30, validation_data=test_dataset, callbacks=callbacks)

In [None]:
model.fit(train_dataset, epochs=40, validation_data=test_dataset, callbacks=callbacks, initial_epoch=30)

In [None]:
model.fit(train_dataset, epochs=50, validation_data=test_dataset, callbacks=callbacks, initial_epoch=40)

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss='binary_crossentropy', metrics=['accuracy', tfa.metrics.F1Score(len(labels))])

In [None]:
model.fit(train_dataset, epochs=80, validation_data=test_dataset, callbacks=callbacks, initial_epoch=50)

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss='binary_crossentropy', metrics=['accuracy', tfa.metrics.F1Score(len(labels))])

In [None]:
model.fit(train_dataset, epochs=90, validation_data=test_dataset, callbacks=callbacks, initial_epoch=80)

In [None]:
model.evaluate(train_dataset)

In [None]:
model.evaluate(test_dataset)

In [None]:
model = tf.keras.models.load_model('logs/20210624-173954/weights_6')

In [None]:
y_pred_test = model.predict(X_data[test_mask], verbose=1)

In [None]:
y_pred_train = model.predict(X_data[train_mask], verbose=1)

In [None]:
def __calc_metrics(y_true: np.ndarray, y_pred: np.ndarray, thresholds: list = None, threshold_count: int = 100):
    total = len(y_true)
    thresholds = thresholds or list(np.linspace(0, 1, threshold_count))
    metrics = []

    positives = y_true.sum()
    negatives = total - positives
    for threshold in thresholds:
        p = y_pred > threshold
        t = y_true.astype(bool)

        tp = np.logical_and(p, t).sum()
        tn = np.logical_and(np.logical_not(p), np.logical_not(t)).sum()
        fp = np.logical_and(p, np.logical_not(t)).sum()
        fn = np.logical_and(np.logical_not(p), t).sum()

        tpr = tp / positives
        tnr = tn / negatives
        precision = tp / (tp + fp) if tp + fp else 0
        recall = tp / (tp + fn) if tp + fn else 0
        accuracy = (tp + tn) / (tp + fn + tn + fp)
        f1 = 2 * precision * recall / (precision + recall) if precision or recall else 0

        metrics.append({
            'threshold': round(threshold, 3),
            'tpr': round(tpr, 3),
            'tnr': round(tnr, 3),
            'precision': round(precision, 3),
            'recall': round(recall, 3),
            'accuracy': round(accuracy, 3),
            'f1': round(f1, 3)
        })

    return metrics


def __calc_rocauc(metrics):
    coords = [(i['tpr'], i['tnr']) for i in metrics]
    auc = 0
    for i in range(len(coords) - 1):
        coord_i = coords[i]
        coord_j = coords[i + 1]
        auc += (coord_i[0] + coord_j[0]) / 2 * (coord_j[1] - coord_i[1])
    return coords, auc

def calc_metrics(y_true_np: np.ndarray, y_pred_np: np.ndarray, plot_name: str, thresholds: list = None, do_text: bool = True):
    thresholds = thresholds or []

    metrics_for_predicted = [__calc_metrics(y_true_np[:,i], y_pred_np[:,i], threshold_count=500)
                             for i in range(len(labels))]
    plt.title(plot_name, fontsize=20)
    plt.xlabel('True negative ratio', fontsize=20)
    plt.ylabel('True positive ratio', fontsize=20)
    plt.grid(color = 'green', linestyle = '--', linewidth = 0.5)
    plt.plot([0, 1], [1, 0], '--', color='coral')
    l = 0
    for metric_batch in metrics_for_predicted:
        coords, auc = __calc_rocauc(metric_batch)
        plt.plot([i[1] for i in coords], [i[0] for i in coords], label=labels[l])
        l += 1
    plt.legend()

    if not thresholds:
        for metric_batch in metrics_for_predicted:
            last_threshold = 0
            for metrics in metric_batch:
                tpr = metrics['tpr']
                tnr = metrics['tnr']
                if tnr > tpr:
                    thresholds.append((last_threshold + metrics['threshold']) / 2)
                    break
                last_threshold = metrics['threshold']

    metrics_for_predicted_with_thresholds = [__calc_metrics(y_true_np[:,i], y_pred_np[:,i], thresholds=[thresholds[i]])
                                             for i in range(len(labels))]

    total_mean_metrics = defaultdict(float)
    for metrics in metrics_for_predicted_with_thresholds:
        total_mean_metrics['precision'] += metrics[0]['precision']
        total_mean_metrics['recall'] += metrics[0]['recall']
        total_mean_metrics['accuracy'] += metrics[0]['accuracy']
        total_mean_metrics['f1'] += metrics[0]['f1']

    for key in total_mean_metrics.keys():
        total_mean_metrics[key] /= len(labels)

    s_t = '\n'.join(map(str, thresholds))
    s_m = '\n'.join(map(lambda x: f'{str.upper(x[0]) + x[1:]}: {round(total_mean_metrics[x], 4)}', total_mean_metrics))
    print(f'[{plot_name}]\nThresholds:\n{s_t}\n\nMetrics:\n{s_m}\n')
    if do_text:
        plt.text(0.3, 0.04, s_m, fontsize=14)
    fig = plt.gcf()
    fig.set_size_inches(12.5, 10.5)
    fig.savefig(f'{plot_name.replace(" ", "_")}{"_text" if do_text else ""}.png', dpi=200)
    return thresholds, total_mean_metrics

In [None]:
class_thresholds, train_metrics = calc_metrics(y_data[train_mask], y_pred_train, plot_name=f'Token Based Model ROC (train)', do_text=True)

In [None]:
class_thresholds, train_metrics = calc_metrics(y_data[train_mask], y_pred_train, plot_name=f'Token Based Model ROC (train)', do_text=False)

In [None]:
_, test_metrics = calc_metrics(y_data[test_mask], y_pred_test, plot_name=f'Token Based Model ROC (test)', thresholds=class_thresholds, do_text=True)

In [None]:
_, test_metrics = calc_metrics(y_data[test_mask], y_pred_test, plot_name=f'Token Based Model ROC (test)', thresholds=class_thresholds, do_text=False)