In [39]:
import argparse
import glob
import multiprocessing as mp
import os
import io
import sys

import tempfile
import time
import warnings

import cv2
import numpy as np
import tqdm
import torch
from PIL import Image
import zstd

from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger

from mask_former import add_mask_former_config
from demo.predictor import VisualizationDemo

In [40]:
CONFIG = "configs/mapillary-vistas-65-v2/maskformer_panoptic_swin_base_transfer.yaml"
INPUT = ["exs/*.jpg"]
OUTPUT = "exs/lol"
WEIGHTS = "output/model_final.pth"
QUANTIZATION = 40

In [41]:
from mask_former.data.datasets.register_mapillary_vistas import MAPILLARY_VISTAS_SEM_SEG_CATEGORIES
categories = [category["name"] for category in MAPILLARY_VISTAS_SEM_SEG_CATEGORIES]

In [42]:
def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    add_deeplab_config(cfg)
    add_mask_former_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg

class Args:
    def __init__(self, config, input, output, opts):
        self.config_file, self.input, self.output, self.opts = config, input, output, opts

In [43]:
args = Args(CONFIG, INPUT, OUTPUT, opts=["MODEL.WEIGHTS", WEIGHTS])
setup_logger(name="fvcore")
logger = setup_logger()

In [44]:
demo = VisualizationDemo(setup_cfg(args))

[32m[10/28 23:47:16 fvcore.common.checkpoint]: [0mLoading checkpoint from output/model_final.pth


In [49]:
input_images = glob.glob(os.path.expanduser(args.input[0]))

for path in input_images:
    start_time = time.time()
    img = read_image(path, format="BGR")
    print("\n", path, round(time.time()-start_time, 3))
    
    start_time = time.time()
    predictions, segmented = demo.run_on_image(img, only_prediction=True)
    print(segmented.dtype, round(time.time()-start_time, 3))
    
    
    start_time = time.time()
    img = Image.fromarray(np.moveaxis(segmented.numpy(), 0, -1))
    img_byte_arr = io.BytesIO()
    img.save(img_byte_arr, format="PNG")
    img_byte_arr = img_byte_arr.getvalue()
    print(sys.getsizeof(img_byte_arr), round(time.time()-start_time, 3))
    
    start_time = time.time()
    quantized = (predictions["sem_seg"] * QUANTIZATION).to(torch.uint8)
    a = quantized.cpu().numpy().copy()
    quantized[:, 1:, :] = torch.diff(quantized, dim=1)
    quantized = quantized.cpu().numpy()
    stream = io.BytesIO()
    np.save(stream, quantized)
    binary = stream.getvalue()
    print(len(binary), round(time.time()-start_time, 3))
    
    compressed = zstd.compress(binary, 7)
    print(len(compressed), round(time.time()-start_time, 3))
    
    quantized = np.cumsum(quantized, axis=1, dtype=np.uint8)
    print(np.allclose(quantized, a))


 exs/a1.jpg 0.008
torch.uint8 0.1
9504 0.01
53248128 0.052
118572 0.071
True

 exs/a2.jpg 0.008
torch.uint8 0.091
9687 0.009
53248128 0.049
108603 0.068
True

 exs/a3.jpg 0.011
torch.uint8 0.092
15310 0.009
53248128 0.051
197946 0.074
True

 exs/a4.jpg 0.008
torch.uint8 0.09
14241 0.009
53248128 0.051
207969 0.074
True

 exs/k1.jpg 0.005
torch.uint8 0.093
8018 0.005
34078848 0.034
88369 0.052
True

 exs/k2.jpg 0.005
torch.uint8 0.093
7545 0.005
34078848 0.034
80105 0.05
True

 exs/k3.jpg 0.005
torch.uint8 0.094
11708 0.006
34078848 0.035
147653 0.055
True

 exs/k4.jpg 0.006
torch.uint8 0.102
10867 0.006
34078848 0.033
148876 0.052


KeyboardInterrupt: 

In [9]:
"""
    rounded = (predictions["sem_seg"] * 40).to(torch.uint8)
    rounded2 = rounded.cpu().numpy()
    rounded[:, 1:, :] = torch.diff(rounded, axis=1)
    rounded = rounded.cpu().numpy()
    
    t = time.time()
    stream = io.BytesIO()
    np.save(stream, rounded)
    binary = stream.getvalue()
    compressed = zstd.compress(binary, 12)
    print(len(compressed), round(time.time()-t, 3))
    
    binary = zstd.decompress(compressed)
    stream = io.BytesIO(binary)
    rounded = np.load(stream, allow_pickle=True)
    print(rounded.shape)
    rounded = np.cumsum(rounded, axis=1, dtype=np.uint8)
    print(np.allclose(rounded, rounded2))
"""

'\n    rounded = (predictions["sem_seg"] * 40).to(torch.uint8)\n    rounded2 = rounded.cpu().numpy()\n    rounded[:, 1:, :] = torch.diff(rounded, axis=1)\n    rounded = rounded.cpu().numpy()\n    \n    t = time.time()\n    stream = io.BytesIO()\n    np.save(stream, rounded)\n    binary = stream.getvalue()\n    compressed = zstd.compress(binary, 12)\n    print(len(compressed), round(time.time()-t, 3))\n    \n    binary = zstd.decompress(compressed)\n    stream = io.BytesIO(binary)\n    rounded = np.load(stream, allow_pickle=True)\n    print(rounded.shape)\n    rounded = np.cumsum(rounded, axis=1, dtype=np.uint8)\n    print(np.allclose(rounded, rounded2))\n'

In [201]:
    """
    t = time.time()
    rounded = (predictions["sem_seg"] * 40).to(torch.uint8).cpu().numpy()
    size = 0
    for a in rounded:
        img = Image.fromarray(a)
        # img = img.resize((320, 640), Image.BILINEAR)
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        #print(sys.getsizeof(img_byte_arr))
        if sys.getsizeof(img_byte_arr) < 908:
            continue
        size += sys.getsizeof(img_byte_arr)
    print(size, round(time.time()-t, 2))
    """

924301
