Skip to content

Commit

Permalink
added lineplots onclick
Browse files Browse the repository at this point in the history
  • Loading branch information
Chilipp committed Apr 5, 2020
1 parent 0e8f563 commit 4bec4c8
Showing 1 changed file with 100 additions and 6 deletions.
106 changes: 100 additions & 6 deletions psy_view/ds_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def __init__(self, ds, *args, **kwargs):

self.disable_navigation()

self.cids = {}

def clear_table(self):
self.dimension_table.clear()
self.dimension_table.setColumnCount(5)
Expand Down Expand Up @@ -409,6 +411,11 @@ def plotmethod_widget(self):
if self.plot_tabs.tabText(i) == label), None)
return self.plot_tabs.widget(i)

@property
def plotmethod_widgets(self):
return dict(zip(self.plotmethods, map(self.plot_tabs.widget,
range(self.plot_tabs.count()))))

_sp = None

def get_sp(self):
Expand Down Expand Up @@ -481,7 +488,6 @@ def make_plot(self):
new_v = self.variable
fmts = {}
if self.sp:
old_v = self.data.name
if not set(self.data.dims) <= set(self.ds[new_v].dims):
self.close_sp()
else:
Expand All @@ -501,9 +507,45 @@ def make_plot(self):
self.ani = None
self.sp = self.plot(
name=self.variable, **self.plot_options)
cid = self.plotter.ax.figure.canvas.mpl_connect(
'button_press_event', self.display_line)
self.cids[self.plotmethod] = cid
self.sp.show()
self.enable_navigation()

def display_line(self, event):
if not event.inaxes:
return
else:
sl = None
for widget in map(self.plot_tabs.widget,
range(self.plot_tabs.count())):
if widget.sp and event.inaxes == widget.plotter.ax:
sl = widget.get_slice(event.xdata, event.ydata)
break
variable = widget.data.name
raw_data = widget.data.psy.base.psy[variable]
if (sl is None or widget.plotmethod not in ['mapplot', 'plot2d'] or
raw_data.ndim == 2):
return
self.plotmethod = 'lineplot'
linewidget = self.plotmethod_widget
xdim = linewidget.xdim
if xdim is None:
xdim = self.dimension_checkbox.currentText()

if not linewidget.sp or (linewidget.xdim and
linewidget.xdim not in raw_data.dims):
with self.silence_variable_buttons():
for v, btn in self.variable_buttons.items():
btn.setChecked(v == variable)
self.make_plot()
linewidget.xdim = xdim
else:
linewidget.xdim = xdim
linewidget.add_line(variable, **sl)


def close_sp(self):
self.sp.close(True, True, True)
self.sp = None
Expand Down Expand Up @@ -695,6 +737,9 @@ def trigger_reset(self):
standardize_dims=False)[self.sp[0].psy.arr_name]
self.reset.emit(self.plotmethod)

def get_slice(self, x, y):
return None


class MapPlotWidget(PlotMethodWidget):

Expand Down Expand Up @@ -805,6 +850,42 @@ def get_fmts(self, var, init=False):
def refresh(self):
self.setEnabled(bool(self.sp))

def transform(self, x, y):
import cartopy.crs as ccrs
x, y = self.plotter.transform.projection.transform_point(
x, y, self.plotter.ax.projection)
# shift if necessary
if isinstance(self.plotter.transform.projection, ccrs.PlateCarree):
coord = self.plotter.plot.xcoord
if coord.min() >= 0 and x < 0:
x -= 360
elif coord.max() <= 180 and x > 180:
x -= 360
return x, y

def get_slice(self, x, y):
import numpy as np
data = self.data.psy.base.psy[self.data.name]
x, y = self.transform(x, y)
fmto = self.plotter.plot

xcoord = fmto.xcoord
ycoord = fmto.ycoord
if fmto.decoder.is_unstructured(fmto.raw_data) or xcoord.ndim == 2:
xy = xcoord.values.ravel() + 1j * ycoord.values.ravel()
dist = np.abs(xy - (x + 1j * y))
imin = np.nanargmin(dist)
if xcoord.ndim == 2:
ncols = data.shape[-2]
return dict(zip(data.dims[-2:],
[imin // ncols, imin % ncols]))
else:
return {data.dims[-1]: imin}
else:
x = xcoord.indexes[xcoord.name].get_loc(x, method='nearest')
y = ycoord.indexes[ycoord.name].get_loc(y, method='nearest')
return dict(zip(data.dims[-2:], [y, x]))


class Plot2DWidget(MapPlotWidget):

Expand All @@ -825,6 +906,9 @@ def edit_labels(self):
LabelDialog.update_project(
self.sp, 'figtitle', 'title', 'clabel', 'xlabel', 'ylabel')

def transform(self, x, y):
return x, y


class LinePlotWidget(PlotMethodWidget):

Expand Down Expand Up @@ -853,6 +937,15 @@ def setup_buttons(self):
"Labels", self.edit_labels,
"Edit title, x-label, legendlabels, etc.", self.formatoptions_box)

@property
def xdim(self):
return self.combo_dims.currentText()

@xdim.setter
def xdim(self, xdim):
if xdim != self.combo_dims.currentText():
self.combo_dims.setCurrentText(xdim)

@property
def data(self):
data = super().data
Expand All @@ -861,9 +954,9 @@ def data(self):
else:
return data[self.combo_lines.currentIndex()]

def add_line(self, name=None):
def add_line(self, name=None, **sl):
ds = self.data.psy.base
xdim = self.combo_dims.currentText()
xdim = self.xdim
if name is None:
name, ok = QtWidgets.QInputDialog.getItem(
self, 'New line', 'Select a variable',
Expand All @@ -872,8 +965,9 @@ def add_line(self, name=None):
if not ok:
return
arr = ds.psy[name]
sl = {key: val for key, val in self.data.psy.idims.items()
if key in arr.dims}
for key, val in self.data.psy.idims.items():
if key in arr.dims:
sl.setdefault(key, val)
for dim in arr.dims:
if dim != xdim:
sl.setdefault(dim, 0)
Expand All @@ -898,7 +992,7 @@ def item_texts(self):

def init_dims(self, var):
ret = {}
xdim = self.combo_dims.currentText() or var.dims[0]
xdim = self.xdim or var.dims[0]
if self.array_info:
arr_names = {}
for arrname, d in self.array_info.items():
Expand Down

0 comments on commit 4bec4c8

Please sign in to comment.