diff --git a/psy_view/ds_widget.py b/psy_view/ds_widget.py index d162c02..0997bf6 100644 --- a/psy_view/ds_widget.py +++ b/psy_view/ds_widget.py @@ -152,6 +152,8 @@ def __init__(self, ds, *args, **kwargs): self.disable_navigation() + self.cids = {} + def clear_table(self): self.dimension_table.clear() self.dimension_table.setColumnCount(5) @@ -409,6 +411,11 @@ def plotmethod_widget(self): if self.plot_tabs.tabText(i) == label), None) return self.plot_tabs.widget(i) + @property + def plotmethod_widgets(self): + return dict(zip(self.plotmethods, map(self.plot_tabs.widget, + range(self.plot_tabs.count())))) + _sp = None def get_sp(self): @@ -481,7 +488,6 @@ def make_plot(self): new_v = self.variable fmts = {} if self.sp: - old_v = self.data.name if not set(self.data.dims) <= set(self.ds[new_v].dims): self.close_sp() else: @@ -501,9 +507,45 @@ def make_plot(self): self.ani = None self.sp = self.plot( name=self.variable, **self.plot_options) + cid = self.plotter.ax.figure.canvas.mpl_connect( + 'button_press_event', self.display_line) + self.cids[self.plotmethod] = cid self.sp.show() self.enable_navigation() + def display_line(self, event): + if not event.inaxes: + return + else: + sl = None + for widget in map(self.plot_tabs.widget, + range(self.plot_tabs.count())): + if widget.sp and event.inaxes == widget.plotter.ax: + sl = widget.get_slice(event.xdata, event.ydata) + break + variable = widget.data.name + raw_data = widget.data.psy.base.psy[variable] + if (sl is None or widget.plotmethod not in ['mapplot', 'plot2d'] or + raw_data.ndim == 2): + return + self.plotmethod = 'lineplot' + linewidget = self.plotmethod_widget + xdim = linewidget.xdim + if xdim is None: + xdim = self.dimension_checkbox.currentText() + + if not linewidget.sp or (linewidget.xdim and + linewidget.xdim not in raw_data.dims): + with self.silence_variable_buttons(): + for v, btn in self.variable_buttons.items(): + btn.setChecked(v == variable) + self.make_plot() + linewidget.xdim = xdim + else: + linewidget.xdim = xdim + linewidget.add_line(variable, **sl) + + def close_sp(self): self.sp.close(True, True, True) self.sp = None @@ -695,6 +737,9 @@ def trigger_reset(self): standardize_dims=False)[self.sp[0].psy.arr_name] self.reset.emit(self.plotmethod) + def get_slice(self, x, y): + return None + class MapPlotWidget(PlotMethodWidget): @@ -805,6 +850,42 @@ def get_fmts(self, var, init=False): def refresh(self): self.setEnabled(bool(self.sp)) + def transform(self, x, y): + import cartopy.crs as ccrs + x, y = self.plotter.transform.projection.transform_point( + x, y, self.plotter.ax.projection) + # shift if necessary + if isinstance(self.plotter.transform.projection, ccrs.PlateCarree): + coord = self.plotter.plot.xcoord + if coord.min() >= 0 and x < 0: + x -= 360 + elif coord.max() <= 180 and x > 180: + x -= 360 + return x, y + + def get_slice(self, x, y): + import numpy as np + data = self.data.psy.base.psy[self.data.name] + x, y = self.transform(x, y) + fmto = self.plotter.plot + + xcoord = fmto.xcoord + ycoord = fmto.ycoord + if fmto.decoder.is_unstructured(fmto.raw_data) or xcoord.ndim == 2: + xy = xcoord.values.ravel() + 1j * ycoord.values.ravel() + dist = np.abs(xy - (x + 1j * y)) + imin = np.nanargmin(dist) + if xcoord.ndim == 2: + ncols = data.shape[-2] + return dict(zip(data.dims[-2:], + [imin // ncols, imin % ncols])) + else: + return {data.dims[-1]: imin} + else: + x = xcoord.indexes[xcoord.name].get_loc(x, method='nearest') + y = ycoord.indexes[ycoord.name].get_loc(y, method='nearest') + return dict(zip(data.dims[-2:], [y, x])) + class Plot2DWidget(MapPlotWidget): @@ -825,6 +906,9 @@ def edit_labels(self): LabelDialog.update_project( self.sp, 'figtitle', 'title', 'clabel', 'xlabel', 'ylabel') + def transform(self, x, y): + return x, y + class LinePlotWidget(PlotMethodWidget): @@ -853,6 +937,15 @@ def setup_buttons(self): "Labels", self.edit_labels, "Edit title, x-label, legendlabels, etc.", self.formatoptions_box) + @property + def xdim(self): + return self.combo_dims.currentText() + + @xdim.setter + def xdim(self, xdim): + if xdim != self.combo_dims.currentText(): + self.combo_dims.setCurrentText(xdim) + @property def data(self): data = super().data @@ -861,9 +954,9 @@ def data(self): else: return data[self.combo_lines.currentIndex()] - def add_line(self, name=None): + def add_line(self, name=None, **sl): ds = self.data.psy.base - xdim = self.combo_dims.currentText() + xdim = self.xdim if name is None: name, ok = QtWidgets.QInputDialog.getItem( self, 'New line', 'Select a variable', @@ -872,8 +965,9 @@ def add_line(self, name=None): if not ok: return arr = ds.psy[name] - sl = {key: val for key, val in self.data.psy.idims.items() - if key in arr.dims} + for key, val in self.data.psy.idims.items(): + if key in arr.dims: + sl.setdefault(key, val) for dim in arr.dims: if dim != xdim: sl.setdefault(dim, 0) @@ -898,7 +992,7 @@ def item_texts(self): def init_dims(self, var): ret = {} - xdim = self.combo_dims.currentText() or var.dims[0] + xdim = self.xdim or var.dims[0] if self.array_info: arr_names = {} for arrname, d in self.array_info.items():