In [42]:
import numpy as np

In [43]:
data = np.loadtxt('data/animals.dat',delimiter=",",dtype=int)
names = np.loadtxt('data/animalnames.txt',dtype=str)

# Get data of correct shape
data = np.reshape(data, (-1, 84))



# Data is binary data with 84 attributes for 32 different animals 
print(data)
print(names)

[[1 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]
 ..., 
 [0 0 0 ..., 0 0 0]
 [0 1 0 ..., 1 0 0]
 [0 0 0 ..., 0 0 0]]
["'antelop'" "'ape'" "'bat'" "'bear'" "'beetle'" "'butterfly'" "'camel'"
 "'cat'" "'crocodile'" "'dog'" "'dragonfly'" "'duck'" "'elephant'" "'frog'"
 "'giraffe'" "'grasshopper'" "'horse'" "'housefly'" "'hyena'" "'kangaroo'"
 "'lion'" "'moskito'" "'ostrich'" "'pelican'" "'penguin'" "'pig'"
 "'rabbit'" "'rat'" "'seaturtle'" "'skunk'" "'spider'" "'walrus'"]


In [70]:
class SOM:
    def __init__(self): 
        # Weights 100 nodes, 84 inputs
        self.weights = np.random.rand(100,84)
        self.nsize = 50
        
    def similarity(self, data, weight):
        distance = data - weight
        
        # similarity is the length of the distance vector!
        return np.dot(distance.T, distance)

    def get_winner(self, data):
        # Returns the index of the winning node(weight), the most similar node to the input data
        similarities = np.array([])
        for i, weight in enumerate(self.weights):
            similarities = np.append(similarities, self.similarity(data, weight))
        
        return np.argmin(similarities) # returns the index of the minimum weight

    def train(self, epochs=20):
        #For each sample in input
        for epoch in range(epochs):
            for i, sample in enumerate(data): 
                # For each animal you will have to pick out the corresponding row from the props matrix.
                
                # Then find the row of the weight matrix with the shortest distance to this attribute vector
                # aka the winner
                winner = self.get_winner(sample)
                
                # Once you have the index to the winning node, it is time to update the weights.
                neighbours = self.get_neighbours(winner) #Get list of neighbours with winnerNode in center
      
                self.update_weights(neighbours, sample)

            # Update neighbourhood size
            if self.nsize > 5:
                self.nsize -= 2
            elif self.nsize > 2:
                self.nsize -= 1
                
    def get_neighbours(self, winner):
        # returns the neighbourhood around the winner
        neighbours = []
        i = 1
        
        neighbours.append(winner) # assuming we want to update the neighbour as well
        
        while (len(neighbours) < self.nsize):
            if (winner - i > 0):
                neighbours.append(winner - i)
            if (winner + i < 99):
                neighbours.append(winner + i)
            i += 1
        
        return np.array(neighbours)


    def update_weights(self, neighbours, sample, lr = 0.2):
        # Update weights of all neighbours to become slightly more similar to the input pattern (sample)
        for i in neighbours:
            self.weights[i] = self.weights[i] + lr * (sample - self.weights[i])

    def predict(self):
        output = []
        for i, sample in enumerate(data):
                winner = self.get_winner(sample) #Find best node
                output.append([winner, animal_names[i]])

        output = np.array(output, dtype=object)
        
        # sort by index
        sort_idx = output[:,0].argsort()
        print(output[sort_idx])


In [71]:
model = SOM()
model.train()       
model.predict()

[[1 "'dragonfly'"]
 [2 "'grasshopper'"]
 [9 "'beetle'"]
 [10 "'butterfly'"]
 [12 "'moskito'"]
 [12 "'housefly'"]
 [19 "'spider'"]
 [29 "'penguin'"]
 [29 "'pelican'"]
 [32 "'duck'"]
 [33 "'ostrich'"]
 [40 "'seaturtle'"]
 [42 "'crocodile'"]
 [42 "'frog'"]
 [49 "'walrus'"]
 [53 "'bear'"]
 [56 "'dog'"]
 [56 "'hyena'"]
 [60 "'skunk'"]
 [63 "'lion'"]
 [63 "'cat'"]
 [64 "'ape'"]
 [71 "'bat'"]
 [72 "'rat'"]
 [74 "'elephant'"]
 [81 "'rabbit'"]
 [82 "'kangaroo'"]
 [84 "'antelop'"]
 [91 "'horse'"]
 [93 "'giraffe'"]
 [98 "'pig'"]
 [98 "'camel'"]]
