1
+ from sklearn import datasets
2
+ import matplotlib .pyplot as plt
3
+ import numpy as np
4
+
5
+ iris = datasets .load_iris ()
6
+ X , y = iris .data , iris .target
7
+
8
+ # 为了便于可视化,只取两个维度
9
+ data = X [:,[1 ,3 ]]
10
+
11
+ print (data )
12
+
13
+ plt .scatter (data [:,0 ],data [:,1 ])
14
+
15
+ ck = 3
16
+ '''
17
+ 随机选取k个点为聚类的初始代表点,即质点
18
+ '''
19
+ def rand_center (data ,k ):
20
+ """Generate k center within the range of data set."""
21
+ n = data .shape [1 ] # features
22
+ centroids = np .zeros ((k ,n )) # init with (0,0)....
23
+ for i in range (n ):
24
+ dmin , dmax = np .min (data [:,i ]), np .max (data [:,i ])
25
+ centroids [:,i ] = dmin + (dmax - dmin ) * np .random .rand (k )
26
+ return centroids
27
+
28
+ # 初始化点列表
29
+ centroids = rand_center (data , ck )
30
+ print (centroids )
31
+
32
+ def kmeans (data ,k = 2 ):
33
+ def _distance (p1 ,p2 ):
34
+ """
35
+ Return Eclud distance between two points.
36
+ p1 = np.array([0,0]), p2 = np.array([1,1]) => 1.414
37
+ """
38
+ tmp = np .sum ((p1 - p2 )** 2 )
39
+ return np .sqrt (tmp )
40
+ def _rand_center (data ,k ):
41
+ """Generate k center within the range of data set."""
42
+ n = data .shape [1 ] # features
43
+ centroids = np .zeros ((k ,n )) # init with (0,0)....
44
+ for i in range (n ):
45
+ dmin , dmax = np .min (data [:,i ]), np .max (data [:,i ])
46
+ centroids [:,i ] = dmin + (dmax - dmin ) * np .random .rand (k )
47
+ return centroids
48
+
49
+ def _converged (centroids1 , centroids2 ):
50
+
51
+ # if centroids not changed, we say 'converged'
52
+ set1 = set ([tuple (c ) for c in centroids1 ])
53
+ set2 = set ([tuple (c ) for c in centroids2 ])
54
+ return (set1 == set2 )
55
+
56
+
57
+ n = data .shape [0 ] # number of entries
58
+ centroids = _rand_center (data ,k )
59
+ label = np .zeros (n ,dtype = np .int ) # track the nearest centroid
60
+ assement = np .zeros (n ) # for the assement of our model
61
+ converged = False
62
+
63
+ while not converged :
64
+ old_centroids = np .copy (centroids )
65
+ for i in range (n ):
66
+ # determine the nearest centroid and track it with label
67
+ min_dist , min_index = np .inf , - 1
68
+ for j in range (k ):
69
+ dist = _distance (data [i ],centroids [j ])
70
+ if dist < min_dist :
71
+ min_dist , min_index = dist , j
72
+ label [i ] = j
73
+ assement [i ] = _distance (data [i ],centroids [label [i ]])** 2
74
+
75
+ # update centroid
76
+ for m in range (k ):
77
+ centroids [m ] = np .mean (data [label == m ],axis = 0 )
78
+ converged = _converged (old_centroids ,centroids )
79
+ return centroids , label , np .sum (assement )
80
+
81
+
82
+ # 多运行
83
+ best_assement = np .inf
84
+ best_centroids = None
85
+ best_label = None
86
+
87
+ for i in range (10 ):
88
+ centroids , label , assement = kmeans (data ,ck )
89
+ if assement < best_assement :
90
+ best_assement = assement
91
+ best_centroids = centroids
92
+ best_label = label
93
+
94
+ data0 = data [best_label == 0 ]
95
+ data1 = data [best_label == 1 ]
96
+
97
+ # 打印展示
98
+ fig , (ax1 ,ax2 ) = plt .subplots (1 ,2 ,figsize = (12 ,5 ))
99
+ ax1 .scatter (data [:,0 ],data [:,1 ],c = 'c' ,s = 30 ,marker = 'o' )
100
+ ax2 .scatter (data0 [:,0 ],data0 [:,1 ],c = 'r' )
101
+ ax2 .scatter (data1 [:,0 ],data1 [:,1 ],c = 'c' )
102
+ ax2 .scatter (centroids [:,0 ],centroids [:,1 ],c = 'b' ,s = 120 ,marker = 'o' )
103
+ plt .show ()
0 commit comments