Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

Radviz #1566

Merged
merged 3 commits into from Jul 11, 2012
Jump to file or symbol
Failed to load files and symbols.
+93 −0
Split
@@ -341,3 +341,29 @@ confidence band.
@savefig autocorrelation_plot.png width=6in
autocorrelation_plot(data)
+
+RadViz
+~~~~~~
+
+RadViz is a way of visualizing multi-variate data. It is based on a simple
+spring tension minimization algorithm. Basically you set up a bunch of points in
+a plane. In our case they are equally spaced on a unit circle. Each point
+represents a single attribute. You then pretend that each sample in the data set
+is attached to each of these points by a spring, the stiffness of which is
+proportional to the numerical value of that attribute (they are normalized to
+unit interval). The point in the plane, where our sample settles to (where the
+forces acting on our sample are at an equilibrium) is where a dot representing
+our sample will be drawn. Depending on which class that sample belongs it will
+be colored differently.
+
+.. ipython:: python
+
+ from pandas import read_csv
+ from pandas.tools.plotting import radviz
+
+ data = read_csv('data/iris.data')
+
+ plt.figure()
+
+ @savefig radviz.png width=6in
+ radviz(data, 'Name')
@@ -315,6 +315,14 @@ def test_andrews_curves(self):
_check_plot_works(andrews_curves, df, 'Name')
@slow
+ def test_radviz(self):
+ from pandas import read_csv
+ from pandas.tools.plotting import radviz
+ path = os.path.join(curpath(), 'data/iris.csv')
+ df = read_csv(path)
+ _check_plot_works(radviz, df, 'Name')
+
+ @slow
def test_plot_int_columns(self):
df = DataFrame(np.random.randn(100, 4)).cumsum()
_check_plot_works(df.plot, legend=True)
View
@@ -147,6 +147,65 @@ def _get_marker_compat(marker):
return 'o'
return marker
+def radviz(frame, class_column, ax=None, **kwds):
+ """RadViz - a multivariate data visualization algorithm
+
+ Parameters:
+ -----------
+ frame: DataFrame object
+ class_column: Column name that contains information about class membership
+ ax: Matplotlib axis object, optional
+ kwds: Matplotlib scatter method keyword arguments, optional
+
+ Returns:
+ --------
+ ax: Matplotlib axis object
+ """
+ import matplotlib.pyplot as plt
+ import matplotlib.patches as patches
+ import matplotlib.text as text
+ import random
+ def random_color(column):
+ random.seed(column)
+ return [random.random() for _ in range(3)]
+ def normalize(series):
+ a = min(series)
+ b = max(series)
+ return (series - a) / (b - a)
+ column_names = [column_name for column_name in frame.columns if column_name != class_column]
+ columns = [normalize(frame[column_name]) for column_name in column_names]
+ if ax == None:
+ ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
+ classes = set(frame[class_column])
+ to_plot = {}
+ for class_ in classes:
+ to_plot[class_] = [[], []]
+ n = len(frame.columns) - 1
+ s = np.array([(np.cos(t), np.sin(t)) for t in [2.0 * np.pi * (i / float(n)) for i in range(n)]])
+ for i in range(len(frame)):
+ row = np.array([column[i] for column in columns])
+ row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
+ y = (s * row_).sum(axis=0) / row.sum()
+ class_name = frame[class_column][i]
+ to_plot[class_name][0].append(y[0])
+ to_plot[class_name][1].append(y[1])
+ for class_ in classes:
+ ax.scatter(to_plot[class_][0], to_plot[class_][1], color=random_color(class_), label=str(class_), **kwds)
+ ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
+ for xy, name in zip(s, column_names):
+ ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
+ if xy[0] < 0.0 and xy[1] < 0.0:
+ ax.text(xy[0] - 0.025, xy[1] - 0.025, name, ha='right', va='top', size='small')
+ elif xy[0] < 0.0 and xy[1] >= 0.0:
+ ax.text(xy[0] - 0.025, xy[1] + 0.025, name, ha='right', va='bottom', size='small')
+ elif xy[0] >= 0.0 and xy[1] < 0.0:
+ ax.text(xy[0] + 0.025, xy[1] - 0.025, name, ha='left', va='top', size='small')
+ elif xy[0] >= 0.0 and xy[1] >= 0.0:
+ ax.text(xy[0] + 0.025, xy[1] + 0.025, name, ha='left', va='bottom', size='small')
+ ax.legend(loc='upper right')
+ ax.axis('equal')
+ return ax
+
def andrews_curves(data, class_column, ax=None, samples=200):
"""
Parameters: