In [13]:
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns

from abc import ABC, abstractmethod

In [14]:
class Neuron:
    def __init__(self, weight, coord):
        self.weight_ = weight
        self.coord_ = coord

In [15]:
class Lattice(ABC):
    def __init__(self, size):
        self.size_ = size

    @abstractmethod
    def generate(self):
        pass

class RectangularLattice(Lattice):
    def generate(self):
        neurons = []
        neurons_nb = self.size_[0] * self.size_[1]
        height = self.size_[0]
        width = self.size_[1]

        for i in range(0, height):
            for j in range(0, width):
                neurons.append(Neuron(None, np.array([i, j])))  # Assuming Neuron class is defined

        return neurons, neurons_nb


In [16]:
class SOM:
    def __init__(self, lattice:Lattice):
        self.neurons_, self.neurons_nb_ = lattice.generate()
                       
    def fit(self, X:np.ndarray, epochs):
        indexes = np.random.choice(X.shape[0], self.neurons_nb_, replace=False)

        for i in range(0, self.neurons_nb_):
            self.neurons_[i].weight_ = X[indexes[i]]

        for ep in range(epochs):
            for x in X:
                # Find BMU
                min_d = np.inf
                winner_idx = None

                for i in range(0, self.neurons_nb_):
                    d = self.hellinger_d(x, self.neurons_[i].weight_)
                    if d < min_d:
                        min_d = d
                        winner_idx = i


                h0 = 1
                sigma = 1/(ep + 1)

                # Updating all neurons
                for i in range(0, self.neurons_nb_):
                    learning_rate = self.neighbourhood_func(self.neurons_[i].coord_, self.neurons_[winner_idx].coord_, h0, sigma)
                    self.neurons_[i].weight_ +=  learning_rate * (x - self.neurons_[i].weight_)

                

    @staticmethod
    def neighbourhood_func(r_i, r_c, h0, sigma):
        distance = np.sum(np.square(r_i - r_c))
        if distance > 1:
            return 0

        return h0 * np.exp(-distance/(2*sigma**2))

    @staticmethod
    def hellinger_d(v1:np.ndarray, v2:np.ndarray):
        # Calculating Hellinger distance
        return np.sqrt(np.sum((np.sqrt(v1) - np.sqrt(v2))**2)) / np.sqrt(2)



In [17]:
def generate_dataset(n_samples, p):
    # p - vector of probability distributions

    distributions_nb = p.shape[0]
    repeat_nb = 10000

    data = []

    for i in range(0, n_samples):
        indexes = np.random.choice(distributions_nb, replace=False, size=2).astype(np.uint)

        sample = []
        for k in indexes:
            sample.append(np.random.multinomial(repeat_nb, p[k])/repeat_nb)
        
        data.append(sample)
    
    return np.array(data).reshape(n_samples, -1)

In [18]:
p = np.array([[1/3, 1/3, 1/3],
              [1/10, 1/10, 8/10],
              [1/4, 1/4, 1/2],
              [2/5, 1/5, 2/5]])

In [19]:
X = generate_dataset(2000, p)

In [20]:
X

array([[0.3987, 0.1943, 0.407 , 0.1025, 0.0972, 0.8003],
       [0.1031, 0.1022, 0.7947, 0.3352, 0.3316, 0.3332],
       [0.0979, 0.1083, 0.7938, 0.2639, 0.2508, 0.4853],
       ...,
       [0.2459, 0.2468, 0.5073, 0.3949, 0.1987, 0.4064],
       [0.4044, 0.2031, 0.3925, 0.3415, 0.3266, 0.3319],
       [0.2454, 0.2579, 0.4967, 0.3995, 0.2019, 0.3986]])

In [24]:
lattice = RectangularLattice((3, 4))
som = SOM(lattice)

In [25]:
som.fit(X, 100)

In [26]:
for neuron in som.neurons_:
    min_distance = np.inf
    indexes = None

    for i in range(len(p)):
            for j in range(len(p)):
                if i != j:        
                    p_distribution = np.concatenate((p[i], p[j]))
                    distance = SOM.hellinger_d(p_distribution, neuron.weight_)
                    if distance < min_distance:
                        min_distance = distance
                        best_distribution = p_distribution
                        indexes = (i, j)

    print(neuron.weight_)
    print(best_distribution)
    print(indexes)
    print()

[0.4044 0.2031 0.3925 0.3415 0.3266 0.3319]
[0.4        0.2        0.4        0.33333333 0.33333333 0.33333333]
(3, 0)

[0.2458 0.2477 0.5065 0.3312 0.3346 0.3342]
[0.25       0.25       0.5        0.33333333 0.33333333 0.33333333]
(2, 0)

[0.0971 0.0982 0.8047 0.3312 0.3375 0.3313]
[0.1        0.1        0.8        0.33333333 0.33333333 0.33333333]
(1, 0)

[0.1003 0.1059 0.7938 0.2545 0.2487 0.4968]
[0.1  0.1  0.8  0.25 0.25 0.5 ]
(1, 2)

[0.3999 0.1994 0.4007 0.2546 0.2549 0.4905]
[0.4  0.2  0.4  0.25 0.25 0.5 ]
(3, 2)

[0.2454 0.2579 0.4967 0.3995 0.2019 0.3986]
[0.25 0.25 0.5  0.4  0.2  0.4 ]
(2, 3)

[0.22439999 0.20344346 0.57215656 0.21076778 0.21035837 0.57887385]
[0.4  0.2  0.4  0.25 0.25 0.5 ]
(3, 2)

[0.2487 0.2465 0.5048 0.0981 0.104  0.7979]
[0.25 0.25 0.5  0.1  0.1  0.8 ]
(2, 1)

[0.3298 0.3307 0.3395 0.2537 0.2469 0.4994]
[0.33333333 0.33333333 0.33333333 0.25       0.25       0.5       ]
(0, 2)

[0.3281 0.3398 0.3321 0.4055 0.1982 0.3963]
[0.33333333 0.33333333 0.3333333