Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT compatible way to convert base64 encoded image to a Tensor #6878

Open
anjali-chadha opened this issue Nov 1, 2022 · 3 comments
Open

Comments

@anjali-chadha
Copy link

Hello -
I have an image which is encoded as a base64 string. I want to take this string as an input and convert it to a Tensor.

How can I achieve this in a way which is torchscriptable?

So far, I have tried the two step process: 1) Convert base64 encoded image to Pillow Image using pillow library, 2) convert PIL Image to Tensor using transformations provided in torchvision
However, this approach is not torchscript compatible.

Any recommendations on how to do this in torchscript-compatible manner?

Thank you!

@anjali-chadha anjali-chadha changed the title JIT compatible way to convert base64 encoded image string to a Tensor JIT compatible way to convert base64 encoded image to a Tensor Nov 1, 2022
@mthrok
Copy link
Contributor

mthrok commented Nov 7, 2022

The best I can think of is to use the underlying implementation of torchaudio's StreamReader.
It is not public API, and, base64 data has to be decoded separately.

import base64

import torch
from torchaudio.io import StreamReader


@torch.jit.script
def decode_png(data):
    s = torch.classes.torchaudio.ffmpeg_StreamReaderTensor(data, "png_pipe", None, 8046)
    s.add_video_stream(
        s.find_best_video_stream(),  # stream_index
        -1,  # frames_per_chunk
        3,  # buffer_chunk_size
        "format=pix_fmts=rgb24",  # filter_desc
        None,  # decoder
        None,  # decoder_option
        None,  # hw_accel
    )
    s.process_all_packets()
    img, = s.pop_chunks()
    return img


if __name__ == '__main__':

    # From https://codepen.io/jamiekane/pen/YayWOa
    data = b"iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="

    data = base64.b64decode(data)

    data = torch.frombuffer(data, dtype=torch.uint8)
    img = decode_png(data)

    import matplotlib.pyplot as plt

    plt.imshow(img[0].permute(1, 2, 0))
    plt.show()

@YosuaMichael
Copy link
Contributor

YosuaMichael commented Nov 17, 2022

Similar to @mthrok , we can do this with torchvision.io.image API for decoding image. However need to decode base64 separately (I can't find a torchscript-compatible way to decode base64).

import base64
import torch
from torchvision.io import image, ImageReadMode

raw_image_b64 = b"iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
raw_image_bytes = base64.b64decode(raw_image_b64)
raw_image_tensor = torch.frombuffer(raw_image_bytes, dtype=torch.uint8)

# This is jit scripted decode_image function
decode_image_script = torch.jit.script(image.decode_image)

decoded_image = decode_image_script(raw_image_tensor, mode=ImageReadMode.RGB)

import matplotlib.pyplot as plt

plt.imshow(decoded_image.permute(1, 2, 0))
plt.show()

@vadimkantorov
Copy link

vadimkantorov commented Dec 8, 2022

I wonder if there's a simple way of transforming SSE vectorized base64 decoding to tensor ops: http://www.alfredklomp.com/programming/sse-base64/, https://r-libre.teluq.ca/1362/1/base64.pdf -> this probably could be a feature request to core pytorch as well...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants