# Data exploration

Before running the Jupyter notebook:
- Create a new Conda or venv
- Install the requirements with: `pip install -r requirements.txt`
- Download a SpaCy model: `python -m spacy download fr_core_news_sm`

## User inputs
You can modify the variables below to set the path to your training or test data and choose whether to display titles on the graphs.

In [None]:
data_path = "../data/training_data/20250428_NP_train-evalLLM.json"
# data_path = "../data/test_data/20250516_NP_test_evalLLM.json"
with_title = False

## 1. Imports

In [None]:
import os
import spacy
import pandas as pd
import seaborn as sns
from typing import List
from pathlib import Path
from datetime import date
from itertools import chain
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from matplotlib.patches import Rectangle
from collections import Counter, defaultdict

### 1.1. Utils functions

In [None]:
def get_color(colorname: str) -> str:
    """Get color from defined color palette

    Args:
        colorname (str): colorname

    Returns:
        str: hexadecimal code
    """
    color_palette = {
        "puce": "#d88c9aff",
        "light-orange": "#f2d0a9ff",
        "almond": "#f1e3d3ff",
        "cambridge-blue": "#99c1b9ff",
        "tropical-indigo": "#8e7dbeff"
    }
    return color_palette[colorname]


def get_color_for_plotly(colorname: str) -> str:
    """Get color from defined color palette with muted colors for Plotly

    Args:
        colorname (str): colorname

    Returns:
        str: hexadecimal code
    """
    color_palette = {
        "puce": "#a57b88",
        "light-orange": "#d4b68e",
        "almond": "#c7b8a4",
        "cambridge-blue": "#7e9e94",
        "tropical-indigo": "#6d5e99"
    }
    return color_palette[colorname]

### 1.2. Folders creation

In [None]:
date = date.today().strftime("%Y%m%d")

if "training_data" in data_path:
    folder_path_for_images = "training_data__images"
    stats_filename = f"training_data__stats__{date}.csv"
elif "test_data" in data_path:
    folder_path_for_images = "test_data__images"
    stats_filename = f"test_data__stats__{date}.csv"
else:
    folder_path_for_images = "images"
    stats_filename = f"stats__{date}.csv"

if with_title:
    subfolder_path_for_images = "images_with_title"
else:
    subfolder_path_for_images = "images_without_title"

# Create folders to save images if not exist
Path(os.path.join(folder_path_for_images, subfolder_path_for_images)).mkdir(parents=True, exist_ok=True)

## 2. Load raw data

In [None]:
df = pd.read_json(data_path)
print(f"{len(df)} documents in {data_path}")
print(f"Columns: {df.columns}")
print(f"Dataframe extract: {df.head(3)}")

## 3. Dataframe calculations
Add new columns for: character length, token length, events number, entities involved in events, labels of central elements and labels of associated elements.

In [None]:
def count_tokens(text: str) -> int:
    """Count tokens using spaCy (not spaces or punctuation)

    Args:
        text (str): text to count tokens from

    Returns:
        int: number of tokens
    """
    doc = nlp(text)
    return len([token for token in doc if not token.is_punct and not token.is_space])

def extract_elements_labels(row: pd.Series, attribute: str) -> List:
    """Extract elements labels in a list

    Args:
        row (pd.Series): dataframe row
        attribute (str): 'evt:central_element' or 'evt:associated_element'

    Returns:
        List: list of elements labels
    """
    # Create a mapping from id to label
    id_to_label = {item['id']: item['label'] for item in row['entities']}
    # Initialize list for labels of elements
    associated_elements_labels = []
    for sublist in row['events']:
        for event in sublist:
            if event['attribute'] == attribute:
                labels = [id_to_label.get(occ) for occ in event['occurrences']]
                labels_unique_list = list(dict.fromkeys(labels))
                associated_elements_labels.append(labels_unique_list)
    return associated_elements_labels

# Load SpaCy model for tokens count
nlp = spacy.load("fr_core_news_sm")

# Copy df to add new columns
df_copy = df.copy()

# Calculate character and token lengths
df_copy['character_length'] = df_copy['text'].str.len()
df_copy['token_length'] = df_copy['text'].apply(count_tokens)
# Calculate number of events and how many entites (central and associated elements) are involved in the event
df_copy['events_nb'] = df_copy['events'].apply(len)
df_copy['entities_involved_in_events'] = df_copy['events'].apply(
    lambda row: [len(sublist) for sublist in row]
)
# Calculate labels of central and associated elements of events
df_copy['central_elements'] = df_copy.apply(
    lambda row: extract_elements_labels(row, "evt:central_element"),
    axis=1
)
df_copy['associated_elements'] = df_copy.apply(
    lambda row: extract_elements_labels(row, "evt:associated_element"),
    axis=1
)

# Save stats to csv
df_copy.to_csv(stats_filename)

## 4. Plot graphs

### 4.1. Plot character and token distribution

In [None]:
def plot_length(df: pd.DataFrame, plot_name_fr: str, plot_name_en: str, column_name: str):
    """Plot histogram for characters and tokens length

    Args:
        df (pd.DataFrame): dataframe to extract data from
        plot_name_fr (str): french plot name ('caractères' or 'tokens')
        plot_name_en (str): english plot name ('characters' or 'tokens')
        column_name (str): dataframe column containing the length ('character_length' or 'token_length')
    """
    # Set context and theme
    sns.set_context("notebook")
    sns.set_theme(style="ticks")

    # Histogram
    ax = sns.histplot(
        data=df_copy,
        x=column_name,
        bins=15,
        color=get_color("cambridge-blue"),
        alpha=0.8
    )

    # Mean
    mean_val = df[column_name].mean()
    print(f"Mean: {round(mean_val)}")
    
    # Titles and labels
    if with_title:
        ax.set_title(f'Distribution des documents en fonction du nombre de {plot_name_fr}', fontsize=16, fontweight='bold')
    ax.set_xlabel(f'Nombre de {plot_name_fr}', fontsize=14)
    ax.set_ylabel('Nombre de documents', fontsize=14)

    # Ticks
    ax.tick_params(axis='both', labelsize=12)
    plt.xticks(rotation=0)

    # Remove top and right spines
    sns.despine()

    # Save and show plot
    plt.tight_layout()
    plt.savefig(os.path.join(folder_path_for_images, subfolder_path_for_images, f"{plot_name_en.lower()}_distribution.png"), dpi=300, bbox_inches='tight')
    plt.show()
    
df_stats = pd.read_csv(stats_filename)
plot_length(df_copy, "caractères", "characters", "character_length")
plot_length(df_copy, "tokens", "tokens", "token_length")

### 4.2. Plot entity labels distribution

In [None]:
def plot_entity_labels_distribution(df: pd.DataFrame):
    """Plot entity labels distribution from dataframe

    Args:
        df (pd.DataFrame): dataframe to extract data from
    """
    # Process data
    # Extract labels from df
    all_labels = []
    for row in df['entities']:
        labels = [entity['label'] for entity in row]
        all_labels.extend(labels)
    # Count frequency of each label
    label_counts = Counter(all_labels)
    total_entities = sum(label_counts.values())
    print(f"{total_entities} entities")
    print(f"{len(label_counts)} labels")
    print(f"Labels count: {label_counts}")
    # Sort labels by frequency in descending order
    sorted_labels = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)

    sns.set_context("notebook")

    # Bar plot
    sns.barplot(
        x=[item[0] for item in sorted_labels],  # Entity labels
        y=[item[1] for item in sorted_labels],  # Frequencies
        color=get_color("tropical-indigo")
    )

    # Titles and labels
    if with_title:
        plt.title('Distribution des labels des entités', fontsize=16, fontweight='bold')
    plt.xlabel('Labels des entités', fontsize=14)
    plt.ylabel('Fréquence', fontsize=14)

    # Ticks
    plt.xticks(rotation=45, ha='right')

    # Remove top and right spines
    sns.despine()

    # Save and show plot
    plt.tight_layout()
    plt.savefig(os.path.join(folder_path_for_images, subfolder_path_for_images, "entity_labels_distribution.png"), dpi=300, bbox_inches='tight')
    plt.show()

if not "test_data" in data_path:
    plot_entity_labels_distribution(df_copy)

#### 4.2.1 Plot entity labels distribution with examples on mouse hover

In [None]:
def plot_entity_labels_distribution_with_examples(df: pd.DataFrame):
    """Plot entity labels distribution with examples on mouse hover. It will produces a .html file.

    Args:
        df (pd.DataFrame): dataframe to extract data from
    """
    label_counts = Counter()
    label_examples = defaultdict(list)
    for row in df['entities']:
        for entity in row:
            label = entity['label']
            text = entity.get('text', '').lower()
            label_counts[label] += 1
            if len(label_examples[label]) < 10 and text not in label_examples[label]:
                label_examples[label].append(text)

    # Sort labels by frequency
    sorted_items = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)

    labels = [label for label, _ in sorted_items]
    counts = [label_counts[label] for label in labels]
    hover_texts = [
        "<br>".join([f"- {example}" for example in label_examples[label]]) if label_examples[label] else "No example"
        for label in labels
    ]

    fig = go.Figure(data=[
        go.Bar(
            x=labels,
            y=counts,
            marker_color="#7c6bc2",
            marker_opacity=0.7,
            hovertemplate='<b>%{x}</b><br>Count: %{y}<br><br>Examples:<br>%{customdata}',
            customdata=hover_texts,
            name=""  # prevents "trace 0"
        )
    ])

    fig.update_layout(
        showlegend=False,
        title="Distribution des labels des entités",
        xaxis_title="Labels des entités",
        yaxis_title="Fréquence",
        xaxis_tickangle=-45,
        template="plotly_white",
        hoverlabel=dict(
            bgcolor=get_color_for_plotly("puce"),
            font=dict(color="white"),
            bordercolor="rgba(255, 255, 255, 0.3)",
            font_size=16
        ),
        title_font=dict(size=20),
        xaxis_title_font=dict(size=20),
        yaxis_title_font=dict(size=20),
        xaxis_tickfont=dict(size=16),
        yaxis_tickfont=dict(size=16)
    )

    # Show and save
    fig.write_html(os.path.join(folder_path_for_images, "entity_labels_distribution_with_examples.html"))

if not "test_data" in data_path:
    plot_entity_labels_distribution_with_examples(df_copy)

### 4.3. Plot events

#### 4.3.1. Event distribution in documents

In [None]:
def plot_event_distribution_in_documents(df_copy: pd.DataFrame):
    """Plot event distribution in documents

    Args:
        df_copy (pd.DataFrame): dataframe to extract data from
    """
    # Set context and theme
    sns.set_context("notebook")
    sns.set_theme(style="whitegrid")

    # Bar plot
    ax = sns.barplot(x=df_copy['events_nb'].value_counts().index, 
                y=df_copy['events_nb'].value_counts().values, 
                color=get_color("light-orange"))

    # Mean
    mean_event_nb = df_copy['events_nb'].mean()
    print(f"Mean: {round(mean_event_nb)}")

    # Titles and labels
    if with_title:
        plt.title("Distribution des événements dans les documents", fontsize=16, fontweight='bold')
    plt.xlabel("Nombre d'événements", fontsize=14)
    plt.ylabel("Nombre de documents", fontsize=14)

    # Ticks and grid
    ax.tick_params(axis='both', labelsize=12)

    # Remove top and right spines
    sns.despine()

    # Remove decimal ticks on the y-axis
    plt.gca().yaxis.set_major_locator(plt.MaxNLocator(integer=True))

    # Save and show plot
    plt.tight_layout()
    plt.savefig(os.path.join(folder_path_for_images, subfolder_path_for_images, "events_distribution_in_documents.png"), dpi=300, bbox_inches='tight')
    plt.show()

if not "test_data" in data_path:
    plot_event_distribution_in_documents(df_copy)

#### 4.3.2. Entities distribution in events

In [None]:
def plot_entities_distribution_in_events(df_copy: pd.DataFrame):
    """Plot entities distribution in events

    Args:
        df_copy (pd.DataFrame): dataframe to extract data from
    """
    # Process data
    # Calculate number of dicts in each sublist
    df_copy['entities_involved_in_events'] = df_copy['events'].apply(
        lambda row: [len(sublist) for sublist in row]
    )
    # Flatten all dict counts into one list
    flat_counts = [count for sublist in df_copy['entities_involved_in_events'] for count in sublist]

    # Set context and theme
    sns.set_context("notebook")
    sns.set_theme(style="whitegrid")

    # Mean
    mean_count = sum(flat_counts) / len(flat_counts)
    print(f"Mean: {round(mean_count)}")

    # Plot
    ax = sns.countplot(x=flat_counts, color=get_color("light-orange"))

    # Labels and title
    if with_title:
        plt.title("Distribution des entités dans les événements", fontsize=16, fontweight='bold')
    plt.ylabel("Nombre d'événements", fontsize=14)
    plt.xlabel("Nombre d'entités", fontsize=14)

    # Ticks
    ax.tick_params(axis='both', labelsize=12)

    # Remove top and right spines
    sns.despine()

    # Save and show plot
    plt.tight_layout()
    plt.savefig(os.path.join(folder_path_for_images, subfolder_path_for_images, "entities_distribution_in_events.png"), dpi=300, bbox_inches='tight')
    plt.show()

if not "test_data" in data_path:
    plot_entities_distribution_in_events(df_copy)


#### 4.3.3. Central and associated elements (with labels) distribution in events

In [None]:
def plot_central_or_associated_elements_distribution(attribute_column: str, attribute_fr_name: str):
    """Plot central or associated elements distribution

    Args:
        attribute_column (str): attribute column 'central_elements' or 'associated_elements"
        attribute_fr_name (str): french attribute name
    """
    # Process data
    # Flatten the nested lists into a single list of labels
    all_labels = list(chain.from_iterable(chain.from_iterable(df_copy[attribute_column])))
    label_counts = pd.Series(all_labels).value_counts().reset_index()
    label_counts.columns = ['label', 'count']
    total_counts = label_counts['count'].sum()
    print(f"For {attribute_column}:\n\t{total_counts} elements\n\t{len(label_counts)} labels")
    print(label_counts)

    # Set context and theme
    sns.set_context("notebook")
    sns.set_theme(style="whitegrid")

    # Bar plot
    ax = sns.barplot(data=label_counts, x='label', y='count', color=get_color("light-orange"))

    # Titles and labels
    if with_title:
        plt.title(f'Distribution des {attribute_fr_name} dans les événements', fontsize=16, fontweight='bold')
    plt.xlabel('Labels', fontsize=14)
    plt.ylabel('Fréquence', fontsize=14)

    # Ticks
    ax.tick_params(axis='both', labelsize=12)
    plt.xticks(rotation=60)

    # Uncomment only for training data to show problematic labels
    if attribute_column == "central_elements":
        highlight_labels = ['PATH_REF_TO_DIS', 'DOC_DATE', 'REL_DATE'] 
        # Get coordinates of bars to highlight
        highlight_patches = []
        for patch, label in zip(ax.patches, label_counts['label']):
            if label in highlight_labels:
                highlight_patches.append(patch)
        # If matches found, compute bounding box and draw rectangle
        if highlight_patches:
            x0 = min(p.get_x() for p in highlight_patches)
            x1 = max(p.get_x() + p.get_width() for p in highlight_patches)
            height = max(p.get_height() for p in highlight_patches)
            rect = Rectangle((x0 - 0.05, 0), x1 - x0 + 0.1, height + 5, fill=False, edgecolor=get_color("puce"), linewidth=2, linestyle='--')
            ax.add_patch(rect)

            import matplotlib.patches as mpatches

            # Create a legend handle (proxy artist) for the rectangle
            highlight_patch = mpatches.Patch(
                edgecolor=get_color("puce"),
                facecolor='none',
                linewidth=2,
                linestyle='--',
                label='Labels problématiques'
            )

            # Add it to the legend
            ax.legend(handles=[highlight_patch], loc='upper right', fontsize=12)

    # Remove top and right spines
    sns.despine()

    plt.tight_layout()
    plt.savefig(os.path.join(folder_path_for_images, subfolder_path_for_images, f"{attribute_column}_distribution.png"), dpi=300, bbox_inches='tight')
    plt.show()

if not "test_data" in data_path:
    plot_central_or_associated_elements_distribution("central_elements", "éléments centraux")
    plot_central_or_associated_elements_distribution("associated_elements", "éléments associés")