In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
from zwad.ad.preprocess import *
from zwad.ad.postprocess import *

from sklearn.cluster import Birch, KMeans, DBSCAN, SpectralClustering
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler

In [3]:
datadir = os.path.join('..', 'data')

oid, feature = load_dataset(os.path.join(datadir, 'oid_m31.dat'), os.path.join(datadir, 'feature_m31.dat'))
# Optionally with feature names:
# oid, feature = load_dataset(os.path.join(datadir, 'oid_m31.dat'), os.path.join(datadir, 'feature_m31.dat'), os.path.join(datadir, 'feature_m31.name'))
anomalies_gmm = load_ad_table(os.path.join(datadir, 'm31_gmm.csv'))

anomaly_oid = anomalies_gmm['oid']
anomaly_feature = extract_anomaly_features(anomaly_oid, oid, feature)

In [4]:
anomaly_feature

array([[2.3800001e+00, 2.9370630e-01, 4.8951048e-02, ..., 1.0970883e+00,
        8.7504232e-01, 1.6457632e+01],
       [2.0444999e+00, 3.5802469e-01, 6.1728396e-03, ..., 1.0822589e+00,
        9.0305829e-01, 1.6626255e+01],
       [8.0999947e-01, 3.9370079e-03, 3.9370079e-03, ..., 9.9654287e-02,
        3.1402901e-01, 1.6848158e+01],
       ...,
       [2.0720000e+00, 2.2222222e-01, 4.2735044e-02, ..., 6.2597132e-01,
        7.3348981e-01, 1.6854887e+01],
       [3.2999992e-01, 6.1224490e-02, 8.1632650e-03, ..., 4.4723526e-02,
        5.1094729e-01, 1.6834017e+01],
       [1.3650036e-01, 3.5826772e-01, 3.9370079e-03, ..., 7.0678927e-02,
        8.8457823e-01, 1.5309025e+01]], dtype=float32)

In [5]:
expert_table = load_expert_table(os.path.join(datadir, 'm31_maria.csv')).reset_index(drop=True)
expert_table

Unnamed: 0,oid,tag,tag_detailed,comments,alerts
0,695211400034403,artefact,bright star,bright star,
1,695211400124577,artefact,bright star,bright star,
2,695211400102351,artefact,bright star,bright star,
3,695211400053697,artefact,bright star,bright star,
4,695211200075348,transient,red star,"MASTER transient, red star; has spectra but st...",
5,695211400000352,artefact,bright star,bright star,
6,695211400088968,artefact,,empty field fits,
7,695211400117334,artefact,bright star,bright star,
8,695211400028274,artefact,bright star,bright star,
9,695211400133827,artefact,,empty field fits,


In [6]:
#scaler = StandardScaler()

### Birch

In [7]:
birch = Birch(n_clusters=4,
              threshold=0.5, 
              branching_factor=50)

clustering_birch = birch.fit_predict(scale_values(anomaly_feature))

In [8]:
table_birch = pd.DataFrame({'oid': anomaly_oid, 'cluster': clustering_birch}).merge(expert_table[['oid', 'tag', 'tag_detailed']])
table_birch = table_birch.sort_values(by='cluster', kind='mergesort')
table_birch

Unnamed: 0,oid,cluster,tag,tag_detailed
10,695211200075348,0,transient,red star
30,695211400009049,0,uncat,red star
36,695211200058391,0,uncat,red star
2,695211200009221,1,artefact,bad column
6,695211200020939,1,artefact,bad column
7,695211200008801,1,artefact,bad column
8,695211200015230,1,artefact,bad column
13,695211200008024,1,artefact,bad column
14,695211200009296,1,artefact,bad column
15,695211200009492,1,artefact,bad column


In [9]:
birch_summary = table_birch.groupby(["cluster", "tag"]).size()
birch_summary

cluster  tag       
0        transient      1
         uncat          2
1        artefact      20
2        Cl*            1
         HII region     1
         artefact      11
3        Cepheid        1
         RSG            1
         uncat          2
dtype: int64

### Kmeans

In [10]:
kmeans = KMeans(n_clusters=5, 
            init='k-means++', #selects initial cluster centers intelligently
            max_iter=300, #max number of iterations of single run
            n_init=50, #number of times to be run w/ different centriod seeding
            random_state=42)

clustering_kmeans = kmeans.fit_predict(scale_values(anomaly_feature))

In [11]:
table_kmeans = pd.DataFrame({'oid': anomaly_oid, 'cluster': clustering_kmeans}).merge(expert_table[['oid', 'tag', 'tag_detailed']])
table_kmeans = table_kmeans.sort_values(by='cluster', kind='mergesort')

In [12]:
table_kmeans

Unnamed: 0,oid,cluster,tag,tag_detailed
10,695211200075348,0,transient,red star
30,695211400009049,0,uncat,red star
2,695211200009221,1,artefact,bad column
6,695211200020939,1,artefact,bad column
7,695211200008801,1,artefact,bad column
8,695211200015230,1,artefact,bad column
13,695211200008024,1,artefact,bad column
14,695211200009296,1,artefact,bad column
15,695211200009492,1,artefact,bad column
16,695211200032218,1,artefact,bad column


In [13]:
kmeans_summary = table_kmeans.groupby(["cluster", "tag"]).size()
kmeans_summary

cluster  tag       
0        transient      1
         uncat          1
1        artefact      20
2        Cl*            1
         HII region     1
         artefact      11
3        uncat          1
4        Cepheid        1
         RSG            1
         uncat          2
dtype: int64

## DBSCAN

In [14]:
dbscan = DBSCAN(eps=8, #max distance for which two particles can be considered in the "same neighborhood"
                min_samples=1) #Number of SPH particles needed in neighborhood for datum to be considered "core point"

clustering_dbscan = dbscan.fit_predict(scale_values(anomaly_feature))

In [15]:
table_dbscan = pd.DataFrame({'oid': anomaly_oid, 'cluster': clustering_dbscan}).merge(expert_table[['oid', 'tag', 'tag_detailed']])
table_dbscan = table_dbscan.sort_values(by='cluster', kind='mergesort')
table_dbscan

Unnamed: 0,oid,cluster,tag,tag_detailed
0,695211400034403,0,artefact,bright star
1,695211400124577,0,artefact,bright star
2,695211200009221,0,artefact,bad column
3,695211400000352,0,artefact,bright star
4,695211400102351,0,artefact,bright star
5,695211400053697,0,artefact,bright star
6,695211200020939,0,artefact,bad column
7,695211200008801,0,artefact,bad column
8,695211200015230,0,artefact,bad column
9,695211400028274,0,artefact,bright star


In [16]:
dbscan_summary = table_dbscan.groupby(["cluster", "tag"]).size()
dbscan_summary

cluster  tag       
0        Cl*            1
         HII region     1
         artefact      31
1        transient      1
         uncat          1
2        Cepheid        1
         RSG            1
         uncat          2
3        uncat          1
dtype: int64

## GMM

In [17]:
gmm = GaussianMixture(n_components=8, #specify 5 clusters
                      covariance_type='full', #specify covariance matrix for components
                      max_iter=200, #max number of iterations of single run
                      init_params='kmeans', #initialize weights with kmeans
                      random_state=42) #Use a new random number generator seeded by 42

clustering_gmm = gmm.fit_predict(scale_values(anomaly_feature))

In [18]:
table_gmm = pd.DataFrame({'oid': anomaly_oid, 'cluster': clustering_gmm}).merge(expert_table[['oid', 'tag', 'tag_detailed']])
table_gmm = table_gmm.sort_values(by='cluster', kind='mergesort')
table_gmm

Unnamed: 0,oid,cluster,tag,tag_detailed
26,695211200046528,0,Cepheid,
27,695211200018901,0,RSG,
2,695211200009221,1,artefact,bad column
6,695211200020939,1,artefact,bad column
7,695211200008801,1,artefact,bad column
13,695211200008024,1,artefact,bad column
15,695211200009492,1,artefact,bad column
16,695211200032218,1,artefact,bad column
19,695211200020898,1,artefact,bad column
10,695211200075348,2,transient,red star


In [19]:
gmm_summary = table_gmm.groupby(["cluster", "tag"]).size()
gmm_summary

cluster  tag       
0        Cepheid        1
         RSG            1
1        artefact       7
2        transient      1
         uncat          1
3        artefact       9
4        uncat          1
5        Cl*            1
         HII region     1
         artefact       2
6        uncat          2
7        artefact      13
dtype: int64

## Spectral Clustering

In [20]:
spectral = SpectralClustering(n_clusters=6, 
                              eigen_solver='arpack',
                              affinity="nearest_neighbors")

clustering_spectral = spectral.fit_predict(scale_values(anomaly_feature))

In [21]:
table_spectral = pd.DataFrame({'oid': anomaly_oid, 'cluster': clustering_spectral}).merge(expert_table[['oid', 'tag', 'tag_detailed']])
table_spectral = table_spectral.sort_values(by='cluster', kind='mergesort')
table_spectral

Unnamed: 0,oid,cluster,tag,tag_detailed
0,695211400034403,0,artefact,bright star
1,695211400124577,0,artefact,bright star
3,695211400000352,0,artefact,bright star
4,695211400102351,0,artefact,bright star
9,695211400028274,0,artefact,bright star
11,695211400133827,0,artefact,
12,695211400088968,0,artefact,
18,695211400117334,0,artefact,bright star
24,695211400053352,0,Cl*,
32,695211400048384,0,HII region,


In [22]:
spectral_summary = table_spectral.groupby(["cluster", "tag"]).size()
spectral_summary

cluster  tag       
0        Cl*            1
         HII region     1
         artefact       8
         uncat          1
1        Cepheid        1
         RSG            1
         transient      1
         uncat          3
2        artefact       6
3        artefact       3
4        artefact       3
5        artefact      11
dtype: int64

In [23]:
print("BIRCH", birch_summary, 
      "\n KMEANS", kmeans_summary, 
      "\n DBSCAN", dbscan_summary,
      "\n GMM", gmm_summary,
      "\n SPECTRAL", spectral_summary)

BIRCH cluster  tag       
0        transient      1
         uncat          2
1        artefact      20
2        Cl*            1
         HII region     1
         artefact      11
3        Cepheid        1
         RSG            1
         uncat          2
dtype: int64 
 KMEANS cluster  tag       
0        transient      1
         uncat          1
1        artefact      20
2        Cl*            1
         HII region     1
         artefact      11
3        uncat          1
4        Cepheid        1
         RSG            1
         uncat          2
dtype: int64 
 DBSCAN cluster  tag       
0        Cl*            1
         HII region     1
         artefact      31
1        transient      1
         uncat          1
2        Cepheid        1
         RSG            1
         uncat          2
3        uncat          1
dtype: int64 
 GMM cluster  tag       
0        Cepheid        1
         RSG            1
1        artefact       7
2        transient      1
         uncat     