In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path
import wget

W2V_MODEL_PATH ='models/'
W2V_MODEL_NAME = 'wiki-news-300d-1M.vec.zip'  # 1.6G

directory = os.path.dirname(W2V_MODEL_PATH)
if not Path(directory).is_dir():
    print(f'Creating directory at {directory}',
          ' for saving word2vec pre-trained model')
    os.makedirs(directory)
if not Path(W2V_MODEL_PATH).is_file():
    w2v_archive = os.path.join(directory, W2V_MODEL_NAME)
    if not Path(w2v_archive).is_file():
        url = f'https://dl.fbaipublicfiles.com/fasttext/vectors-english/{W2V_MODEL_NAME}'
        print(f'Downloading word2vec pre-trained model to {w2v_archive}')
        wget.download(url, os.path.join(directory, W2V_MODEL_NAME))

In [None]:
from gensim.models import KeyedVectors
from gensim.models import Word2Vec
from gensim.models.wrappers import FastText


if W2V_MODEL_NAME[-4:] in ['.vec', '.bin']:
    word2vec_model = KeyedVectors.load_word2vec_format(W2V_MODEL_PATH + W2V_MODEL_NAME,
                                                   binary=W2V_MODEL_NAME[-4:] == '.bin')
elif W2V_MODEL_NAME[-4:] == '.zip':
    word2vec_model = KeyedVectors.load_word2vec_format(W2V_MODEL_PATH + W2V_MODEL_NAME[:-4],
                                               binary=W2V_MODEL_NAME[-4:] == '.bin')
elif W2V_MODEL_NAME[-7:] == '.bin.gz':
    word2vec_model = KeyedVectors.load_word2vec_format(W2V_MODEL_PATH + W2V_MODEL_NAME, binary=True)
    
else:
    word2vec_model = Word2Vec.load(W2V_MODEL_PATH + W2V_MODEL_NAME)
    
word2vec_vector_length = len(word2vec_model.wv.get_vector('tree'))

# Prepare feature rich dataset ``data/dataset.pkl`` out of corenlp annotations

### Triplets exploration

In [None]:
preps = "above across after against along among around at away before behind below beneath beside between by down during for from in front inside into near next of off on onto out outside over	through till to toward under underneath until up"

In [None]:
print(preps.split())

In [None]:
import networkx as nx
import multiprocessing
import numpy as np
from iteration_utilities import unique_everseen


def _extract_plain_features(document):
    def _extract(sentence):

        def get_postags_sequence(span, words, predicate=False):
            columns = ['JJ', 'CD', 'VBD', '', 'RB', 'VBN', 'PRP', 'IN', 'VBP', 'TO', 'NNP', 'VB',
                       'VBZ', 'VBG', 'POS', 'NNS', 'NN', 'MD']

            sequence = [token['pos'] for token in sentence['tokens'][span[0]:span[1]]
                        if token['originalText'] in words][:3]

#             if predicate or {'NN', 'NNP', 'NNS', 'CD'}.intersection(set(sequence)):
            if predicate or 'NNP' in set(sequence) or 'CD' in set(sequence):
                sequence = [[int(column == postag) for column in columns] for postag in sequence]
            else:
                sequence = []

            result = np.zeros((3, len(columns)))

            if sequence:
                result[:len(sequence)] = sequence

            return result

        def get_ner_occurrences(span, words, obj=True):
            _ner_kinds = ['TITLE', 'COUNTRY', 'DATE', 'PERSON', 'ORGANIZATION', 'MISC',
                          'LOCATION', 'NUMBER', 'CAUSE_OF_DEATH', 'NATIONALITY', 'ORDINAL',
                          'DURATION', 'CRIMINAL_CHARGE', 'CITY', 'RELIGION',
                          'STATE_OR_PROVINCE', 'IDEOLOGY', 'SET', 'URL', 'PERCENT', 'TIME',
                          'MONEY', 'HANDLE']

            mentions = [token['ner'] for token in sentence['tokens'][span[0]:span[1]]
                        if token['originalText'] in words]

            mentions = [[int(_ner_kind == mention) for _ner_kind in _ner_kinds] for mention in mentions][:3]
            result = np.zeros((3, len(_ner_kinds)))

            if mentions:
                result[:len(mentions)] = mentions

            return result
        
        def get_prep_sequence(words):
            _prep_kinds = ['above', 'across', 'after', 'against', 'along', 'among', 'around', 'at', 'away', 
                           'before', 'behind', 'below', 'beneath', 'beside', 'between', 'by', 
                           'down', 'during', 'for', 'from', 'in', 'front', 'inside', 'into', 
                           'near', 'next', 'of', 'off', 'on', 'onto', 'out', 'outside', 'over', 
                           'through', 'till', 'to', 'toward', 'under', 'underneath', 'until', 'up']
            
            words = words.split(' ')
            mentions = [int(prep in words) for prep in _prep_kinds]
            result = np.zeros(len(_prep_kinds))
            
            if mentions:
                result[:len(mentions)] = mentions
                
            return result

        def tag_lemma(span, words, tag=False):
            if tag:
                return [token['lemma'].lower() + '_' + _penn_tagset[token['pos']]['fPOS'] for token in
                        sentence['tokens'][span[0]:span[1]]
                        if token['originalText'] in words]
            else:
                return [token['lemma'].lower() for token in sentence['tokens'][span[0]:span[1]]
                        if token['originalText'] in words]

        def remove_repetition(words):
            if words[:len(words) // 2] == words[len(words) // 2:]:
                return words[:len(words) // 2]
            return words

        def get_tokens(words, span):
            return [token['originalText'].lower() for token in sentence['tokens'][span[0]:span[1]]
                    if token['originalText'] in words]

        def _build_dep_path(dependencies, tokens, start: int, end: int):
            edges = []
            deps = {}

            for edge in dependencies:
                edges.append((edge['governor'], edge['dependent']))
                deps[(min(edge['governor'], edge['dependent']),
                      max(edge['governor'], edge['dependent']))] = edge

            graph = nx.Graph(edges)
            try:
                path = nx.shortest_path(graph, source=start, target=end)
                return path[:-1]  # exclude right end
            except:
                return [start, ]

        def _tokens_by_index(indexes, tokens):
            return [token['originalText'] for token in tokens if token['index'] in indexes]

        def _lemmas_by_index(indexes, tokens):
            return [token['lemma'].lower() for token in tokens if token['index'] in indexes]

        def _embed(placeholder, words):
            for j in range(len(words)):
                if j == len(placeholder):
                    break

                word = words[j]
                if word and word in word2vec_model:
                    placeholder[j, :] = word2vec_model[word]
            return placeholder

        def _embed_arg(row):
            result = []
            result.append(_embed(np.zeros((3, word2vec_vector_length)), row['lemmas']))

            return result

        #         deprecated = set(['one', 'he', 'she', 'they', 'his', 'her', 'its', 'our', 'day', 'co.', 'inc.', 
        #               'society', 'people', 'inventor', 'head', 'poet', 'doctor', 'teacher', 'inventor', 
        #               'thanksgiving day', 'halloween',
        #               'sales person', 'model', 'board', 'technology', 'owner', 'one', 'two', 'university', 
        #                           'fbi', 'patricia churchland', 'century', 'association', 'laboratory', 'academy'])
        deprecated = []
        deprec_rels = []

        triplets = sentence['openie']
#         filtered_triplets = filter(
#             lambda obj: obj['object'].lower() not in deprecated and obj['subject'].lower() not in deprecated, 
#             triplets)
        filtered_triplets = filter(lambda obj: obj['subject'].lower().strip() not in deprecated, 
                                   triplets)
        filtered_triplets = filter(lambda obj: obj['object'].lower().strip() not in deprecated, 
                                   filtered_triplets)
        filtered_triplets = filter(lambda obj: obj['relation'].lower().strip() not in deprec_rels, 
                                   filtered_triplets)
        filtered_triplets = filter(
            lambda obj: len(obj['object']) > 2 and len(obj['subject']) > 2 and len(obj['relation']) > 2,
            filtered_triplets)
        filtered_triplets = filter(lambda obj: len(obj['relation'].split()) < 4, filtered_triplets)
        filtered_triplets = filter(lambda obj: len(obj['subject'].split()) < 4, filtered_triplets)
        filtered_triplets = filter(lambda obj: len(obj['object'].split()) < 4, filtered_triplets)
        filtered_triplets = list(filtered_triplets)

        subjects, relations, objects, dep_path = [], [], [], []

        for triplet in filtered_triplets:
            _subject = {
                'tokens': get_tokens(triplet['subject'], triplet['subjectSpan']),
                'lemmas': tag_lemma(triplet['subjectSpan'], triplet['subject']),
                'dist_to_rel': triplet['relationSpan'][0] - triplet['subjectSpan'][0],
                'rel_pos': triplet['subjectSpan'][0] / len(sentence['tokens']),
                'ner': get_ner_occurrences(triplet['subjectSpan'], triplet['subject']),
                'postag': get_postags_sequence(triplet['subjectSpan'], triplet['subject'], predicate=False),
            }
            _subject.update({
                'w2v': _embed(np.zeros((3, word2vec_vector_length)), _subject['lemmas']),
            })

            _relation = {
                'tokens': get_tokens(triplet['relation'], triplet['relationSpan']),
                'lemmas': tag_lemma(triplet['relationSpan'], triplet['relation']),
#                 'dist_to_rel': 0,
                'rel_pos': triplet['relationSpan'][0] / len(sentence['tokens']),
                'ner': get_ner_occurrences(triplet['relationSpan'], triplet['relation']),
                'postag': get_postags_sequence(triplet['relationSpan'], triplet['relation'], predicate=True),
                'prep': get_prep_sequence(triplet['relation']),
            }
            _relation.update({
                'w2v': _embed(np.zeros((3, word2vec_vector_length)), _relation['lemmas']),
            })

            _object = {
                'tokens': get_tokens(triplet['object'], triplet['objectSpan']),
                'lemmas': tag_lemma(triplet['objectSpan'], triplet['object']),
                'dist_to_rel': triplet['relationSpan'][0] - triplet['objectSpan'][0],
                'rel_pos': triplet['objectSpan'][0] / len(sentence['tokens']),
                'ner': get_ner_occurrences(triplet['objectSpan'], triplet['object']),
                'postag': get_postags_sequence(triplet['objectSpan'], triplet['object'], predicate=False),
            }
            _object.update({
                'w2v': _embed(np.zeros((3, word2vec_vector_length)), _object['lemmas']),
            })

            _dependency_path = ' '.join(_lemmas_by_index(_build_dep_path(sentence['basicDependencies'],
                                                                         sentence['tokens'],
                                                                         triplet['subjectSpan'][0],
                                                                         triplet['objectSpan'][-1]),
                                                         sentence['tokens']))
            subjects.append(_subject)
            relations.append(_relation)
            objects.append(_object)
            dep_path.append(_dependency_path)

        # return pd.DataFrame(result, columns=header)
        return subjects, relations, objects

    subjects, relations, objects = [], [], []
    for sentence in document:
        _subject, _relation, _object = _extract(sentence)
        subjects += _subject
        relations += _relation
        objects += _object

    return subjects, relations, objects


def _mark_ner_object(row):
    return row['relation'] + (row['DATE_obj'] == 1) * ' date' \
           + (row['LOCATION_obj'] == 1) * ' location'


def _extract_features(document):
    def _embed_arg(row):
        result = []
        result.append(_embed(np.zeros((3, word2vec_vector_length)), row['lemmas']))

        return result

    features = {}
    features['subject'], features['relation'], features['object'] = _extract_plain_features(document)

    return pd.DataFrame(features)


def remove_repetitions(annot):
    for i in range(len(annot)):
        for j in range(len(annot[i])):
            annot[i][j]['openie'] = list(unique_everseen(annot[i][j]['openie']))
    return annot


class FeaturesProcessor:

    def __init__(self):
        self.pool = multiprocessing.Pool(processes=1)

    def __call__(self, data):
        """
        data: list of lists: [['wiki_id', 'data'], ...]
        """

        def mark_garbage(row):
            """ Remove from the set some uninformative relations as well as triplets which do not contain 
                any noun in the object or subject
            """
            
            deprec_rels = {'in', 
#                            'is', 'was', 
                           'of', "'s", 'to', 'for', 'by', 'with', 'also', 'as of',
#                            'had', 
                           'said', 'said in', 'felt', 'on', 'gave', 'saw', 'found', 'did',
                           'at', 'as', 'e', 'as', 'de', 'mo', '’s', 'v', 'yr', 'al',
                           "'", 'na', 'v.', "d'", 'et', 'mp', 'di', 'y',
                           'ne', 'c.', 'be', 'ao', 'mi', 'im', 'h',
                           'has', 'between', 'are', 'returned', 'began', 'became',
                           'along', 'doors as', 'subsequently terrytoons in',
                          }

            def is_relation_deprecated():
                return row._relation.isdigit() or row._relation in deprec_rels

            def is_postag_undefined():
                return np.all(row['subject']['postag'] == np.zeros((3, 18))) or np.all(
                    row['object']['postag'] == np.zeros((3, 18))) or np.all(
                    row['relation']['postag'] == np.zeros((3, 18)))

            return is_relation_deprecated() or is_postag_undefined()

        features = pd.concat(self.pool.map(_extract_features, data))
        features['_subject'] = features['subject'].map(get_tokens)
        features['_relation'] = features['relation'].map(get_tokens)
        features['_object'] = features['object'].map(get_tokens)
        features['garbage'] = features.apply(lambda row: mark_garbage(row), axis=1)
        features = features[features.garbage == False]
        features = features.drop(columns=["garbage"])
        return features


In [None]:
from glob import glob
from tqdm.autonotebook import tqdm
import pandas as pd
import json

#DATA_PATH = 'data/corenlp_annotations_ner_pairs'  #'data/filtered_annotations'
trex_path = 'trex_data'
DATA_PATH = 'trex_corenlp_annotations'
RESULT_PATH = 'data/processed_separately'
! mkdir $RESULT_PATH 
result = []
extr = FeaturesProcessor()

def extract_matrix(row):
    _matrix = np.concatenate(
        [row['ner'], row['postag'], row['w2v'], np.array([[row['dist_to_rel'], row['rel_pos']]] * 3)], axis=1)
    return _matrix

def get_tokens(column):
    return ' '.join(column['tokens'])

for file in tqdm(glob(DATA_PATH + '/*.json')):
    
    tmp = json.load(open(file, "r"))
    
    if tmp.values():
    
        try:
            result = extr(tmp.values())
            result = result.drop_duplicates(['_subject', '_relation', '_object'])

            result.to_pickle(file.replace(DATA_PATH, RESULT_PATH).replace('.json', '.pkl'))
        except ValueError:
            print('No examples in file:', file)
        
    else:
        print('Unable to load examples from file:', file)

In [None]:
result._relation.unique()

In [None]:
from glob import glob
from tqdm.autonotebook import tqdm
import pandas as pd
import json


RESULT_PATH = 'data/processed_separately'
data = []

for file in tqdm(glob(RESULT_PATH + '/*.pkl')):
    data.append(pd.read_pickle(file))
    
data = pd.concat(data)

In [None]:
data.shape

In [None]:
data.head()

In [None]:
import numpy as np


def extract_matrix(row, predicate=False):
    _matrix = np.concatenate([row['ner'], row['postag']], axis=1)#.flatten()
    if predicate:
        _matrix = np.concatenate([_matrix, row['w2v'], [row['prep'], row['prep'], row['prep']]], axis=1)#.flatten()
    return _matrix.flatten()

def extract_one_matrix(row):
    _matrix = np.concatenate([extract_matrix(row['subject']), 
                             extract_matrix(row['relation'], predicate=True), 
                             extract_matrix(row['object'])], axis=0)
    return _matrix


features = data.apply(extract_one_matrix, axis=1).values
features = np.stack(features)

In [None]:
features.shape

In [None]:
with open('train_features_plain.pkl', 'wb') as f:
    np.save(f, features)

In [None]:
import numpy as np

features = np.load('train_features_plain.pkl')

### Train simple kmeans 

In [None]:
from sklearn.cluster import KMeans

kmeans = KMeans(init='k-means++', n_clusters=15, n_init=10)
kmeans.fit(features.tolist())

In [None]:
import pickle

pickle.dump(kmeans, open('simple_kmeans.pkl', 'wb'))

In [None]:
data['class'] = kmeans.predict(features.tolist())

In [None]:
data.tail()

In [None]:
data[["_subject", "_relation", "_object", "class"]].to_csv("trex_data_classified.csv", sep="\t")

In [None]:
number = 3
data[data['class'] == number]._relation.value_counts()

In [None]:
data[data['class'] == 14]._relation.value_counts()

In [None]:
data[data['class'] == 3].head(10)

In [None]:
data[data['class'] == 3]._relation.unique()