In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.cm import ScalarMappable
from matplotlib import patches
import ipywidgets as widgets
sns.set_style('whitegrid')
sns.set_color_codes()
sns.set_context('poster')
%matplotlib notebook

In [2]:
data = pd.read_csv('./hw_2_data/flowers.csv')

In [3]:
def plot_features(feature_list, color_by):
    axes = []
    if color_by.dtype not in (int, float):
        categories = color_by.unique()
        palette = sns.color_palette(palette='husl', n_colors=len(categories))
        colormap = {cat: palette[i] for i, cat in enumerate(categories)}
        colors = [colormap[color] for color in color_by]
        handles = [patches.Patch(color=colormap[cat]) for cat in categories]
        labels = categories
    else:
        colormap = ScalarMappable(cmap='viridis')
        colormap.set_array(color_by)
        colors = colormap.to_rgba(color_by)
    n_features = len(feature_list)
    fig = plt.figure(figsize=(3*n_features, 3*n_features))
    for i in range(n_features):
        for j in range(i+1):
            ax = plt.subplot(n_features, n_features, n_features*i+j+1)
            ax.set_xlim(min(feature_list[j]), max(feature_list[j]))
            ax.set_ylim(min(feature_list[i]), max(feature_list[i]))
            ax.scatter(feature_list[j], feature_list[i], color=colors, linewidth=0)
            if j != 0:
                plt.setp(ax.get_yticklabels(), visible=False)
            else:
                ax.set_ylabel(feature_list[i].name)
            if i != n_features-1:
                plt.setp(ax.get_xticklabels(), visible=False)
            else:
                ax.set_xlabel(feature_list[j].name)
            axes.append(ax)
    if color_by.dtype not in (int, float):
        fig.legend(handles=handles, labels=labels, bbox_to_anchor=(0.9, 0.9), title=color_by.name)
    else:
        fig.colorbar(colormap, ax=axes, label=color_by.name)
    return fig

In [92]:
class DrawRectangle(object):
    
    def __init__(self, fig):
        self.fig = fig
        self.x0 = None
        self.x1 = None
        self.y0 = None
        self.y1 = None
        self.start_axis = None
        self.rect = None
        self.cids = []
        self.is_pressed = False
        
    def connect(self):
        self.cids.append(self.fig.canvas.mpl_connect('button_press_event', self.on_click))
        self.cids.append(self.fig.canvas.mpl_connect('motion_notify_event', self.on_motion))
        self.cids.append(self.fig.canvas.mpl_connect('button_release_event', self.on_release))
        
    def on_click(self, event):
        if self.rect is not None:
            self.rect.remove()
        self.x0 = event.xdata
        self.y0 = event.ydata
        self.rect = patches.Rectangle((self.x0, self.y0), 0, 0, alpha=0.3)
        self.start_axis = event.inaxes
        self.start_axis.add_patch(self.rect)
        self.is_pressed = True
        
        self.background = self.fig.canvas.copy_from_bbox(self.start_axis.bbox)
        self.fig.canvas.blit(self.start_axis.bbox)
        
    def on_motion(self, event):
        if self.is_pressed and event.inaxes == self.start_axis:
            self.x1 = event.xdata
            self.y1 = event.ydata
            self.rect.set_width(self.x1 - self.x0)
            self.rect.set_height(self.y1 - self.y0)
            self.fig.canvas.restore_region(self.background)
            self.start_axis.draw_artist(self.rect)
            self.fig.canvas.blit(self.start_axis.bbox)

    def on_release(self, event):
        if self.is_pressed and event.inaxes == self.start_axis:
            self.x1 = event.xdata
            self.y1 = event.ydata
            self.rect.set_width(self.x1 - self.x0)
            self.rect.set_height(self.y1 - self.y0)
            self.fig.canvas.restore_region(self.background)
            self.start_axis.draw_artist(self.rect)
            self.fig.canvas.blit(self.start_axis.bbox)
        self.is_pressed = False
    
    def disconnect(self):
        [self.fig.canvas.mpl_disconnect(cid) for cid in cids]

In [93]:
feature_list = [data['sepal length'],
                data['sepal width']
]
w = widgets.HTML()
fig = plot_features(feature_list, color_by=data['species'])
r = DrawRectangle(fig)
r.connect()
w

<IPython.core.display.Javascript object>