In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('TkAgg')
import numpy as np
import torch

m1 = 3.
m2 = -3.
v1 = 2.
v2 = 1.5
p0 = 0.5

# 生成数据
def sample_generater(sample_nums, render=True):
    classA_num = int(sample_nums * p0)
    classA_mean, classA_var = m1, np.sqrt(v1)
    classB_num = int(sample_nums - classA_num)
    classB_mean, classB_var = m2, np.sqrt(v2)

    sampleA = np.random.randn(2, classA_num) * classA_var + classA_mean
    sampleB = np.random.randn(2, classB_num) * classB_var + classB_mean
    sample = np.hstack((sampleA, sampleB)).T

    if render:
        print(f"sample.shape = {sample.shape}")
        print(f"sample[:10] = \n{sample[:10]}")
        plt.figure()
        plt.scatter(x=sampleA[0], y=sampleA[1], s=3, marker="x", color="green")
        plt.scatter(x=sampleB[0], y=sampleB[1], s=3, marker="o", color="red")
        plt.grid("--")
        plt.show()

    return sample


def Gaussian_distribution(value, mean, var):
    return torch.exp(-(value - mean) ** 2 / (2 * var)) / torch.sqrt(2. * np.pi * var)


def get_prob(sample, mean1, var1, prob1, mean2, var2, prob2, axis="x"):
    x, y = zip(*sample)
    x = torch.tensor(x, dtype=torch.float32, requires_grad=False)
    y = torch.tensor(y, dtype=torch.float32, requires_grad=False)

    p1_x = Gaussian_distribution(x, mean1, var1) * prob1
    p2_x = Gaussian_distribution(x, mean2, var2) * prob2
    p1_y = Gaussian_distribution(y, mean1, var1) * prob1
    p2_y = Gaussian_distribution(y, mean2, var2) * prob2
    z1_x = p1_x / (p1_x + p2_x)
    z2_x = p2_x / (p1_x + p2_x)
    z1_y = p1_y / (p1_y + p2_y)
    z2_y = p2_y / (p1_y + p2_y)
    if axis == "x":
        return (z1_x, z2_x)
    else:
        return (z1_y, z2_y)


def E_step(samples, *parameters_):
    Q1, Q2 = get_prob(samples, *parameters_)
    # Q1 = torch.tensor(Q1, dtype=torch.float32)
    # Q2 = torch.tensor(Q2, dtype=torch.float32)
    return (Q1, Q2)


def M_step(samples, Q1, Q2, *parameters, iter_num_=1000, epoch_num_=1):
    x, y = zip(*samples)
    sample_x = torch.tensor(x, dtype=torch.float32, requires_grad=False)
    sample_y = torch.tensor(x, dtype=torch.float32)

    mean1, var1, prob1, mean2, var2, prob2 = parameters
    optimizer = torch.optim.Adam([mean1, var1, prob1, mean2, var2, ],)
    print(optimizer.param_groups[0]['params'])

    for iter in range(iter_num_):
        loss = -criterian(sample_x, Q1, Q2, *parameters)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (iter + 1) % 100 == 0:
            print(f"epoch_num= {epoch_num_+1}, iter_num = {iter+1}, loss = {loss}")
    parameters_final = optimizer.param_groups[0]['params']
    parameters_final.append(1-parameters_final[2])
    return parameters_final


def criterian(sample_input, Q1, Q2, *parameters):
    mean1, var1, prob1, mean2, var2, prob2 = parameters
    loss1 = torch.sum(Q1 * torch.log(Gaussian_distribution(sample_input, mean1, var1) * prob1 / Q1), dim=0)
    loss2 = torch.sum(Q2 * torch.log(Gaussian_distribution(sample_input, mean2, var2) * (1.-prob1) / Q2), dim=0)
    loss_tot = loss1 + loss2
    return loss_tot


def judge(sample, mean1, var1, prob1, mean2, var2, Samples):
    pb1 = Gaussian_distribution(sample, mean1, var1) * prob1
    pb2 = Gaussian_distribution(sample, mean2, var2) * prob2
    class1 = Samples[pb1 < pb2].T
    class2 = Samples[pb1 >= pb2].T

    plt.figure()
    plt.scatter(x=class1[0], y=class1[1], s=3, marker="x", color="green")
    plt.scatter(x=class2[0], y=class2[1], s=3, marker="o", color="red")
    plt.grid("--")
    plt.show()


if __name__ == "__main__":
    sample_nums = 5000
    epoch_num = 100
    iter_num = 500
    render = False

    Samples = sample_generater(sample_nums, render=True)
    x, y = zip(*Samples)
    sample_x = torch.tensor(x, dtype=torch.float32, requires_grad=False)
    Samples = torch.tensor(Samples, dtype=torch.float32, requires_grad=False)


    mean1_init, var1_init, prob1_init, mean2_init, var2_init, prob2_init = 1., 1., 0.3, -1., 1., 0.7

    mean1 = torch.tensor(mean1_init, dtype=torch.float32, requires_grad=True)
    mean2 = torch.tensor(mean2_init, dtype=torch.float32, requires_grad=True)
    var1 = torch.tensor(var1_init, dtype=torch.float32, requires_grad=True)
    var2 = torch.tensor(var2_init, dtype=torch.float32, requires_grad=True)
    prob1 = torch.tensor(prob1_init, dtype=torch.float32, requires_grad=True)
    prob2 = 1-prob1

    optimizer = torch.optim.Adam([mean1, var1, prob1, mean2, var2, ],)


    for epoch in range(epoch_num):
        # E_step:
        p1_x = Gaussian_distribution(sample_x, mean1, var1) * prob1
        p2_x = Gaussian_distribution(sample_x, mean2, var2) * (1-prob1)

        Q1 = (p1_x / (p1_x + p2_x)).detach()
        Q2 = (p2_x / (p1_x + p2_x)).detach()

        # M_step
        for iter in range(iter_num):
            mean1, var1, prob1, mean2, var2 = optimizer.param_groups[0]['params']
            loss1 = -torch.sum(Q1 * torch.log(Gaussian_distribution(sample_x, mean1, var1) * prob1 / Q1), dim=0)
            loss2 = -torch.sum(Q2 * torch.log(Gaussian_distribution(sample_x, mean2, var2) * (1. - prob1) / Q2), dim=0)
            loss_tot = loss1 + loss2
            # loss_tot = mean1+ var1+ prob1+ mean2+ var2+ Q1.sum()
            optimizer.zero_grad()
            loss_tot.backward()
            optimizer.step()
            if (iter + 1) % 100 == 0:
                print(f"epoch_num = {epoch + 1}, iter_num = {iter + 1}, loss = {loss_tot}")


        # parameters = M_step(Samples, *Q_value, *parameters, iter_num_=iter_num, epoch_num_=epoch)
        if (epoch+1) % 5 == 0:
            mean1, var1, prob1, mean2, var2 = optimizer.param_groups[0]['params']
            print(f"mean1 = {mean1.item()}")
            print(f"mean2 = {mean2.item()}")
            print(f"var1 = {var1.item()}")
            print(f"var2 = {var2.item()}")
            print(f"prob1 = {prob1.item()}")
            print(f"prob2 = {1-prob1.item()}")
            if render:
                judge(sample_x, mean1, var1, prob1, mean2, var2, Samples)

    para_final = [para.item() for para in optimizer.param_groups[0]['params']]
    mean1, var1, prob1, mean2, var2 = para_final
    print(f"mean1_final = {mean1}, mean1_origin = {m1}")
    print(f"mean2_final = {mean2}, mean2_origin = {m2}")
    print(f"var1_final = {var1}, var1_origin = {v1}")
    print(f"var1_final = {var2}, var1_origin = {v2}")
    print(f"prob1_final = {prob1  }, prob1_origin = {p0}  ")
    print(f"prob2_final = {1-prob1}, prob2_origin = {1-p0}")

'''

mean1_final = 2.9956448078155518, mean1_origin = 3.0
mean2_final = -3.0056324005126953, mean2_origin = -3.0
var1_final = 2.0761613845825195, var1_origin = 2.0
var1_final = 1.4601083993911743, var1_origin = 1.5
prob1_final = 0.5011658668518066, prob1_origin = 0.5  
prob2_final = 0.49883413314819336, prob2_origin = 0.5

'''

sample.shape = (5000, 2)
sample[:10] = 
[[2.88676981 2.48932276]
 [2.76622852 3.60118025]
 [4.45744344 4.58833376]
 [3.34710196 3.39227115]
 [2.47407757 1.09735596]
 [1.6481559  4.23144052]
 [3.03827208 2.32332117]
 [1.22050439 2.45919052]
 [2.33962504 1.75954251]
 [2.87476474 4.13609534]]
epoch_num = 1, iter_num = 100, loss = 20333.560546875
epoch_num = 1, iter_num = 200, loss = 18873.56640625
epoch_num = 1, iter_num = 300, loss = 17784.78515625
epoch_num = 1, iter_num = 400, loss = 16923.443359375
epoch_num = 1, iter_num = 500, loss = 16219.94140625
epoch_num = 2, iter_num = 100, loss = 15594.4716796875
epoch_num = 2, iter_num = 200, loss = 15097.28515625
epoch_num = 2, iter_num = 300, loss = 14672.8203125
epoch_num = 2, iter_num = 400, loss = 14307.267578125
epoch_num = 2, iter_num = 500, loss = 13990.3125
epoch_num = 3, iter_num = 100, loss = 13710.634765625
epoch_num = 3, iter_num = 200, loss = 13467.466796875
epoch_num = 3, iter_num = 300, loss = 13253.84765625
epoch_num = 3, ite

epoch_num = 27, iter_num = 400, loss = 11765.5537109375
epoch_num = 27, iter_num = 500, loss = 11765.5537109375
epoch_num = 28, iter_num = 100, loss = 11765.5537109375
epoch_num = 28, iter_num = 200, loss = 11765.5537109375
epoch_num = 28, iter_num = 300, loss = 11765.5537109375
epoch_num = 28, iter_num = 400, loss = 11765.5537109375
epoch_num = 28, iter_num = 500, loss = 11765.5537109375
epoch_num = 29, iter_num = 100, loss = 11765.5537109375
epoch_num = 29, iter_num = 200, loss = 11765.5537109375
epoch_num = 29, iter_num = 300, loss = 11765.5537109375
epoch_num = 29, iter_num = 400, loss = 11765.5537109375
epoch_num = 29, iter_num = 500, loss = 11765.5537109375
epoch_num = 30, iter_num = 100, loss = 11765.552734375
epoch_num = 30, iter_num = 200, loss = 11765.552734375
epoch_num = 30, iter_num = 300, loss = 11765.552734375
epoch_num = 30, iter_num = 400, loss = 11765.5537109375
epoch_num = 30, iter_num = 500, loss = 11765.552734375
mean1 = 2.990649938583374
mean2 = -2.982599973678589

epoch_num = 54, iter_num = 300, loss = 11765.5537109375
epoch_num = 54, iter_num = 400, loss = 11765.5537109375
epoch_num = 54, iter_num = 500, loss = 11765.552734375
epoch_num = 55, iter_num = 100, loss = 11765.552734375
epoch_num = 55, iter_num = 200, loss = 11765.5537109375
epoch_num = 55, iter_num = 300, loss = 11765.552734375
epoch_num = 55, iter_num = 400, loss = 11765.552734375
epoch_num = 55, iter_num = 500, loss = 11765.5537109375
mean1 = 2.9906179904937744
mean2 = -2.9825947284698486
var1 = 1.971841812133789
var2 = 1.5175108909606934
prob1 = 0.4988667368888855
prob2 = 0.5011332631111145
epoch_num = 56, iter_num = 100, loss = 11765.5537109375
epoch_num = 56, iter_num = 200, loss = 11765.552734375
epoch_num = 56, iter_num = 300, loss = 11765.5546875
epoch_num = 56, iter_num = 400, loss = 11765.552734375
epoch_num = 56, iter_num = 500, loss = 11765.5546875
epoch_num = 57, iter_num = 100, loss = 11765.5537109375
epoch_num = 57, iter_num = 200, loss = 11765.552734375
epoch_num = 5

epoch_num = 81, iter_num = 200, loss = 11765.5546875
epoch_num = 81, iter_num = 300, loss = 11765.5537109375
epoch_num = 81, iter_num = 400, loss = 11765.5537109375
epoch_num = 81, iter_num = 500, loss = 11765.5537109375
epoch_num = 82, iter_num = 100, loss = 11765.552734375
epoch_num = 82, iter_num = 200, loss = 11765.5546875
epoch_num = 82, iter_num = 300, loss = 11765.552734375
epoch_num = 82, iter_num = 400, loss = 11765.5537109375
epoch_num = 82, iter_num = 500, loss = 11765.552734375
epoch_num = 83, iter_num = 100, loss = 11765.5537109375
epoch_num = 83, iter_num = 200, loss = 11765.552734375
epoch_num = 83, iter_num = 300, loss = 11765.5546875
epoch_num = 83, iter_num = 400, loss = 11765.5537109375
epoch_num = 83, iter_num = 500, loss = 11765.552734375
