|
|
@@ -147,6 +147,65 @@ def _get_marker_compat(marker): |
|
|
return 'o'
|
|
|
return marker
|
|
|
|
|
|
+def radviz(frame, class_column, ax=None, **kwds):
|
|
|
+ """RadViz - a multivariate data visualization algorithm
|
|
|
+
|
|
|
+ Parameters:
|
|
|
+ -----------
|
|
|
+ frame: DataFrame object
|
|
|
+ class_column: Column name that contains information about class membership
|
|
|
+ ax: Matplotlib axis object, optional
|
|
|
+ kwds: Matplotlib scatter method keyword arguments, optional
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ --------
|
|
|
+ ax: Matplotlib axis object
|
|
|
+ """
|
|
|
+ import matplotlib.pyplot as plt
|
|
|
+ import matplotlib.patches as patches
|
|
|
+ import matplotlib.text as text
|
|
|
+ import random
|
|
|
+ def random_color(column):
|
|
|
+ random.seed(column)
|
|
|
+ return [random.random() for _ in range(3)]
|
|
|
+ def normalize(series):
|
|
|
+ a = min(series)
|
|
|
+ b = max(series)
|
|
|
+ return (series - a) / (b - a)
|
|
|
+ column_names = [column_name for column_name in frame.columns if column_name != class_column]
|
|
|
+ columns = [normalize(frame[column_name]) for column_name in column_names]
|
|
|
+ if ax == None:
|
|
|
+ ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
|
|
|
+ classes = set(frame[class_column])
|
|
|
+ to_plot = {}
|
|
|
+ for class_ in classes:
|
|
|
+ to_plot[class_] = [[], []]
|
|
|
+ n = len(frame.columns) - 1
|
|
|
+ s = np.array([(np.cos(t), np.sin(t)) for t in [2.0 * np.pi * (i / float(n)) for i in range(n)]])
|
|
|
+ for i in range(len(frame)):
|
|
|
+ row = np.array([column[i] for column in columns])
|
|
|
+ row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
|
|
|
+ y = (s * row_).sum(axis=0) / row.sum()
|
|
|
+ class_name = frame[class_column][i]
|
|
|
+ to_plot[class_name][0].append(y[0])
|
|
|
+ to_plot[class_name][1].append(y[1])
|
|
|
+ for class_ in classes:
|
|
|
+ ax.scatter(to_plot[class_][0], to_plot[class_][1], color=random_color(class_), label=str(class_), **kwds)
|
|
|
+ ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
|
|
|
+ for xy, name in zip(s, column_names):
|
|
|
+ ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
|
|
|
+ if xy[0] < 0.0 and xy[1] < 0.0:
|
|
|
+ ax.text(xy[0] - 0.025, xy[1] - 0.025, name, ha='right', va='top', size='small')
|
|
|
+ elif xy[0] < 0.0 and xy[1] >= 0.0:
|
|
|
+ ax.text(xy[0] - 0.025, xy[1] + 0.025, name, ha='right', va='bottom', size='small')
|
|
|
+ elif xy[0] >= 0.0 and xy[1] < 0.0:
|
|
|
+ ax.text(xy[0] + 0.025, xy[1] - 0.025, name, ha='left', va='top', size='small')
|
|
|
+ elif xy[0] >= 0.0 and xy[1] >= 0.0:
|
|
|
+ ax.text(xy[0] + 0.025, xy[1] + 0.025, name, ha='left', va='bottom', size='small')
|
|
|
+ ax.legend(loc='upper right')
|
|
|
+ ax.axis('equal')
|
|
|
+ return ax
|
|
|
+
|
|
|
def andrews_curves(data, class_column, ax=None, samples=200):
|
|
|
"""
|
|
|
Parameters:
|
|
|
|