In [5]:
import numpy as np
from astropy.io import fits
from sklearn import cluster


#===========================================Load in data=====================================
hdul = fits.open('PCL1002_cat_removeNaN.fits')
data = hdul[1].data
ra = data.field('ra')
dec = data.field('dec')
f250 = data.field('f350')
et250 = data.field('et250')
f350 = data.field('f350')
et350 = data.field('et350')
f500 = data.field('f500')
et500 = data.field('et500')
color_f250_f350_pre = data.field('f250/f350')
color_f350_f500_pre = data.field('f350/f500')

X = np.column_stack((ra, dec, f250, et250, f350, et350, f500, et500, color_f250_f350_pre, color_f350_f500_pre))
#X = np.column_stack((ra, dec, f250, et250, f350, et350, f500, et500))
#X = np.column_stack((ra, dec, color_f250_f350_pre, color_f350_f500_pre))

print(X)
print(type(X))
print(X.shape)

np.random.seed(0)
indices2 = np.random.permutation(len(X))

X_train = X[indices2[:-121]]  #from first to the last 10% (randomly permuated) datasets
X_test = X[indices2[-121:]]  #from the last 10% to the last (randomly permuated) datasets


#==============================setup classifier=========================================
#Mean shift clustering aims to discover “blobs” in a smooth density of samples. It is a centroid-based algorithm, which works by updating candidates for centroids to be the mean of the points within a given region.

mean_shift = cluster.MeanShift(cluster_all=False)
#cluster_all : boolean, default True. If true, then all points are clustered, even those orphans that are not within any kernel. Orphans are assigned to the nearest kernel. If false, then orphans are given cluster label -1.


mean_shift.fit(X)

#mean_shift.predict(X_test)

print(mean_shift.labels_) #Labels of each point. seq[::n] is a sequence of each n-th item in the entire sequence.

np.savetxt('output_meanshift.cat', mean_shift.labels_, delimiter=' ')

print(mean_shift.cluster_centers_)


[[149.99773      2.5780058   66.2635     ...   5.077003     0.9692827
    1.3175685 ]
 [150.11905      2.4570513   33.69685    ...   5.012378     2.5274227
    4.430969  ]
 [150.15826      2.139626    65.60122    ...   5.002751     1.2487625
    1.8051956 ]
 ...
 [150.1225       2.1299489    6.2136927  ...   5.0116863    0.69618124
    0.54290056]
 [150.24413      2.4969258    7.772629   ...   5.0401177    1.1416423
    1.1474057 ]
 [149.93184      2.3198414    7.176956   ...   4.991613     0.737808
    0.8937035 ]]
<class 'numpy.ndarray'>
(1206, 10)
[ 3 -1  3 ...  0  0  0]
[[1.5013509e+02 2.3688354e+00 8.6756496e+00 3.4098969e+00 8.6756496e+00
  6.5734034e+00 5.1890168e+00 5.0858178e+00 1.4510659e+00 2.3359704e+00]
 [1.5007050e+02 2.3256838e+00 4.4378780e+01 3.4099443e+00 4.4378780e+01
  6.5909619e+00 3.6717491e+01 5.1799579e+00 9.4812715e-01 1.2135097e+00]
 [1.5003903e+02 2.3615346e+00 7.4708929e+00 3.4100080e+00 7.4708929e+00
  6.5917759e+00 2.0744321e-01 5.2326713e+00 1.5398178e+00