Skip to content

Commit

Permalink
Improve backend and plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
pablormier committed May 12, 2024
1 parent 3ed5da6 commit 2932c2b
Show file tree
Hide file tree
Showing 15 changed files with 2,330 additions and 482 deletions.
27 changes: 25 additions & 2 deletions corneto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@
from corneto.utils import Attr, Attributes


def get_version():
import os
import re

here = os.path.abspath(os.path.dirname(__file__))
pyproject_path = os.path.join(here, "..", "pyproject.toml")

with open(pyproject_path, "r") as f:
content = f.read()

# Regex to find the version number
match = re.search(r'^version\s*=\s*"([^"]+)"', content, re.M)
if match:
return match.group(1)
raise RuntimeError("Version not found in pyproject.toml.")


class DeprecatedBackend:
def __init__(self, backend):
self._backend = backend
Expand All @@ -36,7 +53,7 @@ def __getattr__(self, attr):
return getattr(self._backend, attr)


# Wrapping the backend instance with the DeprecatedBackend
# This way of accessing the backend is deprecated
K = DeprecatedBackend(opt)
ops = DeprecatedBackend(opt)

Expand All @@ -56,8 +73,14 @@ def __getattr__(self, attr):

import_sif = Graph.from_sif

__version__ = "1.0.0.dev0"
try:
# Python 3.8 and newer
from importlib.metadata import version
except ImportError:
# Python < 3.8
from importlib_metadata import version

__version__ = version("corneto")

sys.modules.update({f"{__name__}.{m}": globals()[m] for m in ["pl"]})

Expand Down
76 changes: 59 additions & 17 deletions corneto/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,10 @@ def plot(self, **kwargs):
from corneto._settings import LOGGER
from corneto._util import supports_html

LOGGER.warning(f"SVG+XML rendering failed: {e}.")
LOGGER.debug(f"SVG+XML rendering failed: {e}.")
# Detect if HTML support
if supports_html():
LOGGER.debug("Falling back to Viz.js rendering.")
from corneto.contrib._util import dot_vizjs_html

class _VizJS:
Expand All @@ -513,8 +514,39 @@ def _repr_html_(self):

return _VizJS()
else:
LOGGER.debug("HTML rendering not supported.")
raise e

def plot_values(
self, vertex_values=None, edge_values=None, vertex_props=None, edge_props=None
):
from corneto._plotting import (
create_graphviz_edge_attributes,
create_graphviz_vertex_attributes,
)

vertex_props = vertex_props or {}
edge_props = edge_props or {}
vertex_drawing_props = None
if vertex_values is not None:
# Check if vertices has an attribute value to do vertices.value
if hasattr(vertex_values, "value"):
vertex_values = vertex_values.value
vertex_drawing_props = create_graphviz_vertex_attributes(
list(self.V), vertex_values=vertex_values, **vertex_props
)
edge_drawing_props = None
if edge_values is not None:
if hasattr(edge_values, "value"):
edge_values = edge_values.value
edge_drawing_props = create_graphviz_edge_attributes(
edge_values=edge_values, **edge_props
)
return self.plot(
custom_edge_attr=edge_drawing_props,
custom_vertex_attr=vertex_drawing_props,
)

def to_graphviz(self, **kwargs):
from corneto._plotting import to_graphviz

Expand Down Expand Up @@ -552,15 +584,18 @@ def from_vertex_incidence(
def save(self, filename: str, compressed: Optional[bool] = True) -> None:
import pickle

if not filename:
raise ValueError("Filename must not be empty.")

if not filename.endswith(".pkl"):
filename += ".pkl"

if compressed:
import gzip
import lzma

if not filename.endswith(".gz"):
filename += ".gz"
with gzip.open(filename, "wb") as f:
if not filename.endswith(".xz"):
filename += ".xz"
with lzma.open(filename, "wb", preset=9) as f:
pickle.dump(self, f)
else:
with open(filename, "wb") as f:
Expand All @@ -573,20 +608,27 @@ def load(filename: str) -> "BaseGraph":
if filename.endswith(".gz"):
import gzip

with open(filename, "rb") as f:
first_two_bytes = f.read(2)
if first_two_bytes != b"\x1f\x8b":
from corneto._settings import LOGGER
opener = gzip.open
elif filename.endswith(".bz2"):
import bz2

LOGGER.warning(
f"""The file {filename} has a .gz extension but does not
appear to be a valid gzip file."""
)
with gzip.open(filename, "rb") as f:
return pickle.load(f)
opener = bz2.open
elif filename.endswith(".xz"):
import lzma

opener = lzma.open
elif filename.endswith(".zip"):
import zipfile

def opener(file, mode="r"):
# Supports only reading the first file in a zip archive
with zipfile.ZipFile(file, "r") as z:
return z.open(z.namelist()[0], mode=mode)
else:
with open(filename, "rb") as f:
return pickle.load(f)
opener = open

with opener(filename, "rb") as f:
return pickle.load(f)

def reachability_analysis(
self,
Expand Down
63 changes: 51 additions & 12 deletions corneto/_plotting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Union

import numpy as np

Expand All @@ -22,16 +22,9 @@ def vertex_style(
positive_color: str = "firebrick4",
):
v_values = np.array(P.expr[vertex_var].value)
vertex_attrs = dict()
for vn, v in zip(G.V, v_values):
vertex_attrs[vn] = dict()
if v > 0:
vertex_attrs[vn]["color"] = positive_color
vertex_attrs[vn]["penwidth"] = "2"
elif v < 0:
vertex_attrs[vn]["color"] = negative_color
vertex_attrs[vn]["penwidth"] = "2"
return vertex_attrs
return create_graphviz_vertex_attributes(
G.V, v_values, negative_color=negative_color, positive_color=positive_color
)


def edge_style(
Expand All @@ -42,7 +35,29 @@ def edge_style(
negative_color: str = "dodgerblue4",
positive_color: str = "firebrick4",
):
e_values = np.array(P.expr[edge_var].value)
e_values = P.expr[edge_var].value
if e_values is None:
raise ValueError(
f"""Variable {edge_var} in the problem, but values are None.
Has the problem been solved?"""
)
return create_graphviz_edge_attributes(
e_values,
max_edge_width=max_edge_width,
min_edge_width=min_edge_width,
negative_color=negative_color,
positive_color=positive_color,
)


def create_graphviz_edge_attributes(
edge_values: Union[List, np.ndarray],
max_edge_width: float = 5,
min_edge_width: float = 0.25,
negative_color: str = "dodgerblue4",
positive_color: str = "firebrick4",
):
e_values = np.array(edge_values)
edge_attrs = dict()
for i, v in enumerate(e_values):
if abs(v) > 0:
Expand All @@ -59,6 +74,30 @@ def edge_style(
return edge_attrs


def create_graphviz_vertex_attributes(
graph_vertices: List,
vertex_values: Union[List, np.ndarray],
negative_color: str = "dodgerblue4",
positive_color: str = "firebrick4",
):
v_values = np.array(vertex_values)
if len(v_values) != len(graph_vertices):
raise ValueError(
f"""Length of vertex_values ({len(v_values)}) does not match the number
of vertices ({len(graph_vertices)})"""
)
vertex_attrs = dict()
for vn, v in zip(graph_vertices, v_values):
vertex_attrs[vn] = dict()
if v > 0:
vertex_attrs[vn]["color"] = positive_color
vertex_attrs[vn]["penwidth"] = "2"
elif v < 0:
vertex_attrs[vn]["color"] = negative_color
vertex_attrs[vn]["penwidth"] = "2"
return vertex_attrs


def flow_style(
P,
max_edge_width: float = 5,
Expand Down
Loading

0 comments on commit 2932c2b

Please sign in to comment.