In [None]:
import os
import argparse
import numpy as np
import time 
import matplotlib.pyplot as plt
%matplotlib inline

### 新建一个seed参数

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--seed', type=int, default=42, help="Random seed.") #该函数表示添加一个参数，具体链接看https://www.jianshu.com/p/fef2d215b91d
args = parser.parse_args()

### 生成三个不同高斯分布的样本点

In [None]:
np.random.seed(args.seed)
u1,u2,u3 = np.asarray([1.0,1.0]),np.asarry([-1.0,1.0]),np.asarray([5.0,5.0])   #np.aasarry()转换为数组array格式
sigma1, sigma2, sigma3 = np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2) #生成2*2的随机矩阵
sigma1 = np.dot(sigma1, sigma1.T) * 1.0
sigma2 = np.dot(sigma2, sigma2.T) * 5.0
sigma3 = np.dot(sigma3, sigma3.T) * 2.0

num_pts = 100
#生成一个多元正态分布样本，u1表示分布均值，sigma1 * np.ones((2, 2))表示分布的协方差矩阵，size表示维度
samples1 = np.random.multivariate_normal(u1, sigma1 * np.ones((2, 2)), size=num_pts)  
samples2 = np.random.multivariate_normal(u2, sigma2 * np.ones((2, 2)), size=num_pts)
samples3 = np.random.multivariate_normal(u3, sigma3 * np.ones((2, 2)), size=num_pts)

### 可视化刚才生成三个样本的散点图

In [None]:
plt.figure()
plt.scatter(samples1[:, 0], samples1[:, 1], s=40, c="r", alpha=0.5)  #s参数表示点大小，alpha表示背景透明度
plt.scatter(samples2[:, 0], samples2[:, 1], s=40, c="b", alpha=0.5)
plt.scatter(samples3[:, 0], samples3[:, 1], s=40, c="g", alpha=0.5)
plt.grid(True) #配置网格线
plt.title("Ground Truth Clustering")
# plt.savefig("./clustering_{}.pdf".format(args.seed), bbox_inches="tight")
plt.show()

### 随机生成初始化数据

In [None]:
samples = np.vstack([samples1,samples2,samples3])  #将数据按列合并排列
rorder = np.arange(num_pts*3)
rorder = np.random.shuffle(rorder)  #shuffle乱序排列
samples = samples[rorder,:].squeeze() #numpy.squeeze(a,axis = None)从数组的形状中删除单维度条目，即把shape中为1的维度去掉
k=3
centers = np.random.rand(k,2)
num_iters = 10
losses=[]
xdist = np.sum(samples*samples,axis==1)  #axis=1表示按行相加

### K-means核心算法

In [None]:
for _ in range(num_iters):
    cdist = np.sum(centers*centers,axis=1)
    consts = xdist[:,np.newaxis]+cdist
    dists = consts-2*np.dot(samples,centers.T)  #计算到中心点的距离
    ids = np.argmin(dists,axis=1)   #找到距离最小的中心点，argmin返回的是沿轴方向的最小值的索引
    losses.append(np.sum(np.min(dists,axis=1)))
    for i in range(k):
        centers[i,:] = np.mean(samples[ids == i],axis = 0)  #更新中心点

### 聚类可视化

In [None]:
# Plot loss function.损失函数
plt.figure()
plt.plot(np.arange(num_iters), losses, "bo-", linewidth=4, markersize=10)
plt.grid(True)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("K-means loss function")
plt.show()

# Plot cluster assignment.聚类显示
plt.figure()
colors = ["r", "b", "g"]
for i in range(k):
    plt.scatter(samples[ids == i, 0], samples[ids == i, 1], c=colors[i], s=40, alpha=0.5)
plt.grid(True)
plt.title("K-means Clustering")
plt.savefig("./kmeans_{}.pdf".format(args.seed), bbox_inches="tight")
plt.show()