# NOMAD Visualization Module

This notebook provides reusable visualization UI and logic for NOMAD sample data. It can be imported into other dashboard notebooks.

## Usage

```python
# Import the visualization module
%run 'nomad_visualization.ipynb'

# The following functions are now available:
# - create_visualization_tab(data_state): Creates and returns the visualization UI
```

## Requirements

- Dependencies like plotly, pandas, etc. should be installed
- The data_state dictionary should contain:
  - df: DataFrame with sample data
  - attributions: Dictionary of attribution overrides

In [None]:
# Import required libraries
import ipywidgets as widgets
from ipywidgets import HBox, VBox, Button, Label
from IPython.display import display, clear_output
import pandas as pd
import plotly.graph_objects as go
from datetime import datetime

## Visualization Functions

In [None]:
def apply_date_filter(df, start_date, end_date):
    """Apply date range filter to the dataframe if dates are provided.
    
    Args:
        df: DataFrame with 'upload_date' column
        start_date: Start date for filtering (optional)
        end_date: End date for filtering (optional)
        
    Returns:
        DataFrame: Filtered dataframe
    """
    if start_date is not None and end_date is not None:
        mask = (df['upload_date'] >= pd.Timestamp(start_date)) & \
               (df['upload_date'] <= pd.Timestamp(end_date))
        return df[mask]
    return df

def get_effective_authors(df, attributions):
    """Calculate effective authors using original data and overrides.
    
    Args:
        df: DataFrame with author data
        attributions: Dictionary of attribution overrides
        
    Returns:
        DataFrame: DataFrame with effective_author column added
    """
    # Use original author as default
    author_col = 'author_name' if 'author_name' in df.columns else 'main_author'
    df['effective_author'] = df[author_col]

    # Apply attribution overrides
    if attributions:
        for upload_id, attr_info in attributions.items():
            if upload_id in df['upload_id'].values:
                # Get the override author from the attribution info
                override_author = attr_info.get('author_display_name',  # New format
                                attr_info.get('main_author_name',       # Old format
                                attr_info.get('author_id',              # New format
                                attr_info.get('main_author', ''))))     # Old format
                
                if override_author:
                    # Apply the override
                    df.loc[df['upload_id'] == upload_id, 'effective_author'] = override_author
    return df

def create_author_distribution_plot(df):
    """Create a bar plot showing sample distribution by author.
    
    Args:
        df: DataFrame with effective_author column
        
    Returns:
        plotly.graph_objects.Figure: Bar plot of samples by author
    """
    author_counts = df['effective_author'].value_counts()
    
    fig = go.Figure(data=[
        go.Bar(
            x=author_counts.index,
            y=author_counts.values,
            text=author_counts.values,
            textposition='auto',
        )
    ])

    fig.update_layout(
        title='Samples by Author (Using Override Authors)',
        xaxis_title='Author',
        yaxis_title='Number of Samples',
        height=700,
        xaxis={'tickangle': 45},
        margin=dict(b=100)
    )
    
    return fig

def create_time_series_plot(df, time_grouping='Monthly', plot_type='Stacked Bars'):
    """Create a time series plot of samples by author.
    
    Args:
        df: DataFrame with effective_author column
        time_grouping: One of 'Daily', 'Weekly', 'Monthly', 'Yearly'
        plot_type: One of 'Stacked Bars', 'Grouped Bars'
        
    Returns:
        plotly.graph_objects.Figure: Time series plot
    """
    # Determine frequency for grouping
    freq_map = {
        'Daily': 'D',
        'Weekly': 'W',
        'Monthly': 'ME',
        'Yearly': 'YE'
    }
    freq = freq_map.get(time_grouping, 'ME')
    
    # Group data by time period and effective_author
    df_grouped = df.groupby([pd.Grouper(key='upload_date', freq=freq), 'effective_author']).size().unstack(fill_value=0)
    
    fig = go.Figure()
    
    if plot_type == 'Stacked Bars':
        # Create stacked bar plot
        for author in df_grouped.columns:
            fig.add_trace(go.Bar(
                name=author,
                x=df_grouped.index,
                y=df_grouped[author],
                text=df_grouped[author],
                textposition='inside'
            ))
        
        # Add total labels on top of stacked bars
        totals = df_grouped.sum(axis=1)
        fig.add_trace(go.Scatter(
            x=df_grouped.index,
            y=totals,
            mode='text',
            text=totals,
            textposition='top center',
            showlegend=False,
            textfont=dict(size=12)
        ))
        
        fig.update_layout(barmode='stack')
        title = 'Samples Over Time (Stacked by Author, Using Override Authors)'
    
    else:  # Grouped Bars
        # Create grouped bar plot
        for author in df_grouped.columns:
            fig.add_trace(go.Bar(
                name=author,
                x=df_grouped.index,
                y=df_grouped[author],
                text=df_grouped[author],
                textposition='auto'
            ))
        fig.update_layout(barmode='group')
        title = 'Samples Over Time (Grouped by Author, Using Override Authors)'
    
    fig.update_layout(
        title=title,
        xaxis_title='Date',
        yaxis_title='Number of Samples',
        height=700,
        showlegend=True,
        legend_title_text='Author'
    )
    
    return fig

## Visualization Tab Component

In [None]:
def create_visualization_tab(data_state):
    """Create the visualization tab with interactive visualizations.
    
    Args:
        data_state: Dictionary containing:
            - df: DataFrame with sample data
            - attributions: Dictionary of attribution overrides
            
    Returns:
        widgets.VBox: The visualization UI component
    """
    # Create containers
    viz_container = widgets.Output()
    controls_container = widgets.VBox()
    status_output = widgets.Output()

    # Date range selector - initially None to show all data
    start_date = widgets.DatePicker(
        description='Start Date:',
        value=None,
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='300px')
    )
    
    end_date = widgets.DatePicker(
        description='End Date:',
        value=None,
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='300px')
    )

    # Time series options
    time_grouping = widgets.RadioButtons(
        options=['Daily', 'Weekly', 'Monthly', 'Yearly'],
        value='Monthly',
        description='Group by:',
        style={'description_width': 'initial'}
    )

    plot_type = widgets.RadioButtons(
        options=['Stacked Bars', 'Grouped Bars'],
        value='Stacked Bars',
        description='Plot type:',
        style={'description_width': 'initial'}
    )

    # Show visualization button
    show_viz_button = widgets.Button(
        description='Show Visualizations',
        button_style='primary',
        icon='chart-bar',
        layout=widgets.Layout(width='200px')
    )

    # Create containers for each plot type
    author_plot_container = widgets.Output(
        layout=widgets.Layout(
            width='100%',
            min_height='700px'
        )
    )

    time_series_container = widgets.Output(
        layout=widgets.Layout(
            width='calc(100% - 290px)',
            min_height='700px'
        )
    )

    def update_visualizations(*args):
        """Update both visualizations based on current settings"""
        with author_plot_container:
            clear_output(wait=True)
        with time_series_container:
            clear_output(wait=True)
            
        if data_state.get('df') is None or data_state['df'].empty:
            with viz_container:
                clear_output(wait=True)
                display(widgets.HTML("<p>No data available. Please fetch data first.</p>"))
            return

        # Prepare the data
        df = data_state['df'].copy()
        df['upload_date'] = pd.to_datetime(df['upload_date'])
        
        # Apply date filter
        df = apply_date_filter(df, start_date.value, end_date.value)
        
        # Calculate effective authors
        df = get_effective_authors(df, data_state.get('attributions', {}))

        # Create and display author distribution plot
        with author_plot_container:
            fig1 = create_author_distribution_plot(df)
            fig1.show()

        # Create and display time series plot
        with time_series_container:
            fig2 = create_time_series_plot(df, time_grouping.value, plot_type.value)
            fig2.show()

    # Set up event handlers
    start_date.observe(update_visualizations, 'value')
    end_date.observe(update_visualizations, 'value')
    time_grouping.observe(update_visualizations, 'value')
    plot_type.observe(update_visualizations, 'value')
    show_viz_button.on_click(update_visualizations)

    # Time series controls box
    time_series_controls = widgets.VBox([
        widgets.HTML("<h4>Time Series Options:</h4>"),
        time_grouping,
        plot_type
    ], layout=widgets.Layout(
        margin='10px 20px',
        padding='15px',
        border='1px solid #ddd',
        border_radius='5px',
        width='250px',
        align_items='flex-start'
    ))

    # Global controls with better spacing
    global_controls = widgets.VBox([
        widgets.HTML("<h3>Global Controls</h3>"),
        show_viz_button,
        widgets.HBox([start_date, end_date], layout=widgets.Layout(margin='10px 0'))
    ], layout=widgets.Layout(
        margin='0 0 30px 0',
        padding='15px',
        border='1px solid #ddd',
        border_radius='5px'
    ))

    # Create a divider
    divider = widgets.HTML("<hr style='border: none; border-top: 1px solid #ddd; margin: 20px 0;'>")

    # Create layout for time series with controls
    time_series_layout = widgets.HBox([
        time_series_container,
        time_series_controls
    ], layout=widgets.Layout(
        margin='20px 0',
        width='100%',
        align_items='flex-start'
    ))

    # Combine widgets into form with improved spacing
    viz_ui = widgets.VBox([
        widgets.HTML("<h2>Sample Visualizations</h2>"),
        global_controls,
        author_plot_container,
        divider,
        time_series_layout
    ], layout=widgets.Layout(
        margin='20px',
        width='100%'
    ))

    # Initialize if data is already present
    if data_state.get('df') is not None and not data_state['df'].empty:
        update_visualizations()

    return viz_ui

## Example Usage

Here's an example of how to use this visualization module:

In [None]:
# This cell demonstrates how to use the visualization module
# Not meant to be executed in this notebook directly

'''
# Import the visualization module
%run './nomad_visualization.ipynb'

# Initialize data state
data_state = {
    'df': your_dataframe,
    'attributions': load_attributions()
}

# Create the visualization tab
viz_tab = create_visualization_tab(data_state)

# Display the tab
display(viz_tab)
'''