# Imports

In [1]:
from collections import Counter, defaultdict
import community as community_louvain
from expand import process_flight_sequences
import math
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from pyvis.network import Network

# Functions definition

In [2]:
def get_ngrams(sequence, n):
    """Generates n-grams from a list sequence efficiently."""
    if n == 0:
        return []
    # Using zip to create sliding windows is efficient in Python
    return zip(*[sequence[i:] for i in range(n)])


def find_significant_motifs(flights_data, k, z_threshold=1.96, nb_phases=None):
    """
    Identifies statistically significant k-motifs.
    
    Args:
        flights_data (list of lists): Each inner list is a sequence of flight phase integers.
        k (int): The length of the motif to analyze (e.g., 3).
        z_threshold (float): Z-score cutoff (1.96 for 95% confidence).
        nb_phases (int, optional): Number of unique flight phases (required for k=2).
        
    Returns:
        pd.DataFrame: Table of motifs with Obs/Exp probabilities and Z-scores.
    """
    if k == 2 and nb_phases is None:
        raise ValueError("For k=2, nb_phases must be provided to calculate expected probabilities.")
    
    # 1. Count frequencies for k, k-1, and k-2 patterns
    # We use a single pass over the data to populate all counters
    counts_k = Counter()      # Counts for x_1...x_k
    counts_prefix = Counter() # Counts for x_1...x_{k-1}
    counts_overlap = Counter() # Counts for x_2...x_{k-1} (the middle part)
    
    for flight in flights_data:
        # flight is a list of ints, e.g., [10, 20, 30, 40]
        if len(flight) < k:
            continue

        # Count k-grams (x_1...x_k)
        ngrams = list(get_ngrams(flight, k))
        counts_k.update(ngrams)
            
        # Note: If k=2, p_exp = p_obs (Eq. 5)
        if k > 2:
            # Count (k-1)-grams (used for prefix and suffix)
            ngrams_minus_1 = list(get_ngrams(flight, k - 1))
            counts_prefix.update(ngrams_minus_1)

            # Count (k-2)-grams (used for overlap)
            ngrams_minus_2 = list(get_ngrams(flight, k - 2))
            counts_overlap.update(ngrams_minus_2)
    
    # Track total number of substrings of each length for probability normalization
    total_k = counts_k.total()
    total_prefix = counts_prefix.total()
    total_overlap = counts_overlap.total()
    
    # 2. Calculate Probabilities and Z-scores
    results = []
    
    for motif, count in counts_k.items():
        if k == 2:
            # Observed Probability: p_obs(ABC)
            p_obs = count / total_k

            # Expected Probability: p_exp(AB) = poss(AB) / poss(XX)
            p_exp = 1 / nb_phases**2
            expected_count = (total_k/2) * p_exp  # Total possible pairs is total_k/2

            # Standard Deviation for binomial distribution approx: sqrt(N * p * (1-p))
            sigma = math.sqrt((total_k/2) * p_exp * (1 - p_exp))

        else:
            # Define parts of the motif
            # motif is a tuple like (A, B, C)
            prefix = motif[:-1]      # (A, B)
            suffix = motif[1:]       # (B, C)
            overlap = motif[1:-1]    # (B)

            # Observed Probability: p_obs(ABC)
            p_obs = count / total_k
            
            # Expected Probability calculation 
            # p_exp(ABC) = p_obs(AB) * p_obs(BC) / p_obs(B)
            prob_prefix = counts_prefix[prefix] / total_prefix
            prob_suffix = counts_prefix[suffix] / total_prefix

            prob_overlap = counts_overlap[overlap] / total_overlap
            if prob_overlap == 0:
                continue # Avoid division by zero
            p_exp = (prob_prefix * prob_suffix) / prob_overlap
            
            # Standard Deviation for binomial distribution approx: sqrt(N * p * (1-p))
            expected_count = p_exp * total_k
            sigma = math.sqrt(total_k * p_exp * (1 - p_exp))
        
        # Calculate Z-score
        if sigma == 0:
            z_score = 0 # p_exp is too small (prevent floating point issues)
        else:
            z_score = (count - expected_count) / sigma

        if z_score > z_threshold:
            results.append({
                "motif": motif,
                "count": count,
                "expected_count": expected_count,
                "p_obs": p_obs,
                "p_exp": p_exp,
                "z_score": z_score
            })
            
    # 3. Format Output
    df = pd.DataFrame(results)
    if not df.empty:
        df = df.sort_values(by="z_score", ascending=False)
    
    return df


def build_motif_edgelist(flights_data, significant_motifs_df, k, z_threshold=1.96):
    """
    Constructs the weighted network efficiently using sparse logic.
    Returns an Edge List DataFrame instead of a dense matrix.
    """
    
    # 1. Setup Nodes and Indices
    # Filter only significant motifs
    valid_motifs = significant_motifs_df['motif'].tolist()
    # Map motif tuple -> integer ID for efficient processing
    motif_to_idx = {motif: i for i, motif in enumerate(valid_motifs)}
    # Map integer ID -> probability (for vectorization later)
    motif_probs_dict = significant_motifs_df.set_index('motif')['p_obs'].to_dict()
    prob_array = np.array([motif_probs_dict[m] for m in valid_motifs])
    
    # 2. Pre-calculate Normalization Constant (Eq. 6 sum)
    normalization_sum = 0
    # We can calculate this purely from lengths, no need to load data content
    for flight in flights_data:
        l_s = len(flight)
        if l_s >= 2 * k:
            normalization_sum += (l_s - 2 * k + 1) * (l_s - 2 * k + 2)

    # 3. Sparse Counting of Observed Co-occurrences
    # We use a dict to sparsely store counts: {(source_idx, target_idx): count}
    observed_counts = defaultdict(int)
    
    for flight in flights_data:
        if len(flight) < 2 * k:
            continue
            
        # Find instances of significant motifs
        instances = []
        for i in range(len(flight) - k + 1):
            segment = tuple(flight[i : i+k])
            if segment in motif_to_idx:
                instances.append((i, motif_to_idx[segment]))
        
        # Count pairs (Time Complexity: O(M^2) where M is motifs per flight)
        # Since M is usually small compared to total motifs, this is fast.
        n_inst = len(instances)
        for i in range(n_inst):
            start_x, id_x = instances[i]
            for j in range(i + 1, n_inst):
                start_y, id_y = instances[j]
                
                # Check non-overlapping constraint
                if start_y >= start_x + k:
                    observed_counts[(id_x, id_y)] += 1

    # 4. Vectorized Z-Score Calculation (Only on observed edges)
    # Convert dict to arrays for numpy speed
    if not observed_counts:
        return pd.DataFrame(columns=['Source', 'Target', 'Weight'])

    sources = []
    targets = []
    counts = []
    
    for (u, v), c in observed_counts.items():
        sources.append(u)
        targets.append(v)
        counts.append(c)
    
    sources = np.array(sources)
    targets = np.array(targets)
    obs_counts = np.array(counts)
    
    # Calculate Expected Counts Vectorized
    # Exp(X, Y) = 0.5 * p(X) * p(Y) * normalization_sum
    # We fetch p(X) and p(Y) using the source/target indices
    p_source = prob_array[sources]
    p_target = prob_array[targets]
    
    expected_counts = 0.5 * p_source * p_target * normalization_sum
    
    # Calculate Z-scores
    # Using Poisson approximation for standard deviation: sigma = sqrt(expected)
    # For stricter binomial sigma, we'd need Total Possible Pairs, 
    # but sqrt(exp) is standard for rare network events.
    sigma = np.sqrt(expected_counts)
    
    # Avoid division by zero
    with np.errstate(divide='ignore', invalid='ignore'):
        z_scores = (obs_counts - expected_counts) / sigma
        # Handle cases where sigma might be 0 (though unlikely if p > 0)
        z_scores = np.nan_to_num(z_scores)

    # 5. Filter and Format Output
    # Create mask for significant edges
    mask = z_scores > z_threshold
    
    # Map integer IDs back to Motif Tuples
    final_sources = [valid_motifs[i] for i in sources[mask]]
    final_targets = [valid_motifs[i] for i in targets[mask]]
    final_weights = z_scores[mask]
    
    # Create Edge List DataFrame
    edge_df = pd.DataFrame({
        'Source': final_sources,
        'Target': final_targets,
        'Weight': final_weights
    })
    
    # Sort by significance
    edge_df = edge_df.sort_values(by='Weight', ascending=False)
    
    return edge_df


def plot_static_graph(G):
    plt.figure(figsize=(12, 12))
    
    # 1. Calculate Layout (Spring layout positions nodes based on connections)
    pos = nx.spring_layout(G, k=0.5, seed=42)  # k regulates distance between nodes
    
    # 2. Extract Weights for styling
    weights = [G[u][v]['weight'] for u, v in G.edges()]
    
    # Normalize weights for visualization (e.g., thickness between 0.5 and 4.5)
    # Avoid division by zero if all weights are same
    if max(weights) > min(weights):
        width = [(w - min(weights))/(max(weights)-min(weights)) * 4 + 0.5 for w in weights]
    else:
        width = [1.0 for _ in weights]

    # 3. Draw the Network
    # Nodes
    nx.draw_networkx_nodes(G, pos, node_size=700, node_color='lightblue')
    
    # Edges (Width varies by Z-score)
    nx.draw_networkx_edges(G, pos, width=width, edge_color='gray', 
                           arrowstyle='->', arrowsize=20)
    
    # Labels (Motif names)
    nx.draw_networkx_labels(G, pos, font_size=10, font_family="sans-serif")
    
    plt.title("Flight Phase K-Motif Network (Weighted by Z-Score)")
    plt.axis('off')
    plt.show()


def plot_interactive_graph(G, filename="motif_network.html"):
    # Initialize PyVis network
    net = Network(height="750px", width="100%", notebook=True, directed=True)
    
    # Optional: Detect Communities (Louvain method) to color nodes
    try:
        partition = community_louvain.best_partition(G.to_undirected())
        # Add partition info to node attributes for coloring
        for node, group_id in partition.items():
            G.nodes[node]['group'] = group_id
    except ImportError:
        print("Community detection skipped (install 'python-louvain' for colors)")

    # Convert NetworkX graph to PyVis
    net.from_nx(G)
    
    # Customizing physics for better separation (optional)
    net.set_options("""
    var options = {
      "physics": {
        "forceAtlas2Based": {
          "gravitationalConstant": -50,
          "centralGravity": 0.01,
          "springLength": 100,
          "springConstant": 0.08
        },
        "maxVelocity": 50,
        "solver": "forceAtlas2Based",
        "timestep": 0.35,
        "stabilization": { "enabled": true }
      }
    }
    """)
    
    # Save and show
    net.show(filename)
    return filename

# Run

In [3]:
df = pd.read_csv("PIE_data_with_context.csv", sep=";", header=0, index_col=0)
df.drop(columns=['session'], inplace=True)
df = df.sort_values(by=['F_SESSION', 'F_START_FRAME'], ascending=[True, True])
df.head()

Unnamed: 0,F_SESSION,F_START_FRAME,F_END_FRAME,F_DURATION,FIRST_WORD_INDEX,SECOND_WORD_INDEX,THIRD_WORD_INDEX,k_aircraft,k_operator,k_mission
879696,3130311,332,375,00:00:22.000,3.0,10.0,285.0,46852,20,-2
879697,3130311,376,441,00:00:33.000,3.0,75.0,281.0,46852,20,-2
879698,3130311,442,741,00:02:30.000,3.0,10.0,8.0,46852,20,-2
879699,3130311,742,760,00:00:09.500,3.0,28.0,39.0,46852,20,-2
879700,3130311,761,764,00:00:02.000,3.0,10.0,8.0,46852,20,-2


In [None]:
first = df['FIRST_WORD_INDEX']
second = df['SECOND_WORD_INDEX']
third = df['THIRD_WORD_INDEX']

df['phase'] = list(zip(first, second, third))
phase_to_idx = {phase: i for i, phase in enumerate(df['phase'].unique())}
idx_to_phase = {i: phase for i, phase in enumerate(df['phase'].unique())}
df['phase_idx'] = df['phase'].map(phase_to_idx)

In [5]:
sessions = df['F_SESSION'].unique()
# df = process_flight_sequences(df, md=1)
flights = [df[df['F_SESSION'] == s]['phase_idx'].to_list() for s in sessions]

In [6]:
k = 3
df_significant = find_significant_motifs(flights, k, z_threshold=1.96, nb_phases=None)
print(f"Total significant motifs found: {len(df_significant)}")
df_significant.sort_values(by="z_score", ascending=False)

Total significant motifs found: 655731


Unnamed: 0,motif,count,expected_count,p_obs,p_exp,z_score
455803,"(13516, 7, 7577)",1,0.000050,4.046003e-07,2.027562e-11,141.255128
588400,"(25488, 7, 18445)",1,0.000050,4.046003e-07,2.027562e-11,141.255128
564180,"(18302, 7, 26733)",1,0.000050,4.046003e-07,2.027562e-11,141.255128
542379,"(966, 7, 5402)",1,0.000050,4.046003e-07,2.027562e-11,141.255128
266921,"(8668, 1097, 1319)",1,0.000085,4.046003e-07,3.425041e-11,108.678455
...,...,...,...,...,...,...
10196,"(1311, 1336, 7258)",1,0.176507,4.046003e-07,7.141481e-08,1.960103
159631,"(6035, 10190, 10190)",5,2.135592,2.023002e-06,8.640612e-07,1.960090
208200,"(8411, 8410, 10723)",5,2.135592,2.023002e-06,8.640612e-07,1.960090
21957,"(2638, 2638, 1476)",5,2.135592,2.023002e-06,8.640612e-07,1.960090


In [7]:
df_significant[df_significant['count']>1]

Unnamed: 0,motif,count,expected_count,p_obs,p_exp,z_score
69,"(116, 7, 116)",2816,669.935785,0.001139,2.710562e-04,82.924907
62542,"(11817, 2, 16023)",14,0.064764,0.000006,2.620365e-08,54.757868
764,"(192, 2, 6)",1328,420.512429,0.000537,1.701395e-04,44.257636
164,"(2, 2, 2)",2935,1328.055177,0.001188,5.373315e-04,44.107213
1,"(2, 3, 2)",2906,1354.917851,0.001176,5.482002e-04,42.149982
...,...,...,...,...,...,...
134995,"(7269, 8615, 8615)",3,1.020179,0.000001,4.127647e-07,1.960144
159631,"(6035, 10190, 10190)",5,2.135592,0.000002,8.640612e-07,1.960090
208200,"(8411, 8410, 10723)",5,2.135592,0.000002,8.640612e-07,1.960090
21957,"(2638, 2638, 1476)",5,2.135592,0.000002,8.640612e-07,1.960090


In [None]:
edge_df = build_motif_edgelist(flights, df_significant, k, z_threshold=1.96)
print(edge_df.shape)

In [None]:
# Convert Edge List to Graph
G = nx.from_pandas_edgelist(
    edge_df, 
    source='Source', 
    target='Target', 
    edge_attr='Weight', 
    create_using=nx.DiGraph
)

# Then proceed with PyVis or Matplotlib as before

print(f"Nodes: {G.number_of_nodes()}, Edges: {G.number_of_edges()}")