In [134]:
import os
import json
import pandas as pd
import numpy as np
import xarray as xr

In [27]:
# WIDGETS
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import display

In [122]:
# plot
import matplotlib.pyplot as plt
import bg_mpl_stylesheet
from bg_mpl_stylesheet.bg_mpl_stylesheet import bg_mpl_style
plt.style.use(bg_mpl_style)
plt.ion()

<matplotlib.pyplot._IonContext at 0x7fbcba116bd0>

In [125]:
def get_x_fitting_range(x, y, xrange: list):
    # x bound
    if xrange[0] == "min":
        xmin = float(x.min())
    else:
        xmin = float(xrange[0])
    if xrange[1] == "max":
        xmax = float(x.max())
    else:
        xmax = float(xrange[1])
    assert all([isinstance(xmin, float), isinstance(xmax, float)])
    
    bool_range = np.logical_and((x <= xmax), (x >= xmin))
    _x = x[bool_range]
    _y = list()
    for i, yi in enumerate(y):
        yi = yi[bool_range]
        _y.append(yi)
    return _x, np.array(_y)

def xy_dim_labels(xy_lable: str):
    dim_y = xy_lable[0]
    dim_x = xy_lable[1]
    if dim_y == 'i':
        dim_y = 'I (a.u.)'
    elif dim_y == 's':
        dim_y = 'S ($\mathrm{\AA^{-1}})$'
    elif dim_y == 'f':
        dim_y = 'F ($\mathrm{\AA^{-1}})$'
    elif dim_y == 'g':
        dim_y = 'G ($\mathrm{\AA^{-2}})$'
    if dim_x == 'q':
        dim_x = 'Q ($\mathrm{\AA^{-1}})$'
    elif dim_x == 'r':
        dim_x = 'r ($\mathrm{\AA})$'
    return dim_x, dim_y

def dump_da_json(da, path):
    da_dict = da.to_dict()
    with open(path, '+w') as f:
        json.dump(da_dict, f, sort_keys=False, indent=2)

In [126]:
def _setup_fig():
    if plt.fignum_exists(1):
        plt.clf()
    fig, ax = plt.subplots(num='plot')
    return fig, ax

def _setup_files(files):
    if isinstance(files, str):
        files = [files]
    return files

def _collect_data(files, xmin, xmax):
    # collect data 
    ys = list()
    for fname in files: 
        dat = np.loadtxt(fname).T
        x = dat[0]
        y = dat[1:]
        x, y = get_x_fitting_range(x, y, [xmin, xmax])
        ys.append(y)
    return x, ys

def _collect_feature_data(files):
    ys = list()
    for fname in files: 
        dat = np.loadtxt(fname).T
        if XLABELS:
            x = XLABELS
        else:
            x = dat[0]
        y = dat[1]
        ys.append(y)
    return x, ys

def get_rw(y1, y2):
    """ y1 = data ; y2 = calc """
    return float(np.sqrt((((y1 - y2) ** 2).sum()) / ((y2 ** 2).sum()))).__round__(3)

In [127]:
def _plot(fig, ax, x,y, yshift, data_type, fname):
    # assign
    y, diff, fit = y[0:3]
    
    # plot
    ax.set_title(f"Fit:  {fname}")
    dim_x, dim_y = xy_dim_labels(data_type)
    ax.set_xlabel(dim_x)
    ax.set_ylabel(dim_y)
    ax.set_xlim(x.min(),x.max())

    ## data
    ax.plot(x, np.zeros(len(x)), alpha=0.5, c='C4')
    ax.plot(x, y, alpha=0.6, marker='o', markeredgewidth=1, fillstyle='none', markersize=7,  c="C0", label='data')

    ## fit
    ax.plot(x, fit, alpha=0.8, c="C1", label='fit')

    ## diff
    ax.plot(x, np.zeros(len(x)) + (yshift), alpha=0.7, c='C4')
    ax.plot(x, diff + (yshift), alpha=0.8, c="C2", label='diff')
    
    ax.legend()
    fig.canvas.flush_events()

def _plot_compare(fig, ax, x, ys, yshift, data_type, files, draw_type, add_diff, diff_shift, mult_scaling, same_color, add_legends):
    # assign
    s = 0
    ax.plot(x, np.zeros(len(x)), alpha=0.5, c='C4')
    dim_x, dim_y = xy_dim_labels(data_type)
    ax.set_xlabel(dim_x)
    ax.set_ylabel(dim_y)
    ax.set_xlim(x.min(),x.max())
    ax.set_title(f'{draw_type}')
    diff = None
    for i, (y, fname) in enumerate(zip(ys, files)):
        if draw_type == 'data': 
            y = y[0]
        if draw_type == 'fit': 
            y = y[2]
        if draw_type == 'diff': 
            y = y[1]
        if i == 0:
            diff = y 
        if i == 1:
            y = y * mult_scaling
            diff -= y
        if same_color:
            c = 'C0'
        else:
            c = f'C{i}'
        ax.plot(x, y  + s, alpha=0.8, label=fname, c=c)
        s += yshift 
    if len(files) == 2 and add_diff:
#         rw = get_rw(ys[0], ys[1])
#         rw = np.abs(np.sum(diff)).__round__(3)
        ax.plot(x, np.zeros(len(x)) + diff_shift, alpha=0.5, c='C4')
        ax.plot(x, diff + diff_shift, alpha=0.8, label=f'diff')
    
    if add_legends:
        ax.legend()
        
    fig.canvas.flush_events()

    
def _scatter_features(fig, ax, x, ys, files, marker1, marker2, linewidth):
    # assign
    ax.set_xlabel(DIMX)
    ax.set_ylabel(DIMY)
    ax.set_title(f'{TITLE}')
    for i, (y, fname) in enumerate(zip(ys, files)):
        if len(ys) == 2:
            if i == 0:
                ax.plot(x, y, alpha=0.8, lw = linewidth, marker = marker1, label=fname)
            elif i == 1:
                ax.plot(x, y, alpha=0.8,  lw = linewidth, marker = marker2, label=fname)
        else:
            ax.plot(x, y, alpha=0.8,  lw = linewidth, marker = i, label=fname)

    ax.legend()
    fig.canvas.flush_events()

In [128]:
def draw(fname, data_type:str, xmin:float, xmax:float, yshift:float):
    files = _setup_files(fname)
    fig, ax = _setup_fig()
    x, ys = _collect_data(files, xmin, xmax)
    y = ys[0]
    _plot(fig, ax, x,y, yshift, data_type, fname)
    
def draw_compare_data(compare_all: bool, file1:str, file2:str, draw_type: str, add_diff:bool,
                      data_type:str, xmin:float, xmax:float, yshift:float, diff_shift:float, mult_scaling:float, 
                      same_color: bool, add_legends: bool):
    files = FILES
    if not compare_all:
        files = [file1] + [file2]
    fig, ax = _setup_fig()
    x, ys = _collect_data(files, xmin, xmax)
    _plot_compare(fig, ax, x, ys, yshift, data_type, files, draw_type, add_diff, diff_shift, mult_scaling, same_color, add_legends)   
    
def draw_features(compare_all: bool, file1:str, file2:str, marker1, marker2, linewidth):
    files = FILES
    if not compare_all:
        files = [file1] + [file2]
    fig, ax = _setup_fig()
    x, ys = _collect_feature_data(files)
    _scatter_features(fig, ax, x, ys, files, marker1, marker2, linewidth) 

In [129]:
def set_dirpath(dirpath):
    os.chdir(os.path.abspath(dirpath))

In [130]:
path = widgets.FileUpload(
    accept='',  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=True)
w_path = widgets.HBox([widgets.Label("files :"), path])

In [131]:
interact_manual(set_dirpath, dirpath='')
display(w_path)

interactive(children=(Text(value='', description='dirpath'), Button(description='Run Interact', style=ButtonSt…

HBox(children=(Label(value='files :'), FileUpload(value={}, description='Upload', multiple=True)))

In [132]:
# global 
FILES = list(path.value.keys())
print(os.getcwd())
print(FILES)

/home/yr2369/dev/pdfgui_plotter/myplotter/sim_pair_clustering
['0.5', '0.6', '0.7', '0.8', '0.9', '1.0', '1.1', '1.2', '1.3', '1.4', '1.5']


In [74]:
# get x range
dat = np.loadtxt(FILES[0]).T
x = dat[0]

In [75]:
interactive_draw = interactive(draw, 
                               fname = FILES, 
                               data_type = ['gr', 'fq', 'sq', 'iq'], 
                               xmin=(x.min()-x.min()/10, x.max()), 
                               xmax = (x.min(),x.max()+x.max()/10), 
                               yshift=(-5,5,0.1))

In [76]:
%matplotlib qt

In [77]:
# %matplotlib inline

In [78]:
interactive_draw

interactive(children=(Dropdown(description='fname', options=('0.5', '0.6', '0.7', '0.8', '0.9', '1.0', '1.1', …

In [79]:
interactive_draw_compare_data = interactive(draw_compare_data, 
                                            compare_all = False,
                                            file1 = FILES,
                                            file2 = FILES,
                                            draw_type = ['data', 'fit', 'diff'],
                                            mult_scaling = (0, 2, 0.01),
                                            add_diff = False,
                                            data_type = ['gr', 'fq', 'sq', 'iq'], 
                                            xmin=(x.min()-x.min()/10, x.max()), 
                                            xmax = (x.min(),x.max()+x.max()/10), 
                                            yshift=(-5,5,0.1), 
                                            diff_shift = (-5,5,0.1),
                                           same_color = False, 
                                           add_legends = False)

In [80]:
interactive_draw_compare_data

interactive(children=(Checkbox(value=False, description='compare_all'), Dropdown(description='file1', options=…

In [81]:
XLABELS = np.arange(4,62,2).tolist()
DIMX = 'x'
DIMY = 'y'
TITLE = 'test'
markers = ['o', '.', 'v','<', 'X', 's', 'p', 'h', '8']
interactive_draw_features = interactive(draw_features, 
                                            compare_all = False,
                                            file1 = FILES,
                                            file2 = FILES,
                                            marker1 = markers,
                                            marker2 = markers,
                                            linewidth = [0,1,2,3,4,5])

In [82]:
interactive_draw_features

interactive(children=(Checkbox(value=False, description='compare_all'), Dropdown(description='file1', options=…

In [208]:
# serialize to xr json dict for compatibility with 
xmin, xmax = (2,4)
x, ys = _collect_data(FILES, xmin, xmax)
for y, f in zip(ys, FILES):
    data = y[0]
    diff = y[1]
    fit = y[2]
    da = xr.DataArray(np.array([x,fit]), dims = ('r', 'g'), attrs={'config':{'xmin':xmin, 'xmax':xmax, 'path':os.path.abspath(f)}})
    path = "xrda_00-" + f.replace('.', '') + ".gr"
    path = os.path.join('for_minipipes', path)
    dump_da_json(da, path)

In [206]:
da