In [42]:
import numpy as np

In [43]:
data = np.loadtxt('data/animals.dat',delimiter=",",dtype=int)
names = np.loadtxt('data/animalnames.txt',dtype=str)

# Get data of correct shape
data = np.reshape(data, (-1, 84))



# Data is binary data with 84 attributes for 32 different animals 
print(data)
print(names)

[[1 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]
 ..., 
 [0 0 0 ..., 0 0 0]
 [0 1 0 ..., 1 0 0]
 [0 0 0 ..., 0 0 0]]
["'antelop'" "'ape'" "'bat'" "'bear'" "'beetle'" "'butterfly'" "'camel'"
 "'cat'" "'crocodile'" "'dog'" "'dragonfly'" "'duck'" "'elephant'" "'frog'"
 "'giraffe'" "'grasshopper'" "'horse'" "'housefly'" "'hyena'" "'kangaroo'"
 "'lion'" "'moskito'" "'ostrich'" "'pelican'" "'penguin'" "'pig'"
 "'rabbit'" "'rat'" "'seaturtle'" "'skunk'" "'spider'" "'walrus'"]


In [51]:
class SOM:
    def __init__(self): 
        # Weights 100 nodes, 84 inputs
        self.weights = np.random.rand(100,84)
        self.nsize = 50
        
    def similarity(self, data, weight):
        distance = data - weight
        
        # similarity is the length of the distance vector!
        return np.dot(distance.T, distance)

    def get_winner(self, data):
        # Returns the index of the winning node(weight), the most similar node to the input data
        similarities = np.array([])
        for i, weight in enumerate(self.weights):
            similarities = np.append(similarities, self.similarity(data, weight))
        
        return np.argmin(similarities) # returns the index of the minimum weight

    
    def getNeighbours(self, winner, ind):
        # returns the neighbourhood around the winner
        left = []
        right = []
        i = 1
        
        while (len(left) + len(right) < self.nsize)
            if (winner - i > 0):
                left.append(winner - i)
            if (winner - i < 99):
                right.append(winner + i)
            i += 1

        left = np.array(left)
        right = np.array(right)
        
        return left, right


    def updateWeights(self, idx, ind, lr = 0.2):
        # Update weight i
        for i in np.nditer(idx):
            self.weights[i] = self.weights[i] + lr * (np.subtract(ind,self.weights[i]))

    def train(self, epochs=20, size = 20):
        #For each sample in input
        for epoch in range(epochs):
            for i in range(data.shape[0]): 
                # For each animal you will have to pick out the corresponding row from the props matrix.
                
                # Then find the row of the weight matrix with the shortest distance to this attribute vector (p)
                # aka the winner
                winner = self.get_winner(data[i])
                
                # Once you have the index to the winning node, it is time to update the weights.
                n_left, n_right = getNeighbours(winner, data[i]) #Get list of neighbours with winnerNode in center

                
                updateWeights(n_left, ind)
                updateWeights(n_right, ind)

            #Update neighbourhood size
            if self.nsize > 5:
                self.nsize -= 2
            elif self.nsize > 2:
                self.nsize -= 1

    #Creates a SOM based on training
    def predict(self, ):
        pos=[]
        #Loop through animals
        for i in range(data.shape[0]):
                winnerNode = self.get_winner(data[i]) #Find best node
                pos.append([winnerNode, animal_names[i]])

        pos = np.array(pos,dtype=object)
        print(pos[pos[:,0].argsort()])


In [52]:
model = SOM()
model.train()       
model.predict()

[[0 "'duck'"]
 [26 "'hyena'"]
 [29 "'moskito'"]
 [33 "'lion'"]
 [33 "'giraffe'"]
 [33 "'elephant'"]
 [33 "'kangaroo'"]
 [33 "'crocodile'"]
 [33 "'camel'"]
 [33 "'pig'"]
 [33 "'rat'"]
 [33 "'skunk'"]
 [33 "'bat'"]
 [33 "'ape'"]
 [33 "'cat'"]
 [33 "'horse'"]
 [35 "'housefly'"]
 [35 "'walrus'"]
 [35 "'penguin'"]
 [35 "'ostrich'"]
 [35 "'pelican'"]
 [58 "'antelop'"]
 [62 "'frog'"]
 [62 "'seaturtle'"]
 [80 "'dragonfly'"]
 [80 "'butterfly'"]
 [80 "'beetle'"]
 [80 "'bear'"]
 [80 "'spider'"]
 [80 "'grasshopper'"]
 [83 "'rabbit'"]
 [88 "'dog'"]]
