In [1]:
import pandas as pd
import numpy as np
from ClusterPlot import ClusterPlot
from DataSetFactory import DataSetFactory
import plotly.graph_objects as go
import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def vis_2d(_df, _x, _y, _color, _algo):
    fig = px.scatter(_df, x=_x, y=_y, color=_color)
    fig.update_layout(title=f'{_algo}: 2D Visualization')
    fig.show()

In [3]:
# ds = DataSetFactory.get_dataset('fists_no_overlap')
# ds = DataSetFactory.get_dataset('cross')
ds = DataSetFactory.get_dataset('MNIST')

In [None]:
amap = ClusterPlot(verbose=True, n_intra_anchors=4, k=15, dim_reduction_algo='mds', anchors_method='mean_shift', radius_q=0.5, n_jobs=4)
low_dim = amap.fit_transform(ds.df[ds.feature_cols].values, ds.df[ds.label_col].values)

finding intra class anchors using mean_shift
Number of intra_class anchors (centroids) is 4
Bandwidth label: 0 = 7.255927816114452


In [None]:
dfs = []
for anchor_index in range(len(amap.intra_class_anchors_labels)):
    tmp_df = pd.DataFrame(amap.random_points_per_cluster(anchor_index), columns=['x','y'])
    tmp_df['label'] = amap.anchor_to_label_cluster(anchor_index)[0]#anchor_index
    dfs.append(tmp_df)
tmp_df = pd.concat(dfs)
# tmp_df['label'] = tmp_df['label'] + 10

In [None]:
df = pd.DataFrame(data=low_dim, columns=['x', 'y'])
df['label'] = amap.intra_class_anchors_labels
# df = pd.concat([df, tmp_df])  # Comment out if you want to display with random points
vis_2d(df, 'x', 'y', 'label', '')

In [None]:
# Test Radius
# amap.low_dim_anchors = amap.low_dim_anchors[0]
x_plus_arr = amap.low_dim_anchors.copy()
x_plus_arr[:, 0] = amap.low_dim_anchors[:, 0] + amap.anchors_radius
x_minus_arr = amap.low_dim_anchors.copy()
x_minus_arr[:, 0] = amap.low_dim_anchors[:, 0] - amap.anchors_radius
y_plus_arr = amap.low_dim_anchors.copy()
y_plus_arr[:, 1] = amap.low_dim_anchors[:, 1] + amap.anchors_radius
y_minus_arr = amap.low_dim_anchors.copy()
y_minus_arr[:, 1] = amap.low_dim_anchors[:, 1] - amap.anchors_radius
if amap.n_components > 2:
    z_plus_arr = amap.low_dim_anchors.copy()
    z_plus_arr[:, 2] = amap.low_dim_anchors[:, 2] + amap.anchors_radius
    z_minus_arr = amap.low_dim_anchors.copy()
    z_minus_arr[:, 2] = amap.low_dim_anchors[:, 2] - amap.anchors_radius
anchors_radius = np.concatenate([x_plus_arr, x_minus_arr, y_plus_arr, y_minus_arr])
if amap.n_components > 2:
    anchors_radius = np.concatenate([anchors_radius, z_plus_arr, z_minus_arr])
n_points_per_anchor = amap.n_components * 2
labels = []
for i in range(n_points_per_anchor):
    labels.extend(amap.intra_class_anchors_labels)
    
anchors_df = pd.DataFrame(anchors_radius, columns = ['x', 'y', 'z'] if amap.n_components > 2 else ['x', 'y'])
anchors_df['label'] = labels
# anchors_df['label'] = anchors_df['label'].apply(lambda x: 10 if x==1 else 20)
# df['label'] = df['label'].apply(lambda x: 30 if x==1 else 40)
df = pd.concat([df, anchors_df])
vis_2d(df, 'x', 'y', 'label', '')

# dfs = []
# n_points_per_anchor = len(amap.intra_class_anchors_labels) * amap.n_compenents * 2
# data = np.zeros((n_points_per_anchor, amap.n_components))
# for anchor_index in range(len(amap.intra_class_anchors_labels)):
#     minx = amap.low_dim_anchors[anchor_index][0] - amap.anchors_radius[anchor_index]
#     maxx = amap.low_dim_anchors[anchor_index][0] + amap.anchors_radius[anchor_index]
#     miny = amap.low_dim_anchors[anchor_index][1] - amap.anchors_radius[anchor_index]
#     maxy = amap.low_dim_anchors[anchor_index][1] + amap.anchors_radius[anchor_index]
#     if self.n_components > 2:
#         minz = amap.low_dim_anchors[anchor_index][2] - amap.anchors_radius[anchor_index]
#         maxz = amap.low_dim_anchors[anchor_index][2] + amap.anchors_radius[anchor_index]
#     data[anchor_index*n_points_per_anchor][0] = minx
#     data[anchor_index*n_points_per_anchor][1] = maxx
    
# tmp_df = pd.DataFrame(amap.random_points_per_cluster(anchor_index), columns=['x','y'])
# tmp_df['label'] = amap.anchor_to_label_cluster(anchor_index)[0]#anchor_index
# dfs.append(tmp_df)
# tmp_df = pd.concat(dfs)
# tmp_df['label'] = tmp_df['label'] + 10

In [None]:
amap.anchors_plot(alpha=0.3)

In [None]:
_, ax = plt.subplots(figsize=(10,10))
sns.heatmap(amap.inter_class_relations, ax=ax, annot=True, square=True, 
            xticklabels=[str(amap.anchor_to_label_cluster(i)) for i in range(amap.inter_class_relations.shape[0])],
            yticklabels=[str(amap.anchor_to_label_cluster(i)) for i in range(amap.inter_class_relations.shape[0])])

In [None]:
amap.inter_class_relations_low_dim

In [None]:
_, ax = plt.subplots(figsize=(10,10))
sns.heatmap(amap.inter_class_relations_low_dim, ax=ax, annot=True, square=True, 
            xticklabels=[str(amap.anchor_to_label_cluster(i)) for i in range(amap.inter_class_relations.shape[0])],
            yticklabels=[str(amap.anchor_to_label_cluster(i)) for i in range(amap.inter_class_relations.shape[0])])

In [None]:
labels

In [None]:
len(labels)

In [None]:
amap.intra_class_anchors_labels

In [None]:
len(amap.intra_class_anchors_labels)