In [1]:
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns

from itertools import chain

In [43]:
class Neuron:
    def __init__(self, weight, coord):
        self.weight_ = weight
        self.coord_ = coord

class SOM:
    def __init__(self):
        self.neurons_ = None
        self.neurons_nb_ = 9
        self.generate_lattice()
    

    def generate_lattice(self):
        n = self.neurons_nb_//3
        self.neurons_ = []

        for i in range(0, n):
            for j in range(0, n):
                self.neurons_.append(Neuron(None, np.array([i, j])))
                       

    def fit(self, X:np.ndarray, epochs, learning_rate=0.9):
        indexes = np.random.choice(X.shape[0], self.neurons_nb_, replace=False)

        for i in range(0, self.neurons_nb_):
            self.neurons_[i].weight_ = X[indexes[i]]

        for neuron in self.neurons_:
            print(neuron.weight_)

        for ep in range(epochs):
            for x in X:
                # Find BMU
                min_d = np.inf
                winner_idx = None

                for i in range(0, self.neurons_nb_):
                    d = self.hellinger_d(x, self.neurons_[i].weight_)
                    if d < min_d:
                        min_d = d
                        winner_idx = i


                h0 = 1
                sigma = ep + 1

                # Updating all neurons
                for i in range(0, self.neurons_nb_):
                    learning_rate = self.neighbourhood_func(self.neurons_[i].coord_, self.neurons_[winner_idx].coord_, h0, sigma)
                    self.neurons_[i].weight_ +=  learning_rate * (x - self.neurons_[i].weight_)

                

    @staticmethod
    def neighbourhood_func(r_i, r_c, h0, sigma):
        distance = np.sum(np.square(r_i - r_c))
        if distance > 1:
            return 0

        return h0 * np.exp(-distance/sigma**2)

    @staticmethod
    def hellinger_d(v1:np.ndarray, v2:np.ndarray):
        # Calculating Hellinger distance
        return np.sqrt(np.sum((np.sqrt(v1) - np.sqrt(v2))**2)) / np.sqrt(2)



In [44]:
def generate_dataset(n_samples, p):
    # p - vector of probability distributions

    distributions_nb = p.shape[0]
    repeat_nb = 10000

    data = []
    indexes_res = []

    for i in range(0, n_samples):
        indexes = np.random.choice(distributions_nb, replace=False, size=2).astype(np.uint)

        sample = []
        for k in indexes:
            sample.append(np.random.multinomial(repeat_nb, p[k])/repeat_nb)
        
        indexes_res.append(indexes)
        data.append(sample)
    
    return np.array(data).reshape(n_samples, -1), indexes_res

In [45]:
p = np.array([[1/3, 1/3, 1/3],
              [1/10, 1/10, 8/10],
              [1/4, 1/4, 1/2],
              [2/5, 1/5, 2/5],
              [2/5, 1/5, 2/5]])

In [46]:
X, indexes = generate_dataset(1000, p)

In [52]:
X

array([[0.1003, 0.0974, 0.8023, 0.337 , 0.3261, 0.3369],
       [0.0992, 0.0971, 0.8037, 0.3356, 0.3317, 0.3327],
       [0.3908, 0.2039, 0.4053, 0.3901, 0.2006, 0.4093],
       ...,
       [0.3917, 0.2059, 0.4024, 0.2501, 0.2438, 0.5061],
       [0.2465, 0.2476, 0.5059, 0.3279, 0.3401, 0.332 ],
       [0.3994, 0.1952, 0.4054, 0.2425, 0.2506, 0.5069]])

In [53]:
som = SOM()

In [54]:
som.fit(X, 300)

[0.3267 0.3298 0.3435 0.1037 0.1008 0.7955]
[0.2476 0.2561 0.4963 0.3986 0.2047 0.3967]
[0.3886 0.2073 0.4041 0.3905 0.2046 0.4049]
[0.3988 0.1954 0.4058 0.4041 0.2039 0.392 ]
[0.335  0.3279 0.3371 0.4    0.197  0.403 ]
[0.4022 0.1961 0.4017 0.1031 0.1015 0.7954]
[0.3369 0.3275 0.3356 0.4014 0.2072 0.3914]
[0.4009 0.2032 0.3959 0.2479 0.2474 0.5047]
[0.3374 0.3309 0.3317 0.3982 0.2006 0.4012]


In [55]:
for neuron in som.neurons_:
    print(neuron.weight_)
    print()

[0.24650089 0.24760104 0.50589807 0.32790084 0.34009841 0.33200075]

[0.3264 0.3411 0.3325 0.4032 0.1971 0.3997]

[0.32639915 0.34109896 0.33250189 0.40319996 0.19710003 0.39970001]

[0.2465 0.2476 0.5059 0.3279 0.3401 0.332 ]

[0.24650089 0.24760104 0.50589807 0.32790084 0.34009841 0.33200075]

[0.39939991 0.19520012 0.40539997 0.24250008 0.25059992 0.50689999]

[0.24650166 0.24759949 0.50589885 0.32789751 0.34009736 0.33200513]

[0.39939991 0.19520012 0.40539997 0.24250008 0.25059992 0.50689999]

[0.3994 0.1952 0.4054 0.2425 0.2506 0.5069]

