In [8]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
#from torchvision.transforms import ToTensor
from itertools import combinations
from queue import PriorityQueue
import os

## MLP module

In [9]:
class MLP(nn.Module):
    def __init__(self,input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)  
        )
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.linear_relu_stack(x)
        return x

In [10]:
class MLP_singleton(nn.Module): #embed singletons
    def __init__(self, input_dim = 2, hidden_dim = 3, output_dim=8):
        super(MLP_singleton, self).__init__()
        self.node = MLP(input_dim, hidden_dim, output_dim)
        #nn.ReLU() 
        
    def forward(self, x): #input here is a point
        x = self.node(x) #this is mapping the embedding of leaf to higher dimension
        return x

In [11]:
class MLP_init(nn.Module):
    def __init__(self, input_dim = 2, hidden_dim = 3, output_dim = 8):
        super(MLP_init, self).__init__()
        self.init = MLP(input_dim, hidden_dim, output_dim)
        #nn.ReLU() 
        
    def forward(self, x, y): #input here is pair of singletons
        z = self.init(x+y) #computing pairwise embeddings
        return z
        

In [12]:
class MLP_update(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP_update, self).__init__()
        self.update = MLP(input_dim, hidden_dim, output_dim)
        #nn.ReLU() 
        
    def forward(self, x, y, z): #input here is a list of 3 clusters/singletons
        final = np.concatenate(x+y, z)
        final = self.update(final)
        return final

In [13]:
class pairwise_ranker(nn.Module):
    def __init__(self, input_dim):
        super(pairwise_ranker, self).__init__()
        self.w = nn.Parameter(torch.rand(input_dim))
       
    def forward(self, x, y):
       z = x.embedding - y.embedding
       z = torch.tanh(x*z)
       return z

In [14]:
#Embedd user data
class cluster:
    def __init__(self, label, data, embedding):
        self.label = label
        self.data = data
        self.embedding = embedding
        self.children = []
        
    def add_child(self, child_cluster):
        if isinstance(child_cluster, cluster):
             self.children.append(child_cluster)
        else:
             raise TypeError("Child is not a cluster.")
    
    def __lt__(self, other):
        rank = pairwise_ranker(self, other)
        return rank > 0

In [36]:
def main(data):
    new_candidates = []
    embeddings = {} #consists of singletons [A], [B] and clusters i.e [AB]
    merges = PriorityQueue()
    
    def generate_rand_label(length=3):
        return os.urandom(length).hex()

    def get_pairwise_sums(singleton1, singleton2, embeddings): ## Getting pairwise sums 
        embed1 = embeddings.get(singleton1)
        embed2 = embeddings.get(singleton2)
        label = singleton1+singleton2
        data = embed1.data+embed2.data
    
        temp_embed = MLP_init(embed1.data, embed2.data)
    
        temp_embed_object = cluster(label, data, temp_embed)
        temp_embed_object.add_child(embed1)
        temp_embed_object.add_child(embed2)    
    
        embeddings[label] = temp_embed_object
        merges.put(temp_embed_object)
    

    #MLP_singletons
    for point in data: #assuming data is a list of vectors
        embedding = MLP_singleton(point)
        label = generate_rand_label()
        temp_embed_object = cluster(label, point, embedding)
        embeddings[label] = temp_embed_object #adding data_point object to dictionary
        
    print("Finished running MLP_singleton")

    #MLP_init
    for a, b in combinations(list(embeddings.keys()), 2):
        get_pairwise_sums(a,b, embeddings)
        
    print("Finished running MLP_init")
        
    #MLP_update
     
    while merges.empty() == False:
        merge = merges.get()
        embed_1 = embeddings.get(merge.label).children[0]
        embed_2 = embeddings.get(merge.label).children[1]
        dissected_1 = embed_1.label
        dissected_2 = embed_2.label
        dissected = set(list(dissected_1)).union(set(list(dissected_2)))
    
        for key in embeddings.keys():
            key_disect = set(list(key))
            common_elements = dissected.intersection(key_disect)
            if len(common_elements) == 0:
                new_candidates = new_candidates.append(key)

        for candidate in new_candidates: #compute embedding of new cluster to every other cluster
            merge_cluster_label = dissected_1+dissected_2
            clust_1 = embeddings.get(dissected_1 + candidate)
            clust_1 = clust_1.data
            clust_2 = embeddings.get(dissected_2 + candidate)
            clust_2 = clust_2.data
            clust_3 = embeddings.get(merge_cluster_label)
            clust_3 = clust_3.data
   
            new_embedding = MLP_update(clust_1, clust_2, clust_3)
            temp_embed_object = cluster(merge_cluster_label, NaN, new_embedding)
            temp_embed_object.add_child(clust_1)
            temp_embed_object.add_child(clust_2)
            temp_embed_object.add_child(clust_3)
            embeddings[merge_cluster_label+candidate] = temp_embed_object
            
            merges.pop()
    print("Finished running MLP_update")  
    print("Size of priority queue:", merges.qsize())
    return 1    

if __name__ == "__main__":
    main(([1,2]))

Finished running MLP_singleton
Finished running MLP_init
Finished running MLP_update
Size of priority queue: 0


In [16]:
#NOTES

'''
- implement training pipeline
'''

'\n- test basic functionality on synthetic datasets\n- implement tree object\n'