In [20]:
from collections import defaultdict
from pathlib import Path
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger

from pyvis.network import Network

In [2]:
df = pd.read_csv("./response_1746808556206.csv")

In [8]:
gg = GraphService()

[32m2025-05-09 13:51:50.344[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m14[0m - [1mInitialized Graph service.[0m


In [12]:
gg.create_graph(dataframe=df, limit_nodes=20)

[32m2025-05-09 13:52:59.879[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_graph[0m:[36m103[0m - [1mCreating graph from dataframe.[0m
[32m2025-05-09 13:52:59.880[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_graph[0m:[36m108[0m - [1mPreprocessing to identify target and batedor plates.[0m
[32m2025-05-09 13:52:59.882[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_graph[0m:[36m112[0m - [1mIterating over dataframe rows to add nodes and edges.[0m
[32m2025-05-09 13:52:59.899[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_graph[0m:[36m123[0m - [1mRemoving isolated nodes.[0m
[32m2025-05-09 13:52:59.900[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_graph[0m:[36m127[0m - [1mLimiting nodes in graph.[0m
[32m2025-05-09 13:52:59.902[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_graph[0m:[36m132[0m - [1mGraph created.[0m


In [15]:
G = gg.G

In [22]:
filename = "grafo_1.html"
net = create_interactive_graph(G=G, filename=filename)

In [23]:
net.show(filename)

grafo_1.html


In [11]:
class GraphService:
    """
    Service for graph operations.


    Args:
        G: Directed graph.
        dataframe: DataFrame with the data of the plates.
    """

    def __init__(self):
        self.G: nx.DiGraph | None = None
        self.dataframe: pd.DataFrame | None = None
        logger.info("Initialized Graph service.")

    def __limit_nodes_in_graph(self, G: nx.DiGraph, max_nodes: int = 30) -> nx.DiGraph:
        """
        Limits the number of nodes in the graph to improve visualization and avoid overlapping.

        Args:
            G (nx.DiGraph): The directed graph to be limited.
            max_nodes (int, default=30): Approximate reference value for the maximum number of nodes, not an exact limit.
            The algorithm prioritizes keeping nodes with higher weight (relevance) and may include a
            different number of nodes depending on the weight distribution in the graph.

        Returns:
            nx.DiGraph: The graph with a limited number of nodes.
        """
        # Count of target and batedor nodes by weight group
        weight_counts = {}
        for node, data in G.nodes(data=True):
            node_type = data.get("type")

            # For batedor nodes, get the weight of the edges
            if node_type == "batedor":
                edges = G.edges(node, data=True)
                for _, _, edge_data in edges:
                    weight = edge_data.get("weight", 0)
                    if weight not in weight_counts:
                        weight_counts[weight] = {"target": 0, "batedor": 0}
                    weight_counts[weight]["batedor"] += 1
            # For target nodes, add to counter with weight 0 (or other specific value)
            elif node_type == "target":
                if 0 not in weight_counts:
                    weight_counts[0] = {"target": 0, "batedor": 0}
                weight_counts[0]["target"] += 1

        # Sort weights in descending order
        sorted_weights = sorted(weight_counts.keys(), reverse=True)

        # Determine the minimum weight to reach at least 60 nodes
        min_weight = 1  # default value
        total_nodes = 0

        for weight in sorted_weights:
            total_nodes += (
                weight_counts[weight]["target"] + weight_counts[weight]["batedor"]
            )
            if total_nodes >= max_nodes:
                min_weight = weight
                break

        # Filter the graph to include only nodes with weight >= min_weight
        nodes_to_remove = []
        for node, data in G.nodes(data=True):
            if data.get("type") == "batedor":
                # Check if all edges of the batedor have weight < min_weight
                all_edges_below_threshold = True
                for _, _, edge_data in G.edges(node, data=True):
                    if edge_data.get("weight", 0) >= min_weight:
                        all_edges_below_threshold = False
                        break

                if all_edges_below_threshold:
                    nodes_to_remove.append(node)

        G.remove_nodes_from(nodes_to_remove)

        isolated_nodes = list(nx.isolates(G))
        G.remove_nodes_from(isolated_nodes)

        return G

    def create_graph(self, dataframe: pd.DataFrame, limit_nodes: int = 30) -> None:
        """
        Creates and displays an interactive graph from a DataFrame,
        with colored nodes, detailed tooltips, and visualization options.

        Args:
            dataframe: pandas DataFrame with the data of the plates.  Must contain, at least,
                the columns 'placa_target', 'placa', 'count_different_targets' and 'target'.
                Ideally, it should also contain columns like 'datahora_local', 'bairro', etc.
                for more detailed information in the tooltips.
            limit_nodes: Approximate threshold for the maximum number of nodes in the graph.
                        The algorithm will select the nodes with the highest weight until it reaches
                        or approaches this number, prioritizing more relevant connections.
                        It is not an exact limit, but a reference value to control the density of the graph.
        """
        logger.info(f"Creating graph from dataframe.")
        self.dataframe = dataframe

        G = nx.DiGraph()

        logger.info("Preprocessing to identify target and batedor plates.")
        # Preprocessing to identify target and batedor plates
        target_plates = set(dataframe[dataframe["target"] == True]["placa"])

        logger.info("Iterating over dataframe rows to add nodes and edges.")
        for _, row in dataframe.iterrows():
            node_type = "batedor"  # Assume batedor by default
            if row["placa"] in target_plates:
                node_type = "target"  # Overwrites if the plate is in the target list

            G.add_node(row["placa"], type=node_type, **row.to_dict())

            if not row["target"]:  # Adds edges only for batedors
                G.add_edge(row["placa"], row["placa_target"], weight=row["weight"])

        logger.info("Removing isolated nodes.")
        isolated_nodes = list(nx.isolates(G))
        G.remove_nodes_from(isolated_nodes)

        logger.info("Limiting nodes in graph.")
        G = self.__limit_nodes_in_graph(G, max_nodes=limit_nodes)

        # net.show_buttons(filter_=['physics'])  # Optional: show physics controls
        self.G = G
        logger.info("Graph created.")
        # return G

    def to_html(self):
        """
        Cria e exibe um grafo interativo a partir de um DataFrame,
        com nós coloridos, tooltips detalhados e opções de visualização.

        Args:
            df: DataFrame pandas com os dados das placas.  Deve conter, no mínimo,
                as colunas 'placa_target', 'placa', 'count_different_targets' e 'target'.
                Idealmente, deve conter também colunas como 'datahora_local', 'bairro', etc.
                para informações mais detalhadas nos tooltips.
            filename: Nome do arquivo HTML onde o grafo será salvo.
        """

        net = Network(
            height="800px",
            width="100%",
            notebook=True,
            directed=True,
            cdn_resources="remote",
        )

        for node, data in self.G.nodes(data=True):
            # Cor do nó
            if data.get("type") == "target":
                color = "red"
                size = 20
            else:
                color = "blue"
                size = 20

            net.add_node(node, label=node, color=color, size=size)

        for source, target, data in self.G.edges(data=True):
            weight = data.get("weight", 1)  # Peso padrão 1 se não houver
            net.add_edge(
                source, target, value=weight, color="black", title=f"Peso: {weight}"
            )

        net.set_options(
            """
        {
        "physics": {
            "forceAtlas2Based": {
            "theta": 0.5,
            "gravitationalConstant": -50,
            "centralGravity": 0.01,
            "springLength": 100,
            "springConstant": 0.08,
            "damping": 0.4,
            "avoidOverlap": 1
            },
            "maxVelocity": 50,
            "minVelocity": 0.75,
            "solver": "barnesHut",
            "timestep": 0.5
        },
        "wind":{
            "x":0,
            "y":0
        },
        "edges": {
            "smooth": {
            "type": "dynamic"
            }
        },
        "interaction": {
            "hover": true  
        }
        }
        """
        )

        # net.show_buttons(filter_=['physics'])  # Opcional: mostrar controles de física
        return net

    def to_png(self, file_dir: Path | str = "./", file_name: str = "grafo.png") -> Path:
        """
        Converts the graph to a PNG image.

        Args:
            file_dir: Directory to save the PNG file.
            file_name: Name of the PNG file.

        Returns:
            Path: Path to the PNG file.
        """
        logger.info("Converting graph to PNG.")
        # Ensure file_dir is a Path instance before joining with filename
        if not isinstance(file_dir, Path) and isinstance(file_dir, str):
            file_dir = Path(file_dir)

        if not isinstance(file_name, str):
            raise ValueError("file_name must be instance of pathlib.Path or a string")

        file_path = file_dir / file_name
        plt.figure(figsize=(16, 12))

        logger.info("Preprocessing to identify target and batedor plates.")
        # Preprocessing
        target_plates = set(self.dataframe[self.dataframe["target"] == True]["placa"])
        batedores = [
            node
            for node, data in self.G.nodes(data=True)
            if data.get("type") == "batedor"
        ]

        # Custom layout creation
        pos = {}

        logger.info("Positioning batedors in a regular polygon.")
        # 1. Positions the batedors in a regular polygon
        num_batedores = len(batedores)
        if num_batedores > 0:
            radius = 5  # Raio do polígono
            angles = np.linspace(0, 2 * np.pi, num_batedores, endpoint=False)
            for i, batedor in enumerate(batedores):
                angle = angles[i] - np.pi / 2  # Rotates to have a node at the top
                x = radius * np.cos(angle)
                y = radius * np.sin(angle)
                pos[batedor] = np.array([x, y])

        logger.info("Grouping targets by connected batedors.")
        # 2. Groups targets by connected batedors
        connection_groups = defaultdict(list)
        for node, data in self.G.nodes(data=True):
            if data.get("type") == "target":
                connected = tuple(
                    sorted(p for p in self.G.predecessors(node) if p in batedores)
                )
                connection_groups[connected].append(node)

        logger.info("Positioning each group dynamically.")
        # 3. Positions each group dynamically
        for connected_bats, nodes in connection_groups.items():
            if not connected_bats:
                continue

            # Calculates the centroid of the connected batedors
            bats_pos = [pos[b] for b in connected_bats]
            centroid = np.mean(bats_pos, axis=0)

            # Direction and distance from the center
            vec_from_center = centroid - np.mean([pos[b] for b in batedores], axis=0)
            if np.linalg.norm(vec_from_center) > 0:
                direction = vec_from_center / np.linalg.norm(vec_from_center)
            else:
                direction = np.array([1, 0])

            # Base position with distance adjustment
            # Radius increases with the number of connections
            base_radius = 2 + len(connected_bats)
            base_pos = centroid + direction * base_radius

            # Circular distribution of nodes
            num_nodes = len(nodes)
            node_radius = 1.5 + 0.2 * num_nodes
            angles = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False)

            for i, (node, angle) in enumerate(zip(nodes, angles)):
                offset = np.array([np.cos(angle), np.sin(angle)]) * node_radius
                pos[node] = base_pos + offset

        logger.info("Applying spring layout for remaining nodes.")
        # 4. Applies spring layout for remaining nodes
        unpositioned = list(set(self.G.nodes()) - set(pos.keys()))
        if unpositioned:
            sub_pos = nx.spring_layout(
                self.G.subgraph(unpositioned), k=150, iterations=500, seed=42
            )
            pos.update(sub_pos)

        # Dynamic visual configurations
        node_colors = []
        node_sizes = []
        for node in self.G.nodes():
            if self.G.nodes[node].get("type") == "target":
                node_colors.append("red")
                node_sizes.append(400)
            else:
                node_colors.append("blue")
                node_sizes.append(250)

        logger.info("Normalizing edge widths.")
        # Normalization of edge widths
        edge_weights = [d["weight"] for _, _, d in self.G.edges(data=True)]
        # logger.debug(f"Edge weights: {edge_weights}")

        if edge_weights and len(set(edge_weights)) > 1:
            min_w, max_w = min(edge_weights), max(edge_weights)
            edge_widths = [1 + 4 * (w - min_w) / (max_w - min_w) for w in edge_weights]
        else:
            edge_widths = [1]

        logger.info("Drawing the graph.")
        # Drawing the graph
        nx.draw_networkx_nodes(
            self.G,
            pos,
            node_color=node_colors,
            node_size=node_sizes,
            alpha=0.9,
            linewidths=2,
        )

        logger.info("Drawing edges.")
        nx.draw_networkx_edges(
            self.G,
            pos,
            # edge_color='gray',
            edge_color="black",
            width=edge_widths,
            arrows=True,
            arrowstyle="-|>",
            arrowsize=15,
            connectionstyle="arc3,rad=0.15",
            alpha=0.7,
        )

        logger.info("Drawing node labels.")
        # Node labels
        for node, (x, y) in pos.items():
            plt.text(
                x,
                y - 0.5,
                node,
                fontsize=9,
                ha="center",
                va="center",
                color="black",
                bbox=dict(facecolor="white", edgecolor="none", alpha=0.7),
            )

        # Final adjustments
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(file_path, dpi=300, bbox_inches="tight")
        logger.info(f"Graph converted to PNG")

        return file_path