In [50]:
from ipywidgets import Button, HBox, VBox, Checkbox, FileUpload, Label, Output, Layout
import ipywidgets as widgets
from pycalphad import Database, variables as v
from pycalphad.core.utils import filter_phases
from IPython.display import Markdown
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor
from nbconvert import HTMLExporter

class State():
    dbf = None
    db_filename = None
    selectable_conditions = [('','')]

global_state = State()
out = Output(layout={'border': '1px solid black'})
calc_result = Output()


generate = Button(description='Generate Code', layout=Layout(width='250px'))
uploader = FileUpload(description='Choose Database', multiple=False, accept='*.tdb',
                     layout=Layout(width='250px'))
elements = widgets.SelectMultiple(
    options=[],
    value=[],
    description='Elements',
    disabled=False
)
phases = widgets.SelectMultiple(
    options=[],
    value=[],
    description='Phases',
    disabled=False
)

all_conditions = VBox([])

def add_condition(b):
    new_cond_input = make_conditions_input()
    caller_widget_idx = all_conditions.children.index(b.parent)
    previous_children = list(all_conditions.children)
    new_children = previous_children[:caller_widget_idx+1] + \
                   [new_cond_input] + \
                   previous_children[caller_widget_idx+1:]
    all_conditions.children = tuple(new_children)
    for inp in all_conditions.children:
        inp.del_cond_button.disabled = False

def del_condition(b):
    if len(all_conditions.children) > 1:
        with out:
            caller_widget_idx = all_conditions.children.index(b.parent)
            all_conditions.children[caller_widget_idx].close()
            previous_children = list(all_conditions.children)
            new_children = previous_children[:caller_widget_idx] + \
                           previous_children[caller_widget_idx+1:]
            all_conditions.children = tuple(new_children)
        if len(all_conditions.children) == 1:
            all_conditions.children[0].del_cond_button.disabled = True

def toggle_range(change):
    with out:
        caller_widget_idx = all_conditions.children.index(change['owner'].parent)
        use_range = change['owner'].value
        conditions_input = all_conditions.children[caller_widget_idx]
        if use_range:
            min_value = widgets.FloatText(layout=Layout(width='70px'))
            max_value = widgets.FloatText(layout=Layout(width='70px'))
            step_size = widgets.FloatText(layout=Layout(width='70px'))
            range_widget = HBox([Label('From'), min_value, Label('To'), max_value, Label('Step Size'), step_size])
            range_widget.get_range = lambda: (min_value.value, max_value.value, step_size.value)
            conditions_input.conditions_value = range_widget
            conditions_input.children = [conditions_input.conditions_selector,
                                         conditions_input.conditions_value,
                                         conditions_input.range_check,
                                         conditions_input.add_cond_button,
                                         conditions_input.del_cond_button]
        else:                
            conditions_input.conditions_value = widgets.FloatText(layout=Layout(width='150px'))
            conditions_input.children = [conditions_input.conditions_selector,
                                         conditions_input.range_check,
                                         conditions_input.conditions_value,
                                         conditions_input.add_cond_button,
                                         conditions_input.del_cond_button]

def make_conditions_input():
    add_cond_button = widgets.Button(icon='plus-square',
                                    layout=Layout(width='40px'))
    del_cond_button = widgets.Button(icon='minus-square',
                                    layout=Layout(width='40px'))
    range_check = widgets.ToggleButton(icon='arrows-h',
                                       value=False, layout=Layout(width='40px'))
    conditions_selector = widgets.Dropdown(
        options=global_state.selectable_conditions,
        value='',
        # rows=10,
        description='Condition:',
        disabled=False,
        layout=Layout(width='150px')
    )
    conditions_value = widgets.FloatText(layout=Layout(width='150px'))
    conditions_input = HBox([conditions_selector,
                             range_check,
                             conditions_value,
                             add_cond_button,
                             del_cond_button
                             ])
    conditions_input.conditions_selector = conditions_selector
    conditions_input.conditions_value = conditions_value
    conditions_input.add_cond_button = add_cond_button
    conditions_input.del_cond_button = del_cond_button
    conditions_input.range_check = range_check
    add_cond_button.parent = conditions_input
    del_cond_button.parent = conditions_input
    range_check.parent = conditions_input
    add_cond_button.on_click(add_condition)
    del_cond_button.on_click(del_condition)
    range_check.observe(toggle_range)
    return conditions_input

all_conditions.children = (make_conditions_input(), )
all_conditions.children[0].del_cond_button.disabled = True

x_axis_selector = widgets.Dropdown(
    options=global_state.selectable_conditions,
    value='',
    # rows=10,
    description='X-Axis:',
    disabled=False,
    layout=Layout(width='300px')
)
y_axis_selector = widgets.Dropdown(
    options=[('GM       - Gibbs Energy', 'GM(*)'), ('NP       - Phase Amount', 'NP(*)')] + global_state.selectable_conditions,
    value='',
    # rows=10,
    description='Y-Axis:',
    disabled=False,
    layout=Layout(width='300px')
)

output_selector = widgets.VBox([x_axis_selector, y_axis_selector, generate])
output_selector.x_axis_selector = x_axis_selector
output_selector.y_axis_selector = y_axis_selector

def handle_upload(change):
    *_, (fname, f) = change['new'].items()
    tdb_content = f['content'].decode('utf-8')
    out.clear_output()
    with out:
        print('Loading...')
        try:
            global_state.dbf = Database(tdb_content)
            global_state.dbf.to_file(fname, if_exists='overwrite')
            global_state.db_filename = fname
        except:
            raise
        else:
            out.clear_output()
    elements.options = sorted(global_state.dbf.elements)

def update_phases_and_set_conditions(change):
    with out:
        selected_elements = sorted(map(v.Species, change['new']))
        nonvacant_elements = [x for x in selected_elements if x.number_of_atoms>0]
        out.clear_output()
        phases.options = filter_phases(global_state.dbf, selected_elements)
        new_cond_options = [('', ''),
                            ('N       - Total Moles', v.N),
                            ('P       - Pressure (Pa)', v.P),
                            ('T       - Temperature (K)', v.T)]
        for el in nonvacant_elements:
            new_cond_options.append((f'X({el.name}) - Mole Fraction {el.name}', f'X(\'{el.name}\')'))
        for el in nonvacant_elements:
            new_cond_options.append((f'MU({el.name}) - Chemical Potential {el.name}', f'MU(\'{el.name}\')'))
        global_state.selectable_conditions = new_cond_options
        # if fewer than the number of conditions needed to satisfy Gibbs phase rule are rpesent
        if len(all_conditions.children) < len(nonvacant_elements) + 2:
            additional_conds_needed = (len(nonvacant_elements) + 2) - len(all_conditions.children)
            new_conds = tuple(make_conditions_input() for _ in range(additional_conds_needed))
            all_conditions.children = all_conditions.children + new_conds
        for inp in all_conditions.children:
            inp.conditions_selector.disabled = False
            inp.conditions_selector.options = global_state.selectable_conditions
        output_selector.x_axis_selector.options = global_state.selectable_conditions
        output_selector.y_axis_selector.options = [('GM       - Gibbs Energy', 'GM(*)'), ('NP       - Phase Amount', 'NP(*)')] + global_state.selectable_conditions

eq_code_template = """from pycalphad import Database, equilibrium, variables as v

dbf = Database({dbf})
comps = {comps}
phases = {phases}
conds = {conds}
eq = equilibrium(dbf, comps, phases, conds)
print(eq)
"""

bin_code_template = """%matplotlib inline
import matplotlib.pyplot as plt
from pycalphad import Database, binplot
import pycalphad.variables as v

# Load database and choose the phases that will be considered
dbf = Database({dbf})
my_phases = {phases}

# Create a matplotlib Figure object and get the active Axes
fig = plt.figure(figsize=(9,6))
axes = fig.gca()

# Compute the phase diagram and plot it on the existing axes
binplot(dbf, {comps}, my_phases, {conds}, plot_kwargs={{'ax': axes}})

plt.show()
"""

gm_code_template = """%matplotlib inline
import matplotlib.pyplot as plt
from pycalphad import Database, calculate, variables as v
from pycalphad.plot.utils import phase_legend
import numpy as np

# Load database and choose the phases that will be plotted
dbf = Database({dbf})
my_phases = {phases}

# Get the colors that map phase names to colors in the legend
legend_handles, color_dict = phase_legend(my_phases)

fig = plt.figure(figsize=(9,6))
ax = fig.gca()

# Loop over phases, calculate the Gibbs energy, and scatter plot
for phase_name in my_phases:
    result = calculate(dbf, {comps}, phase_name, {conds}, output='GM')
    ax.scatter({xarr}, {yarr}, marker='.', s=5, color=color_dict[phase_name])

# Format the plot
ax.set_xlabel('{xlabel}')
ax.set_ylabel('{ylabel}')

ax.legend(handles=legend_handles, loc='center left', bbox_to_anchor=(1, 0.6))
plt.show()
"""

@out.capture()
def generate_code(b):
    print('lol')
    db_filename = global_state.db_filename
    comps = sorted(elements.value)
    ph = sorted(phases.value)
    cond_list = []
    for inp in all_conditions.children:
        cond_name = 'v.' + str(inp.conditions_selector.value)
        if hasattr(inp.conditions_value, 'get_range'):
            cond_value = str(inp.conditions_value.get_range())
        else:
            cond_value = str(inp.conditions_value.value)
        cond_list.append(cond_name + ': ' + cond_value)
    conds = '{' + ', '.join(cond_list) + '}'
    
    desired_template = None
    output_x = str(output_selector.x_axis_selector.value)
    output_y = str(output_selector.y_axis_selector.value)

    if (output_y == 'T') and output_x.startswith('X('):
        result = bin_code_template.format(dbf='\''+db_filename+'\'', comps=comps,
                                      phases=ph, conds=conds)
    elif (output_y == 'GM(*)'):
        desired_template = gm_code_template
        cond_list = []
        for inp in all_conditions.children:
            cond_name = str(inp.conditions_selector.value)
            if not (cond_name in ['N', 'T', 'P']):
                continue
            if hasattr(inp.conditions_value, 'get_range'):
                cond_value = str(inp.conditions_value.get_range())
            else:
                cond_value = str(inp.conditions_value.value)
            cond_list.append(cond_name + '=' + cond_value)
        conds = ', '.join(cond_list)
        yarr = 'result.GM'
        ylabel = 'GM'
        xarr = ''
        xlabel = output_x
        if output_x == 'T':
            xarr = 'result[\'T\']'
            xlabel = 'Temperature (K)'
        elif output_x.startswith('X('):
            desired_comp = output_x[3:5]
            xarr = 'result.X.sel(component=\'{0}\')'.format(desired_comp)
            xlabel = 'X({0})'.format(desired_comp)
        else:
            raise ValueError('Unknown plot type')
        result = desired_template.format(xarr=xarr, xlabel=xlabel,
                                                   yarr=yarr, ylabel=ylabel,
                                                  dbf='\''+db_filename+'\'', comps=comps,
                                      phases=ph, conds=conds)
    else:
        result = eq_code_template.format(dbf='\''+db_filename+'\'', comps=comps,
                                      phases=ph, conds=conds)

    nb = nbformat.v4.new_notebook()
    nb['cells'] = [nbformat.v4.new_code_cell(result)]
    nbformat.write(nb, 'Untitled.ipynb')
    out.clear_output()
    display(Markdown("""```python\n""" + result + """```"""))

generate.on_click(generate_code)
uploader.observe(handle_upload, names='value')
elements.observe(update_phases_and_set_conditions, names='value')



tabs = widgets.Tab(children=[all_conditions, output_selector])
tabs.set_title(0, 'Conditions')
tabs.set_title(1, 'Output')

display(uploader)
display(elements)
display(phases)
display(tabs)
out

FileUpload(value={}, accept='*.tdb', description='Choose Database', layout=Layout(width='250px'))

SelectMultiple(description='Elements', options=(), value=())

SelectMultiple(description='Phases', options=(), value=())

Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Condition:', layout=Layout(width='150px'), o…

Output(layout=Layout(border='1px solid black'))