In [30]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [31]:
%cd drive/MyDrive/task

[Errno 2] No such file or directory: 'drive/MyDrive/task'
/content/drive/MyDrive/task


In [32]:
import numpy as np
import nltk

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')

from tensorflow import keras

from Datasets import SequenceDataset
from models import SequenceModel
from utility import top_k_metric

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [33]:
train_ds = SequenceDataset(mode="train", tag_func=nltk.pos_tag_sents)
val_ds = SequenceDataset(mode="valid", tag_func=nltk.pos_tag_sents, input_tokenizer=train_ds.input_tokenizer, target_tokenizer=train_ds.target_tokenizer)
test_ds = SequenceDataset(mode="test", tag_func=nltk.pos_tag_sents, input_tokenizer=train_ds.input_tokenizer, target_tokenizer=train_ds.target_tokenizer)

In [34]:
train_data, train_labels = train_ds.get_data_target(whole_dialog=True)
val_data, val_labels = val_ds.get_data_target(whole_dialog=True)
test_data, test_labels = test_ds.get_data_target(whole_dialog=True)

finished loading corpus
finished loading descriptions
finished extracting contexts
finished extracting targets (OOCs)
finished loading corpus
finished loading descriptions
finished extracting contexts
finished extracting targets (OOCs)
finished loading corpus
finished loading descriptions
finished extracting contexts
finished extracting targets (OOCs)


## Model

In [35]:
import tensorflow as tf
from sklearn.utils.class_weight import compute_class_weight
import tensorflow.keras.backend as K

In [36]:
def calculating_class_weights(y_true):
    number_dim = np.shape(y_true)[1]
    weights = np.zeros([number_dim, 2])
    for i in range(1, number_dim):
        weights[i] = compute_class_weight('balanced', [0.,1.], y_true[:, i])
    return weights

def get_weighted_loss(weights):
    def weighted_loss(y_true, y_pred):
        return K.mean((weights[:,0]**(1-y_true))*(weights[:,1]**(y_true))*K.binary_crossentropy(y_true, y_pred), axis=-1)
    return weighted_loss

class_weights = calculating_class_weights(train_labels)

In [37]:
model = SequenceModel(whole_dialog=True)
model.build_model(
        len(train_ds.input_tokenizer.word_index) + 1,
        512,
        train_labels.shape[-1],
        (train_data.shape[1],),
        loss=get_weighted_loss(class_weights),
        metrics=[
            keras.metrics.Precision(name='precision', top_k=5),
            keras.metrics.Recall(name='recall', top_k=5),
        ],
        hidden_sizes=[1024, 1024])

In [38]:
model.train(train_data, train_labels, val_data=val_data, val_labels=val_labels, batch_size=32, epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f7e23f31f98>

## Evaluation

In [39]:
train_pred = model.infer(train_data)
val_pred = model.infer(val_data)
test_pred = model.infer(test_data)

In [40]:
for i in range(1, 11):
    metric = top_k_metric(train_pred, train_labels, i)
    print(f"top {i}: \nprecision: {metric[0]}\nrecall: {metric[1]}\nf1: {metric[2]}")

top 1: 
precision: 0.00011611704598235021
recall: 3.8705681994116735e-05
f1: 5.8058522991175105e-05
top 2: 
precision: 0.00023223409196470042
recall: 0.0001644991484749961
f1: 0.00019258436894633693
top 3: 
precision: 0.0005031738659235176
recall: 0.0005844557981111626
f1: 0.0005407776067932465
top 4: 
precision: 0.000841848583372039
recall: 0.0014785570521752593
f1: 0.001072847902746018
top 5: 
precision: 0.004691128657686949
recall: 0.010885143653374028
f1: 0.00655658919099931
top 6: 
precision: 0.0049156216132528255
recall: 0.013608088381660142
f1: 0.007222334336082069
top 7: 
precision: 0.004263154402494857
recall: 0.013801616791630725
f1: 0.00651416203997324
top 8: 
precision: 0.003875406409660938
recall: 0.014255440913011744
f1: 0.006094103171642029
top 9: 
precision: 0.004218919337358724
recall: 0.017415759847831377
f1: 0.006792398941335704
top 10: 
precision: 0.0038086391082210874
recall: 0.01743898325702785
f1: 0.006251880092612975


In [41]:
for i in range(1, 11):
    metric = top_k_metric(val_pred, val_labels, i)
    print(f"top {i}: \nprecision: {metric[0]}\nrecall: {metric[1]}\nf1: {metric[2]}")

top 1: 
precision: 0.0
recall: 0.0
f1: nan


  f1 = 2 * (precision * recall) / (precision + recall)


top 2: 
precision: 0.0
recall: 0.0
f1: nan
top 3: 
precision: 0.0
recall: 0.0
f1: nan
top 4: 
precision: 0.0007062146892655367
recall: 0.0018832391713747645
f1: 0.0010272213662044169
top 5: 
precision: 0.0062146892655367235
recall: 0.014940364092906465
f1: 0.00877801807239015
top 6: 
precision: 0.006434400502197111
recall: 0.01854990583804143
f1: 0.009554599740699731
top 7: 
precision: 0.005515200430454667
recall: 0.01854990583804143
f1: 0.008502472211958621
top 8: 
precision: 0.0049435028248587575
recall: 0.018863779033270557
f1: 0.007834001839789356
top 9: 
precision: 0.0051265955220757475
recall: 0.021563088512241053
f1: 0.008283742352801346
top 10: 
precision: 0.004613935969868173
recall: 0.021563088512241053
f1: 0.007601376525897925


In [42]:
for i in range(1, 11):
    metric = top_k_metric(test_pred, test_labels, i)
    print(f"top {i}: \nprecision: {metric[0]}\nrecall: {metric[1]}\nf1: {metric[2]}")

top 1: 
precision: 0.0
recall: 0.0
f1: nan


  f1 = 2 * (precision * recall) / (precision + recall)


top 2: 
precision: 0.0
recall: 0.0
f1: nan
top 3: 
precision: 0.0006277463904582548
recall: 0.0005492780916509729
f1: 0.0005858966310943711
top 4: 
precision: 0.0014124293785310734
recall: 0.003217200251098556
f1: 0.0019630374413482717
top 5: 
precision: 0.006026365348399247
recall: 0.015756434400502197
f1: 0.008718257650998333
top 6: 
precision: 0.00549278091650973
recall: 0.016933458882611427
f1: 0.00829490638056543
top 7: 
precision: 0.004708097928436911
recall: 0.016933458882611427
f1: 0.007367712349214543
top 8: 
precision: 0.004119585687382297
recall: 0.016933458882611427
f1: 0.006626959309259037
top 9: 
precision: 0.004917346725256329
recall: 0.020652856246076583
f1: 0.007943406248490993
top 10: 
precision: 0.004425612052730698
recall: 0.020652856246076583
f1: 0.007289243380968208
