In [None]:
import dash
from dash import dcc, html
import plotly.graph_objects as go
from dash.dependencies import Input, Output, State
import numpy as np
import json


#Load the heatmap data from a JSON file
def load_data(json_path):
    #Initialize groups to separate them in the heatmap
    group_localization = {
        "1": {"en": "Jobs and Investments", "fi": "Työpaikat ja Investoinnit"},
        "2": {"en": "Jobs", "fi": "Työpaikat"},
        "3": {"en": "Investments", "fi": "Investoinnit"},
    }

    with open(json_path, "r", encoding="utf-8") as file:
        raw_data = json.load(file)

    return [
        {
            "word": details.get("word"),
            "capitalized_word": details.get("word").capitalize(),
            "value": details.get("value"),
            "group": details.get("group"),
            "group_name": {
                "en": group_localization.get(details.get("group"), {}).get("en", ""),
                "fi": group_localization.get(details.get("group"), {}).get("fi", "")
            },
            "similar_sources_found": details.get("similar_sources_found", {})
        }
        for details in raw_data
    ]

original_data = load_data("data.json")


def load_color_config(json_path):
    with open(json_path, "r", encoding="utf-8") as file:
        return json.load(file)

color_config = load_color_config("colorConfig.json")

#Generate hex coordinates for the map
def generate_hex_coordinates(radius):
    #Store cords in a list
    coords = []
    for q in range(-radius, radius + 1):
        r1 = max(-radius, -q - radius)
        r2 = min(radius, -q + radius)
        for r in range(r1, r2 + 1):
            coords.append((q, r))
    return coords

# Convert axial to pixel coordinates
def axial_to_pixel(q, r, size):
    x = size * 3/2 * q
    y = size * np.sqrt(3) * (r + q / 2)
    return x, y

# Get hexagon vertices
def hexagon_vertices(center_x, center_y, size):
    angles = np.linspace(0, 2 * np.pi, 7)[:-1]
    x = center_x + size * np.cos(angles)
    y = center_y + size * np.sin(angles)
    return x, y

# Get cell color
def get_color(value, group, is_disabled=False):
    for config in color_config:
        if (config["targetGroup"] == group or config["targetGroup"] == "*") and config["min"] <= value <= config.get("max", float('inf')):
            return config["color"] if not is_disabled else config["disabledColor"]
    return "rgba(255, 255, 255, 0.8)"

#Strings for translations
translations = {
    "en": {
        "title": "Relevant Manufacturing Industry Topics",
        "word_label": "Word",
        "value_label": "Value",
        "group_name_label": "Group Name",
    },
    "fi": {
        "title": "Tärkeitä Valmistavan Teollisuuden Aiheita",
        "word_label": "Sana",
        "value_label": "Arvo",
        "group_name_label": "Ryhmän Nimi",
    }
}


#Initialize Dash app
app = dash.Dash(__name__)

app.layout = html.Div([
    html.Div([
        #Changes the labels depending on language
        dcc.Dropdown(
            id='language-selector',
            options=[
                {'label': 'English', 'value': 'en'},
                {'label': 'Suomi', 'value': 'fi'}
            ],
            value='en',
            clearable=False,
            style={'width': '200px'}
        ),
        #Reset the map to default enlarged view
        html.Button("Reset map", id="reset-button", n_clicks=0),
        html.A(
            "Open Fullscreen",
            href="http://127.0.0.1:8050/",
            target="_blank",
            style={
                "display": "inline-block",
                "padding": "10px",
                "background": "#007BFF",
                "color": "white",
                "textDecoration": "none",
                "font-size": "16px",
                "border-radius": "8px",
                "width": "200px",
            }
        )
    ], style={"display": "flex", "flexDirection": "row", "alignItems": "center", "marginBottom": "10px", "width": "100%", "justifyContent": "space-between"}),
    dcc.Store(id='clicked-word-store'),  #Store for clicked word
    dcc.Graph(
        id='heatmap-graph',
        clear_on_unhover=True
    ),
], style={"display": "flex", "flexDirection": "column", "alignItems": "center"})

@app.callback(
    Output('clicked-word-store', 'data'),
    [Input('heatmap-graph', 'clickData'),
    Input('reset-button', 'n_clicks')],
    [State('clicked-word-store', 'data')]  # Access previous stored word
)
def update_stored_word(clickData, _, stored_clicked_word):
    # Check what triggered the callback
    triggered_prop = dash.callback_context.triggered[0]['prop_id']

    if triggered_prop == 'reset-button.n_clicks':  # Reset button was clicked
        return None  # Reset the stored word
    elif triggered_prop == 'heatmap-graph.clickData' and clickData:  # Heatmap cell clicked
        clicked_word =  clickData['points'][0]['customdata'][0]  # Set the clicked word
        if clicked_word != stored_clicked_word:
            return clicked_word # clicking the same word will reset the heatmap
        else:
            return None
    return stored_clicked_word  # Keep the existing stored word if nothing triggered


@app.callback(
    Output('heatmap-graph', 'figure'),
    Input('heatmap-graph', 'hoverData'),
    Input('language-selector', 'value'),
    Input('clicked-word-store', 'data'), # when the clicked word changes this function is called. so this works like useEffect in react with dependencies in the array
)
def on_cell_hover_or_language_change(hoverData, selected_language, stored_clicked_word):
    global original_data
    data = original_data.copy()
    if stored_clicked_word:
      # Find the clicked cell in the original data
      clicked_cell = next((cell for cell in data if cell['word'] == stored_clicked_word), None)
      new_data = []
      for cell in data:
          if cell['word'] == clicked_cell['word']:
              new_data.append(cell)
          elif cell['word'] in clicked_cell['similar_sources_found']:
              new_data.append(cell)
      data = new_data

    else:
        data = original_data.copy()
    #Initialize the cells on the hexagon grid
    radius = max(6, int(len(data) ** 0.5))#Set the overall radius of the hexagon grid
    size = 1 #This is supposed to be the size of the hexagons, but it's working very dubiously and is affected by other values here
    width = height = min(2000, 100 * radius)#The biggest affector of the hexagons
    raw_coordinates = generate_hex_coordinates(radius) #Give the radius data to generate hexagon coordinates
    sorted_coordinates = sorted(raw_coordinates, key=lambda coord: abs(coord[0]) + abs(coord[1]) + abs(-coord[0] - coord[1]))
    pixel_centers = [axial_to_pixel(q, r, size) for q, r in sorted_coordinates]
    hex_vertices = [hexagon_vertices(x, y, size) for x, y in pixel_centers]

    cells_to_keep = []
    hovered_cell_index = -1
    if dash.callback_context.triggered[0]['prop_id'] == 'heatmap-graph.hoverData':
        hovered_cell_index = hoverData['points'][0]['curveNumber'] if hoverData else -1
        if hovered_cell_index != -1:
            cells_to_keep = set(data[hovered_cell_index]['similar_sources_found'].keys())
            cells_to_keep.add(data[hovered_cell_index]['word'].lower())

    scatter_traces, annotations = [], []
    for i, (x_center, y_center), (x_hex, y_hex) in zip(range(len(data)), pixel_centers, hex_vertices):
        is_disabled = data[i]['word'].lower() not in cells_to_keep and hovered_cell_index != -1
        fillcolor = get_color(data[i]['value'], data[i]['group'], is_disabled)

        short_text = data[i]['word'][:11] + "..." if len(data[i]['word']) > 11 else data[i]['word']
        scatter_trace = go.Scatter(
            x=x_hex.tolist() + [x_hex[0]],
            y=y_hex.tolist() + [y_hex[0]],
            mode="lines",
            fill="toself",
            fillcolor=fillcolor,
            line=dict(color="rgb(10, 10, 10)", width=1),
            customdata=[data[i]['word']],
            hoverinfo="text",
            text=(
                f"{translations[selected_language]['word_label']}: {data[i]['capitalized_word']}<br>"
                f"{translations[selected_language]['value_label']}: {data[i]['value']}<br>"
                f"{translations[selected_language]['group_name_label']}: {data[i]['group_name'][selected_language]}"
            ),
            showlegend=False
        )
        scatter_traces.append(scatter_trace)

        annotation = dict(
            x=x_center,
            y=y_center,
            text=f"{short_text}<br>{data[i]['value']}",
            showarrow=False,
            font=dict(size=12, color="rgb(10, 10, 10)"),
            xanchor="center",
            yanchor="middle"
        )
        annotations.append(annotation)

    title = translations[selected_language]['title']
    if stored_clicked_word:
        title += f" - {stored_clicked_word.capitalize()}"
    fig = go.Figure(data=scatter_traces)
    fig.update_layout(
        xaxis=dict(showgrid=False, zeroline=False, visible=False),
        yaxis=dict(showgrid=False, zeroline=False, visible=False, scaleanchor="x", scaleratio=1),
        annotations=annotations,
        width=width,
        height=height,
        margin=dict(l=50, r=50, t=50, b=50),
        hovermode="closest",
        paper_bgcolor="rgb(240, 240, 240)",
        plot_bgcolor="rgb(240, 240, 240)",
        title=title
    )
    #Return the figure

    return fig

#Run the app
if __name__ == '__main__':
    app.run()
