Skip to content

Commit

Permalink
move array_to_imagestr function to be part of public API (#2879)
Browse files Browse the repository at this point in the history
* move array_to_imagestr function to be part of public API

* renamed function
  • Loading branch information
emmanuelle committed Nov 17, 2020
1 parent fa9500b commit 9c9b98e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 76 deletions.
75 changes: 75 additions & 0 deletions packages/python/plotly/_plotly_utils/data_utils.py
@@ -0,0 +1,75 @@
from io import BytesIO
import base64
from .png import Writer, from_array

try:
from PIL import Image

pil_imported = True
except ImportError:
pil_imported = False


def image_array_to_data_uri(img, backend="pil", compression=4, ext="png"):
"""Converts a numpy array of uint8 into a base64 png or jpg string.
Parameters
----------
img: ndarray of uint8
array image
backend: str
'auto', 'pil' or 'pypng'. If 'auto', Pillow is used if installed,
otherwise pypng.
compression: int, between 0 and 9
compression level to be passed to the backend
ext: str, 'png' or 'jpg'
compression format used to generate b64 string
"""
# PIL and pypng error messages are quite obscure so we catch invalid compression values
if compression < 0 or compression > 9:
raise ValueError("compression level must be between 0 and 9.")
alpha = False
if img.ndim == 2:
mode = "L"
elif img.ndim == 3 and img.shape[-1] == 3:
mode = "RGB"
elif img.ndim == 3 and img.shape[-1] == 4:
mode = "RGBA"
alpha = True
else:
raise ValueError("Invalid image shape")
if backend == "auto":
backend = "pil" if pil_imported else "pypng"
if ext != "png" and backend != "pil":
raise ValueError("jpg binary strings are only available with PIL backend")

if backend == "pypng":
ndim = img.ndim
sh = img.shape
if ndim == 3:
img = img.reshape((sh[0], sh[1] * sh[2]))
w = Writer(
sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression
)
img_png = from_array(img, mode=mode)
prefix = "data:image/png;base64,"
with BytesIO() as stream:
w.write(stream, img_png.rows)
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
else: # pil
if not pil_imported:
raise ImportError(
"pillow needs to be installed to use `backend='pil'. Please"
"install pillow or use `backend='pypng'."
)
pil_img = Image.fromarray(img)
if ext == "jpg" or ext == "jpeg":
prefix = "data:image/jpeg;base64,"
ext = "jpeg"
else:
prefix = "data:image/png;base64,"
ext = "png"
with BytesIO() as stream:
pil_img.save(stream, format=ext, compress_level=compression)
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
return base64_string
File renamed without changes.
77 changes: 2 additions & 75 deletions packages/python/plotly/plotly/express/_imshow.py
@@ -1,94 +1,21 @@
import plotly.graph_objs as go
from _plotly_utils.basevalidators import ColorscaleValidator
from ._core import apply_default_cascade
from io import BytesIO
import base64
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
import pandas as pd
from .png import Writer, from_array
import numpy as np
from plotly.utils import image_array_to_data_uri

try:
import xarray

xarray_imported = True
except ImportError:
xarray_imported = False
try:
from PIL import Image

pil_imported = True
except ImportError:
pil_imported = False

_float_types = []


def _array_to_b64str(img, backend="pil", compression=4, ext="png"):
"""Converts a numpy array of uint8 into a base64 png string.
Parameters
----------
img: ndarray of uint8
array image
backend: str
'auto', 'pil' or 'pypng'. If 'auto', Pillow is used if installed,
otherwise pypng.
compression: int, between 0 and 9
compression level to be passed to the backend
ext: str, 'png' or 'jpg'
compression format used to generate b64 string
"""
# PIL and pypng error messages are quite obscure so we catch invalid compression values
if compression < 0 or compression > 9:
raise ValueError("compression level must be between 0 and 9.")
alpha = False
if img.ndim == 2:
mode = "L"
elif img.ndim == 3 and img.shape[-1] == 3:
mode = "RGB"
elif img.ndim == 3 and img.shape[-1] == 4:
mode = "RGBA"
alpha = True
else:
raise ValueError("Invalid image shape")
if backend == "auto":
backend = "pil" if pil_imported else "pypng"
if ext != "png" and backend != "pil":
raise ValueError("jpg binary strings are only available with PIL backend")

if backend == "pypng":
ndim = img.ndim
sh = img.shape
if ndim == 3:
img = img.reshape((sh[0], sh[1] * sh[2]))
w = Writer(
sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression
)
img_png = from_array(img, mode=mode)
prefix = "data:image/png;base64,"
with BytesIO() as stream:
w.write(stream, img_png.rows)
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
else: # pil
if not pil_imported:
raise ImportError(
"pillow needs to be installed to use `backend='pil'. Please"
"install pillow or use `backend='pypng'."
)
pil_img = Image.fromarray(img)
if ext == "jpg" or ext == "jpeg":
prefix = "data:image/jpeg;base64,"
ext = "jpeg"
else:
prefix = "data:image/png;base64,"
ext = "png"
with BytesIO() as stream:
pil_img.save(stream, format=ext, compress_level=compression)
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
return base64_string


def _vectorize_zvalue(z, mode="max"):
alpha = 255 if mode == "max" else 0
if z is None:
Expand Down Expand Up @@ -422,7 +349,7 @@ def imshow(
for ch in range(img.shape[-1])
]
)
img_str = _array_to_b64str(
img_str = image_array_to_data_uri(
img_rescaled,
backend=binary_backend,
compression=binary_compression_level,
Expand Down
2 changes: 1 addition & 1 deletion packages/python/plotly/plotly/utils.py
Expand Up @@ -4,7 +4,7 @@
from pprint import PrettyPrinter

from _plotly_utils.utils import *

from _plotly_utils.data_utils import *

# Pretty printing
def _list_repr_elided(v, threshold=200, edgeitems=3, indent=0, width=80):
Expand Down

0 comments on commit 9c9b98e

Please sign in to comment.