In [1]:
#python -u OWL2Vec_Plus.py --walker wl --walk_depth 4 --URI_Doc yes --Lit_Doc yes --Embed_Out_URI no --Embed_Out_Words yes


In [2]:
import numpy as np
import pandas as pd
import random
import re
import multiprocessing
import gensim
import sys
from tqdm import tqdm
from nltk import word_tokenize
from owl2vec_star.lib.Evaluator import Evaluator
from owl2vec_star.lib.RDF2Vec_Embed import get_rdf2vec_walks

In [3]:
class AttributeDict(dict):
    def __getattr__(self, attr):
        return self[attr]

    def __setattr__(self, attr, value):
        self[attr] = value

# Usage
FLAGS = AttributeDict()
FLAGS['onto_file'] = "files/foodon-merged.train.owl"
FLAGS['train_file'] = "files/train.csv"
FLAGS['valid_file'] = "files/valid.csv"
FLAGS['test_file'] = "files/test.csv"
FLAGS['class_file'] = "files/classes.txt"
FLAGS['inferred_ancestor_file'] = "files/inferred_ancestors.txt"
FLAGS["embedsize"] = 100

FLAGS["URI_Doc"] ="yes"
FLAGS["Lit_Doc"] ="no"
FLAGS["Mix_Doc"] ="no"
FLAGS["Mix_Type"] ="random"
FLAGS["Embed_Out_URI"] ="yes"
FLAGS["Embed_Out_Words"] ="yes"

FLAGS["input_type"] ="concatenate"
FLAGS["walk_depth"] = 4
FLAGS["walker"] ="wl"
FLAGS["axiom_file"] ='files/axioms.txt'
FLAGS["annotation_file"] ='files/annotations.txt'

classes = [line.strip() for line in open(FLAGS.class_file).readlines()]
candidate_num = len(classes)

In [4]:

def URI_parse(uri):
    """Parse a URI: remove the prefix, parse the name part (Camel cases are plit)"""
    uri = re.sub("http[a-zA-Z0-9:/._-]+#", "", uri)
    uri = uri.replace('_', ' ').replace('-', ' ').replace('.', ' ').replace('/', ' '). \
        replace('"', ' ').replace("'", ' ')
    words = []
    for item in uri.split():
        matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', item)
        for m in matches:
            word = m.group(0)
            if word.isalpha():
                words.append(word.lower())
    return words


def embed(model, instances):

    def word_embeding(inst):
        v = np.zeros(model.vector_size)
        if inst in uri_label:
            words = uri_label.get(inst)
            n = 0
            for word in words:
                if word in model.wv.index_to_key:
                    v += model.wv.get_vector(word)
                    n += 1
            return v / n if n > 0 else v
        else:
            return v

    feature_vectors = []
    for instance in instances:
        if FLAGS.Embed_Out_Words.lower() == 'yes' and FLAGS.Embed_Out_URI.lower() == 'yes':
            v_uri = model.wv.get_vector(instance) if instance in model.wv.index_to_key else np.zeros(model.vector_size)
            v_word = word_embeding(inst=instance)
            feature_vectors.append(np.concatenate((v_uri, v_word)))

        elif FLAGS.Embed_Out_Words.lower() == 'no' and FLAGS.Embed_Out_URI.lower() == 'yes':
            v_uri = model.wv.get_vector(instance) if instance in model.wv.index_to_key else np.zeros(model.vector_size)
            feature_vectors.append(v_uri)

        elif FLAGS.Embed_Out_Words.lower() == 'yes' and FLAGS.Embed_Out_URI.lower() == 'no':
            v_word = word_embeding(inst=instance)
            feature_vectors.append(v_word)

        else:
            print("Unknown embed out type")
            sys.exit(0)

    return feature_vectors


def pre_process_words(words):
    text = ' '.join([re.sub(r'https?:\/\/.*[\r\n]*', '', word, flags=re.MULTILINE) for word in words])
    tokens = word_tokenize(text)
    processed_tokens = [token.lower() for token in tokens if token.isalpha()]
    return processed_tokens


# Extract corpus and learning embedding

In [5]:

uri_label = dict()
annotations = list()
for line in open(FLAGS.annotation_file, encoding="utf8").readlines():
    tmp = line.strip().split()
    if tmp[1] == 'http://www.w3.org/2000/01/rdf-schema#label':
        uri_label[tmp[0]] = pre_process_words(tmp[2:])
    elif tmp[0] in classes:
        annotations.append(tmp)

In [6]:

walk_sentences, axiom_sentences = list(), list()
if FLAGS.URI_Doc.lower() == 'yes':
    walks_ = get_rdf2vec_walks(onto_file=FLAGS.onto_file, walker_type=FLAGS.walker,
                               walk_depth=FLAGS.walk_depth, classes=classes)
    print('Extracted {} walks for {} classes!'.format(len(walks_), len(classes)))
    walk_sentences += [list(map(str, x)) for x in walks_]
    for line in open(FLAGS.axiom_file).readlines():
        axiom_sentence = [item for item in line.strip().split()]
        axiom_sentences.append(axiom_sentence)
    print('Extracted %d axiom sentences' % len(axiom_sentences))
URI_Doc = walk_sentences + axiom_sentences

Extracted 2218855 walks for 28182 classes!
Extracted 34184 axiom sentences


In [7]:
Lit_Doc = list()
if FLAGS.Lit_Doc.lower() == 'yes':
    for annotation in annotations:
        processed_words = pre_process_words(annotation[2:])
        if len(processed_words) > 0:
            Lit_Doc.append(uri_label[annotation[0]] + processed_words)
    print('Extracted %d literal annotations' % len(Lit_Doc))

    for sentence in walk_sentences:
        lit_sentence = list()
        for item in sentence:
            if item in uri_label:
                lit_sentence += uri_label[item]
            elif item.startswith('http://www.w3.org'):
                lit_sentence += [item.split('#')[1].lower()]
            else:
                lit_sentence += [item]
        Lit_Doc.append(lit_sentence)

    for sentence in axiom_sentences:
        lit_sentence = list()
        for item in sentence:
            lit_sentence += uri_label[item] if item in uri_label else [item.lower()]
        Lit_Doc.append(lit_sentence)

In [8]:

Mix_Doc = list()
if FLAGS.Mix_Doc.lower() == 'yes':
    for sentence in walk_sentences:
        if FLAGS.Mix_Type.lower() == 'all':
            for index in range(len(sentence)):
                mix_sentence = list()
                for i, item in enumerate(sentence):
                    if i == index:
                        mix_sentence += [item]
                    else:
                        if item in uri_label:
                            mix_sentence += uri_label[item]
                        elif item.startswith('http://www.w3.org'):
                            mix_sentence += [item.split('#')[1].lower()]
                        else:
                            mix_sentence += [item]
                Mix_Doc.append(mix_sentence)
        elif FLAGS.Mix_Type.lower() == 'random':
            random_index = random.randint(0, len(sentence)-1)
            mix_sentence = list()
            for i, item in enumerate(sentence):
                if i == random_index:
                    mix_sentence += [item]
                else:
                    if item in uri_label:
                        mix_sentence += uri_label[item]
                    elif item.startswith('http://www.w3.org'):
                        mix_sentence += [item.split('#')[1].lower()]
                    else:
                        mix_sentence += [item]
            Mix_Doc.append(mix_sentence)

    for sentence in axiom_sentences:
        if FLAGS.Mix_Type.lower() == 'all':
            for index in range(len(sentence)):
                random_index = random.randint(0, len(sentence) - 1)
                mix_sentence = list()
                for i, item in enumerate(sentence):
                    if i == random_index:
                        mix_sentence += [item]
                    else:
                        mix_sentence += uri_label[item] if item in uri_label else [item.lower()]
                Mix_Doc.append(mix_sentence)
        elif FLAGS.Mix_Type.lower() == 'random':
            random_index = random.randint(0, len(sentence)-1)
            mix_sentence = list()
            for i, item in enumerate(sentence):
                if i == random_index:
                    mix_sentence += [item]
                else:
                    mix_sentence += uri_label[item] if item in uri_label else [item.lower()]
            Mix_Doc.append(mix_sentence)

In [9]:

print('URI_Doc: %d, Lit_Doc: %d, Mix_Doc: %d' % (len(URI_Doc), len(Lit_Doc), len(Mix_Doc)))
all_doc = URI_Doc + Lit_Doc + Mix_Doc
random.shuffle(all_doc)

URI_Doc: 2253039, Lit_Doc: 0, Mix_Doc: 0


In [10]:
pd.DataFrame(URI_Doc)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,15,16,17,18,19,20,21,22,23,24
0,http://purl.obolibrary.org/obo/FOODON_03305399,http://purl.obolibrary.org/obo/IAO_0000114,b'u\x1a\xbeX\\\xec\x85\x96\xec9D\x05\xa2)q\x9d',http://purl.obolibrary.org/obo/IAO_0000114,"b""\xa9\x16\xbd\x01\x97:\x1f\xb5$ n\xe1q't\x04""",,,,,,...,,,,,,,,,,
1,http://purl.obolibrary.org/obo/FOODON_03542476,http://purl.obolibrary.org/obo/IAO_0000114,b'F\xfbpd\xe9 =\xc9\r*\xf6\xdc\xce5Q/',http://purl.obolibrary.org/obo/IAO_0000114,b'P\x06\x97\xac\xa2\xa98E\tp\xf7h\xaf\xc0\xa5\...,,,,,,...,,,,,,,,,,
2,http://purl.obolibrary.org/obo/FOODON_03414141,http://www.geneontology.org/formats/oboInOwl#h...,b'\x0f0M\x17_=\x1da\xa3\xdd\xec\x919(~\x19',http://www.geneontology.org/formats/oboInOwl#h...,b'm\xf3\xeb\x01\xca\x0b\xf4\xb03\x13a!l\xe5\xa...,,,,,,...,,,,,,,,,,
3,http://purl.obolibrary.org/obo/FOODON_03400829,http://www.w3.org/1999/02/22-rdf-syntax-ns#type,b'\xcdy\xf2\x1dr\x89(X+\xbd \xa2!\xd4>Q',,,,,,,,...,,,,,,,,,,
4,http://purl.obolibrary.org/obo/FOODON_03315876,http://www.w3.org/2000/01/rdf-schema#subClassOf,http://purl.obolibrary.org/obo/FOODON_00002007,http://www.w3.org/2000/01/rdf-schema#subClassOf,http://purl.obolibrary.org/obo/FOODON_00001792,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2253034,http://purl.obolibrary.org/obo/FOODON_03303025,SubClassOf,http://purl.obolibrary.org/obo/RO_0001000,some,http://purl.obolibrary.org/obo/FOODON_03411457,,,,,,...,,,,,,,,,,
2253035,http://purl.obolibrary.org/obo/FOODON_03303025,SubClassOf,http://purl.obolibrary.org/obo/RO_0002350,some,http://purl.obolibrary.org/obo/FOODON_03400212,,,,,,...,,,,,,,,,,
2253036,http://purl.obolibrary.org/obo/FOODON_03413352,EquivalentTo,http://purl.obolibrary.org/obo/FOODON_00001303,some,http://purl.obolibrary.org/obo/NCBITaxon_394708,and,http://purl.obolibrary.org/obo/FOODON_00001303,only,http://purl.obolibrary.org/obo/NCBITaxon_394708,,...,,,,,,,,,,
2253037,http://purl.obolibrary.org/obo/FOODON_03413352,SubClassOf,http://purl.obolibrary.org/obo/FOODON_03411084,,,,,,,,...,,,,,,,,,,


In [None]:
pd.DataFrame([t[2] for t in URI_Doc if t[1] == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"]).value_counts()

IndexError: list index out of range

In [14]:
model_ = gensim.models.Word2Vec(all_doc, vector_size=FLAGS.embedsize, window=5, workers=multiprocessing.cpu_count(),
                                    sg=1, epochs=10, negative=25, min_count=1, seed=42)

In [None]:
classes.extend(["http://www.w3.org/2002/07/owl#Class", "http://www.w3.org/2002/07/owl#ObjectProperty"])

In [15]:

classes_e = embed(model=model_, instances=classes)
new_embedsize = classes_e[0].shape[0]


In [19]:
#pd.DataFrame(classes_e).to_csv(r'D:\.VScode\projects\Owls2Vec_star\Workspace\TTL\OWL2VEC.csv', index=False, header=False)

In [18]:
classes_e = pd.read_csv(r'D:\.VScode\projects\Owls2Vec_star\Workspace\TTL\OWL2VEC.csv', header=None).to_numpy()
new_embedsize = classes_e[0].shape[0]

# 2.Train and test

In [5]:
train_samples = [line.strip().split(',') for line in open(FLAGS.train_file).readlines()]
valid_samples = [line.strip().split(',') for line in open(FLAGS.valid_file).readlines()]
test_samples = [line.strip().split(',') for line in open(FLAGS.test_file).readlines()]
random.shuffle(train_samples)



In [9]:
train_x_list, train_y_list = list(), list()
for s in tqdm(train_samples):
    sub, sup, label = s[0], s[1], s[2]
    sub_v = classes_e[classes.index(sub)]
    sup_v = classes_e[classes.index(sup)]
    if not (np.all(sub_v == 0) or np.all(sup_v == 0)):
        if FLAGS.input_type == 'concatenate':
            train_x_list.append(np.concatenate((sub_v, sup_v)))
        else:
            train_x_list.append(sub_v - sup_v)
        train_y_list.append(int(label))
train_X, train_y = np.array(train_x_list), np.array(train_y_list)
print('train_X: %s, train_y: %s' % (str(train_X.shape), str(train_y.shape)))

  0%|          | 0/41688 [00:00<?, ?it/s]

100%|██████████| 41688/41688 [00:20<00:00, 2080.33it/s]

train_X: (26279, 200), train_y: (26279,)





In [33]:
inferred_ancestors = dict()
with open(FLAGS.inferred_ancestor_file) as f:
    for line in f.readlines():
        all_infer_classes = line.strip().split(',')
        cls = all_infer_classes[0]
        inferred_ancestors[cls] = all_infer_classes

In [34]:

class InclusionEvaluator(Evaluator):
    def __init__(self, valid_samples, test_samples, train_X, train_y):
        super(InclusionEvaluator, self).__init__(valid_samples, test_samples, train_X, train_y)

    def evaluate(self, model, eva_samples, mem_limit_GiB = 5):
        print('Evaluating...')
        sample_mem_size = 1_000_000_000 * 8 / (32 * candidate_num * train_X.shape[1])
        array_size_limit = int(mem_limit_GiB * sample_mem_size)
        if array_size_limit < 1:
            raise ValueError(f'Memory limit is too small! at least {1/sample_mem_size:.3f} GiB is required.')
        else:
            print(f'Array size limit: {array_size_limit}')
        MRR_sum, hits1_sum, hits5_sum, hits10_sum = 0, 0, 0, 0
        arange = tqdm(range(0, len(eva_samples), array_size_limit), desc='Evaluating...')
        accumulate_total = 0
        for i in arange:
            sub_eva_samples = eva_samples[i:i + array_size_limit]
            X_array = np.empty((len(sub_eva_samples), candidate_num, train_X.shape[1]), dtype=np.float32)
            for index, (individual, gt) in enumerate(sub_eva_samples):
                sub_index = classes.index(sub)
                sub_v = classes_e[sub_index]
                X_array[index] = np.concatenate((np.array([sub_v] * candidate_num), classes_e), axis=1)
                
            predicted_proba_array = X_array.reshape(-1, train_X.shape[1])
            predicted_proba_array = model.predict_proba(predicted_proba_array)[:, 1].reshape((len(sub_eva_samples), candidate_num))
            
            for P, (individual, gt) in zip(predicted_proba_array, sub_eva_samples):
                sorted_indexes = np.argsort(P)[::-1]
                sorted_classes = list()
                for j in sorted_indexes:
                    if classes[j] not in inferred_ancestors[individual]:
                        sorted_classes.append(classes[j])
                rank = sorted_classes.index(gt) + 1
                MRR_sum += 1.0 / rank
                hits1_sum += 1 if gt in sorted_classes[:1] else 0
                hits5_sum += 1 if gt in sorted_classes[:5] else 0
                hits10_sum += 1 if gt in sorted_classes[:10] else 0
                
            accumulate_total += len(sub_eva_samples)
            e_MRR, hits1, hits5, hits10 = MRR_sum / accumulate_total, hits1_sum / accumulate_total, hits5_sum / accumulate_total, hits10_sum / accumulate_total
            desc = f'({accumulate_total}) Evaluated MRR {e_MRR:.3f}, Hits@1 {hits1:.3f}, Hits@5 {hits5:.3f}, Hits@10 {hits10:.3f}'
            arange.set_description(desc)
        return e_MRR, hits1, hits5, hits10

In [35]:
print("\n		2.Train and test ... \n")
evaluator = InclusionEvaluator(valid_samples, test_samples, train_X, train_y)
evaluator.run_random_forest()


		2.Train and test ... 

Evaluating...
Array size limit: 221


Evaluated MRR 0.003, Hits@1 0.000, Hits@5 0.003, Hits@10 0.005:  11%|█         | 3/27 [10:22<1:23:03, 207.66s/it]


KeyboardInterrupt: 

# Concept2Vec

In [16]:
TRAIN_PATH = r"D:\.VScode\projects\Owls2Vec_star\Workspace\TTL\KGE_data\FoodOn\train_dataset.tsv"

In [1]:
class ConceptClass:
    def __init__(self, concept, concept_embedded_vector, entities, entities_embedded_vectors):
        self.concept = concept
        self.concept_embedded_vector = concept_embedded_vector
        self.entities = entities 
        self.entities_embedded_vectors = entities_embedded_vectors

        self.average_entities_vectors = self.average_entities_vectors()
        
    def __repr__(self):
        return f'concept: {self.concept}'
    
    def similarity(self, vec1, vec2):
        return np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))
    
    def coherence_score(self, res_concepts, target_concept):
        return sum(res_concepts == target_concept) / res_concepts.shape[0]
    
    def average_entities_vectors(self):
        return np.mean(self.entities_embedded_vectors, axis = 0)
    
class OntologyEvaluation:
    def __init__(self, triples, entity_to_id, embedded_vectors, type_relations = 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'):
        self.type_relations = triples[(triples.iloc[:,1] == type_relations)]
        self.entity_to_id = entity_to_id
        self.embedded_vector = embedded_vectors
        
        self.concepts = self.filter_concepts_from_triples()
        self.concat_df = self.create_concat_embedded_vectors()

    def create_concat_embedded_vectors(self):
        concat_df = []
        for concept in self.concepts:
            _df = pd.DataFrame({'vector': [i for i in concept.entities_embedded_vectors]})
            _df['label'] = concept.concept
            concat_df.append(_df)
        concat_df = pd.concat(concat_df, axis=0).reset_index(drop=True)
        return concat_df
        
    def filter_concepts_from_triples(self, filter_num = 1):
        concepts_count = pd.DataFrame(self.type_relations.iloc[:, 2]).value_counts(sort = True)
        evalution_concepts = concepts_count[concepts_count > filter_num]
        print(evalution_concepts)
        
        filter_concepts = evalution_concepts.index.to_list()
        for i, concept in enumerate(filter_concepts):
            concept = concept[0]
            concept_embedded_vector = self.get_embedded_vectors_from_label(concept)
            
            entities = self.get_entities_from_concept(concept)
            entities_embedded_vector = self.get_embedded_vectors_from_labels(entities)
            
            filter_concepts[i] = ConceptClass(concept, concept_embedded_vector, entities, entities_embedded_vector)
            
        return filter_concepts
    
    def get_entities_from_concept(self, concept):
        entities = self.type_relations[self.type_relations.iloc[:,2] == concept].iloc[:,0]
        return entities
    
    def get_embedded_vectors_from_label(self, label):
        #Get index from label
        index = self.entity_to_id[label]
        
        #Get embedded vector from index
        concept_embedded_vector = self.embedded_vector[index]
        return concept_embedded_vector
    
    def get_embedded_vectors_from_labels(self, labels):
        concept_embedded_vectors = []
        for label in labels:
            concept_embedded_vectors.append(self.get_embedded_vectors_from_label(label))
        return np.array(concept_embedded_vectors)
        

KeyboardInterrupt: 

In [17]:
triples = pd.read_csv(TRAIN_PATH, header=None, names=['head', 'relation', 'tail'])
triples.head()

Unnamed: 0,head,relation,tail
0,SIREN DB annotation:\n* has quality 'semisolid...,http://www.w3.org/2002/07/owl#sameAs,SIREN DB annotation:\n* has quality 'semisolid...
1,Anabantoidei,http://www.w3.org/2002/07/owl#sameAs,Anabantoidei
2,http://purl.obolibrary.org/obo/UO_0000113,http://www.w3.org/1999/02/22-rdf-syntax-ns#type,http://www.w3.org/2002/07/owl#Class
3,http://purl.obolibrary.org/obo/FOODON_03307105,http://www.w3.org/2000/01/rdf-schema#comment,"SIREN DB annotation:\n* has quality 'whole, sh..."
4,http://purl.obolibrary.org/obo/FOODON_03316729,http://www.w3.org/2000/01/rdf-schema#comment,"SIREN DB annotation:\n* surrounded by 'can, bo..."


In [21]:
classes_dict = {cls: i for i, cls in enumerate(classes)}

In [31]:
evalution_concepts = OntologyEvaluation(triples, classes_dict, classes_e)

tail                                        
http://www.w3.org/2002/07/owl#Class             28182
http://www.w3.org/2002/07/owl#ObjectProperty       56
Name: count, dtype: int64


KeyError: 'http://www.w3.org/2002/07/owl#Class'

In [None]:
for task in evalution_concepts.concepts:
    print(f'{task.concept: <35}: {task.similarity(task.concept_embedded_vector, task.average_entities_vectors):.3f}')

In [None]:
top_k = 10
concat_df = evalution_concepts.concat_df
for task in evalution_concepts.concepts:

    res_sims = concat_df.iloc[:, 0].map(lambda x: task.similarity(x, task.concept_embedded_vector)).sort_values(ascending=False).head(top_k)
    res_concepts = concat_df['label'].loc[res_sims.index]
    res_score = task.coherence_score(res_concepts, task.concept)
    print(f'{task.concept: <35}: {res_score:.3f} | {task.entities.shape[0] / concat_df.shape[0]:.3f}')

# Distance in the embedding space of training samples 

In [None]:
accumulate_distances = np.zeros((2))
count = 0
for s in tqdm(train_samples):
    ind, cls, label = s[0], s[1], s[2]
    ind_v = classes_e[classes.index(ind)]
    cls_v = classes_e[classes.index(cls)]
    if not (np.all(ind_v == 0) or np.all(cls_v == 0)):
        dist_ind = int(label)
        accumulate_distances[dist_ind] += np.linalg.norm(ind_v - cls_v)
        count += 1
average_distance = accumulate_distances / count
ratio = average_distance[1] / average_distance[0]
print('average_distance:\nfor positive: {:.3f}\nfor negative: {:.3f}\nfor ratio {:.3f}'.format(average_distance[1], average_distance[0], ratio))

# T-SNE visualization

In [None]:
% pip install seaborn matplotlib ipympl

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
%matplotlib widget

In [None]:
train_samples_df = pd.DataFrame(train_samples, columns=['individual', 'class', 'label'])
positive_samples_df = train_samples_df[train_samples_df['label'] == '1'].drop(columns=['label'])
positive_samples_df.head()

In [None]:
def _plot(X_embedded, combined_label, class_labels, color_list, label = ''):
    plt.clf()
    for label_i, (x, y), color in zip(combined_label, X_embedded, color_list):
        plt.plot(x, y, 'o', c = sns.color_palette()[color])
        if label_i in class_labels:
            plt.text(x * (1.02), y * (1.02) , label_i, fontsize=10)

    if label != '':
        plt.title(label)
    plt.show()

In [None]:
tsne_classes = []
tsne_X_label = []
combined_data = []
tsne_class_label = []
clean = lambda x: x.split('#')[-1] 


for i, cls in enumerate(tsne_classes, 1):
    samples = positive_samples_df.loc[positive_samples_df["class"] == cls, "individual"]
    
    tsne_X_label.append(cls)
    tsne_X_label.extend(samples)
    
    tsne_class_label.extend([0] + [i] * (len(samples)))
    
    cls_v = classes_e[classes.index(cls)]
    combined_data.append(cls_v)
    for sample in samples:
        ind_v = classes_e[classes.index(sample)]
        combined_data.append(ind_v)

combined_data = np.array(combined_data)
tsne_X_label = [clean(l) for l in tsne_X_label]
tsne_classes_clean = [clean(l) for l in tsne_classes]


In [None]:
X_embedded = TSNE(n_components=2, n_jobs = -1, early_exaggeration= 20).fit_transform(combined_data)

_plot(X_embedded, tsne_X_label, tsne_classes_clean, tsne_class_label)

## class-entity experiment

In [None]:
tsne_classes = positive_samples_df["class"].unique().tolist()  
tsne_classes_emgedded = [classes_e[classes.index(cls)] for cls in tsne_classes]

tsne_entity = positive_samples_df["individual"].sample(1000, replace=False).tolist()
tsne_entity_embedded = [classes_e[classes.index(ind)] for ind in tsne_entity]

combined_data = tsne_classes_emgedded + tsne_entity_embedded
combined_color = ['red'] * len(tsne_classes_emgedded) + ['grey'] * len(tsne_entity_embedded)
combined_data = np.array(combined_data)


In [None]:
X_embedded = TSNE(n_components=2, n_jobs = -1, early_exaggeration= 20).fit_transform(combined_data)

In [None]:
plt.clf()
for (x, y), color in zip(X_embedded, combined_color):
    plt.plot(x, y, 'o', c = color)
    
plt.show()