/
hinton.py
96 lines (76 loc) · 3.32 KB
/
hinton.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import collections
from matplotlib import transforms
from matplotlib import ticker
__all__ = ['hinton']
# TOOD: Add yutils.mpl._coll to mpltools and use that for square collection.
class SquareCollection(collections.RegularPolyCollection):
"""Return a collection of squares."""
def __init__(self, **kwargs):
super(SquareCollection, self).__init__(4, rotation=np.pi/4., **kwargs)
def get_transform(self):
"""Return transform scaling circle areas to data space."""
ax = self.axes
pts2pixels = 72.0 / ax.figure.dpi
scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width
scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height
return transforms.Affine2D().scale(scale_x, scale_y)
def hinton(inarray, max_value=None, use_default_ticks=True):
"""Plot Hinton diagram for visualizing the values of a 2D array.
Plot representation of an array with positive and negative values
represented by white and black squares, respectively. The size of each
square represents the magnitude of each value.
Unlike the hinton demo in the matplotlib gallery [1]_, this implementation
uses a RegularPolyCollection to draw squares, which is much more efficient
than drawing individual Rectangles.
.. note::
This function inverts the y-axis to match the origin for arrays.
.. [1] http://matplotlib.sourceforge.net/examples/api/hinton_demo.html
Parameters
----------
inarray : array
Array to plot.
max_value : float
Any *absolute* value larger than `max_value` will be represented by a
unit square.
use_default_ticks: boolean
Disable tick-generation and generate them outside this function.
"""
ax = plt.gca()
ax.set_axis_bgcolor('gray')
# make sure we're working with a numpy array, not a numpy matrix
inarray = np.asarray(inarray)
height, width = inarray.shape
if max_value is None:
max_value = 2**np.ceil(np.log(np.max(np.abs(inarray)))/np.log(2))
values = np.clip(inarray/max_value, -1, 1)
rows, cols = np.mgrid[:height, :width]
pos = np.where(values > 0)
neg = np.where(values < 0)
for idx, color in zip([pos, neg], ['white', 'black']):
if len(idx[0]) > 0:
xy = zip(cols[idx], rows[idx])
circle_areas = np.pi / 2 * np.abs(values[idx])
squares = SquareCollection(sizes=circle_areas,
offsets=xy, transOffset=ax.transData,
facecolor=color, edgecolor=color)
ax.add_collection(squares, autolim=True)
ax.axis('scaled')
# set data limits instead of using xlim, ylim.
ax.set_xlim(-0.5, width-0.5)
ax.set_ylim(height-0.5, -0.5)
if use_default_ticks:
ax.xaxis.set_major_locator(IndexLocator())
ax.yaxis.set_major_locator(IndexLocator())
class IndexLocator(ticker.Locator):
def __init__(self, max_ticks=10):
self.max_ticks = max_ticks
def __call__(self):
"""Return the locations of the ticks."""
dmin, dmax = self.axis.get_data_interval()
if dmax < self.max_ticks:
step = 1
else:
step = np.ceil(dmax / self.max_ticks)
return self.raise_if_exceeds(np.arange(0, dmax, step))