In [1]:
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import tensorflow as tf
import pandas as pd
from collections import Counter
from ast import literal_eval

# FUNCTIONS

In [2]:
def get_trimmed_glove_vectors(filename):
    """
    Args:
        filename: path to the npz file

    Returns:
        matrix of embeddings (np array)

    """
    try:
        with np.load(filename) as data:
            return data["embeddings"]

    except IOError:
        raise MyIOError(filename)

# CONSTANTS

In [7]:
PATH_TO_TRAIN = "../../data/train_preprocessed.csv"
PATH_TO_IDS_TRAIN = "../../data/train_context_ids.npy"
PATH_TO_IDS_TEST = "../../data/test_context_ids.npy"
PATH_TO_IDS_VAL = "../../data/val_context_ids.npy"

# LOAD AND SPLIT TRAIN

In [9]:
data = pd.read_csv(PATH_TO_TRAIN)

In [10]:
data.head()

Unnamed: 0,context_id,context_2,context_1,context_0,reply_id,reply,label,confidence,merged_contexts,contexts_and_reply,weighted_label,initial_labels
0,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",0,"[743031, 1101989]","[0, 0, 1]",0.875352,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.87535161750000001]",good
1,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",1,"[1141967, 1463197, 1431000, 1100672, 334254]","[0, 1, 0]",0.900968,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.90096821130000004, 0.0]",neutral
2,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",2,"[1448996, 1463197, 334254, 743031, 1688119, 15...","[1, 0, 0]",0.88432,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.88432021449999998, 0.0, 0.0]",bad
3,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",3,"[334254, 743031, 1103971, 1431000]","[0, 0, 1]",0.98253,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.98253046730000004]",good
4,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",4,"[334254, 355899, 1430584]","[0, 0, 1]",0.838054,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.83805350959999991]",good


# Load context ids

In [8]:
train_ids = np.load(PATH_TO_IDS_TRAIN)
test_ids = np.load(PATH_TO_IDS_TEST)
val_ids = np.load(PATH_TO_IDS_VAL)

In [13]:
train = data.loc[data['context_id'].isin(train_ids)]
test = data.loc[data['context_id'].isin(test_ids)]
val = data.loc[data['context_id'].isin(val_ids)]

In [14]:
train.head()

Unnamed: 0,context_id,context_2,context_1,context_0,reply_id,reply,label,confidence,merged_contexts,contexts_and_reply,weighted_label,initial_labels
0,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",0,"[743031, 1101989]","[0, 0, 1]",0.875352,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.87535161750000001]",good
1,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",1,"[1141967, 1463197, 1431000, 1100672, 334254]","[0, 1, 0]",0.900968,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.90096821130000004, 0.0]",neutral
2,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",2,"[1448996, 1463197, 334254, 743031, 1688119, 15...","[1, 0, 0]",0.88432,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.88432021449999998, 0.0, 0.0]",bad
3,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",3,"[334254, 743031, 1103971, 1431000]","[0, 0, 1]",0.98253,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.98253046730000004]",good
4,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",4,"[334254, 355899, 1430584]","[0, 0, 1]",0.838054,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.83805350959999991]",good


In [15]:
test.head()

Unnamed: 0,context_id,context_2,context_1,context_0,reply_id,reply,label,confidence,merged_contexts,contexts_and_reply,weighted_label,initial_labels
24,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",0,"[397658, 786497, 1489432, 1376676, 1463197, 90...","[0, 0, 1]",0.867679,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.0, 0.0, 0.86767855680000006]",good
25,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",1,"[786497, 1489432]","[0, 1, 0]",0.653608,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.0, 0.65360824549999996, 0.0]",neutral
26,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",2,"[1554817, 334254, 1017975, 1134945, 989147, 11...","[0, 0, 1]",0.903552,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.0, 0.0, 0.90355219860000002]",good
27,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",3,"[1702518, 786497, 1489432, 884730, 980317]","[1, 0, 0]",0.94458,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.94457976109999997, 0.0, 0.0]",bad
28,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",4,"[109082, 1463197, 1702518, 786497, 1489432]","[0, 0, 1]",0.87135,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.0, 0.0, 0.87134977730000007]",good


In [16]:
val.head()

Unnamed: 0,context_id,context_2,context_1,context_0,reply_id,reply,label,confidence,merged_contexts,contexts_and_reply,weighted_label,initial_labels
121,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",0,"[642764, 1704903, 743031, 1066137]","[0, 1, 0]",0.936427,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.0, 0.93642739580000001, 0.0]",neutral
122,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",1,"[1141967, 637337]","[0, 0, 1]",0.586733,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.0, 0.0, 0.58673283279999999]",good
123,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",2,"[1704903, 1098780, 1059132, 307305, 1463197, 1...","[0, 0, 1]",0.958358,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.0, 0.0, 0.95835794429999999]",good
124,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",3,"[1776104, 581435, 1064784]","[1, 0, 0]",0.965069,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.96506939290000004, 0.0, 0.0]",bad
125,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",4,"[1141967, 1463197, 743031, 307305]","[0, 0, 1]",0.865941,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.0, 0.0, 0.86594092]",good


In [18]:
train.to_csv("../../data/train_splitted.csv", index=False)
test.to_csv("../../data/test_splitted.csv", index=False)
val.to_csv("../../data/val_splitted.csv", index=False)

# Check saved datasets

In [20]:
uploaded_train = pd.read_csv("../../data/train_splitted.csv")
uploaded_test = pd.read_csv("../../data/test_splitted.csv")
uploaded_val = pd.read_csv("../../data/val_splitted.csv")

In [21]:
uploaded_train.head()

Unnamed: 0,context_id,context_2,context_1,context_0,reply_id,reply,label,confidence,merged_contexts,contexts_and_reply,weighted_label,initial_labels
0,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",0,"[743031, 1101989]","[0, 0, 1]",0.875352,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.87535161750000001]",good
1,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",1,"[1141967, 1463197, 1431000, 1100672, 334254]","[0, 1, 0]",0.900968,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.90096821130000004, 0.0]",neutral
2,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",2,"[1448996, 1463197, 334254, 743031, 1688119, 15...","[1, 0, 0]",0.88432,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.88432021449999998, 0.0, 0.0]",bad
3,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",3,"[334254, 743031, 1103971, 1431000]","[0, 0, 1]",0.98253,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.98253046730000004]",good
4,22579918886,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[1830003, 1463197, 334254, 676100, 917832]","[1293580, 1463197, 1501890, 1000204]",4,"[334254, 355899, 1430584]","[0, 0, 1]",0.838054,"[654394, 1561605, 16734, 1554817, 306154, 1561...","[654394, 1561605, 16734, 1554817, 306154, 1561...","[0.0, 0.0, 0.83805350959999991]",good


In [22]:
uploaded_test.head()

Unnamed: 0,context_id,context_2,context_1,context_0,reply_id,reply,label,confidence,merged_contexts,contexts_and_reply,weighted_label,initial_labels
0,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",0,"[397658, 786497, 1489432, 1376676, 1463197, 90...","[0, 0, 1]",0.867679,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.0, 0.0, 0.86767855680000006]",good
1,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",1,"[786497, 1489432]","[0, 1, 0]",0.653608,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.0, 0.65360824549999996, 0.0]",neutral
2,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",2,"[1554817, 334254, 1017975, 1134945, 989147, 11...","[0, 0, 1]",0.903552,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.0, 0.0, 0.90355219860000002]",good
3,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",3,"[1702518, 786497, 1489432, 884730, 980317]","[1, 0, 0]",0.94458,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.94457976109999997, 0.0, 0.0]",bad
4,127768564286,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[334254, 1017975, 613707, 200685, 1463197, 613...","[786497, 1504830, 980317, 938763]",4,"[109082, 1463197, 1702518, 786497, 1489432]","[0, 0, 1]",0.87135,"[1061960, 397658, 1247951, 1134945, 1147066, 1...","[1061960, 397658, 1247951, 1134945, 1147066, 1...","[0.0, 0.0, 0.87134977730000007]",good


In [23]:
uploaded_val.head()

Unnamed: 0,context_id,context_2,context_1,context_0,reply_id,reply,label,confidence,merged_contexts,contexts_and_reply,weighted_label,initial_labels
0,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",0,"[642764, 1704903, 743031, 1066137]","[0, 1, 0]",0.936427,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.0, 0.93642739580000001, 0.0]",neutral
1,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",1,"[1141967, 637337]","[0, 0, 1]",0.586733,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.0, 0.0, 0.58673283279999999]",good
2,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",2,"[1704903, 1098780, 1059132, 307305, 1463197, 1...","[0, 0, 1]",0.958358,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.0, 0.0, 0.95835794429999999]",good
3,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",3,"[1776104, 581435, 1064784]","[1, 0, 0]",0.965069,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.96506939290000004, 0.0, 0.0]",bad
4,521831731666,[1304112],"[1663321, 488315, 1862038]","[1748603, 1098780, 1120957, 166293, 307305]",4,"[1141967, 1463197, 743031, 307305]","[0, 0, 1]",0.865941,"[1304112, 1663321, 488315, 1862038, 1748603, 1...","[1304112, 1663321, 488315, 1862038, 1748603, 1...","[0.0, 0.0, 0.86594092]",good
