diff --git a/codes/dataops/augmennt/augmennt/functional.py b/codes/dataops/augmennt/augmennt/functional.py index 8d068f9d..c30fa4df 100644 --- a/codes/dataops/augmennt/augmennt/functional.py +++ b/codes/dataops/augmennt/augmennt/functional.py @@ -15,6 +15,7 @@ import collections import warnings +from ...colors import linear2srgb, srgb2linear from .common import preserve_channel_dim, preserve_shape from .common import _cv2_str2pad, _cv2_str2interpolation @@ -168,13 +169,14 @@ def resize(img, size, interpolation='BILINEAR'): raise TypeError('img should be numpy image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): raise TypeError('Got inappropriate size arg: {}'.format(size)) - + w, h, = size if isinstance(size, int): # h, w, c = img.shape #this would defeat the purpose of "size" - + if (w <= h and w == size) or (h <= w and h == size): return img + img = srgb2linear(img) if w < h: ow = size oh = int(size * h / w) @@ -184,9 +186,10 @@ def resize(img, size, interpolation='BILINEAR'): ow = int(size * w / h) output = cv2.resize(img, dsize=(ow, oh), interpolation=_cv2_str2interpolation[interpolation]) else: + img = srgb2linear(img) output = cv2.resize(img, dsize=(size[1], size[0]), interpolation=_cv2_str2interpolation[interpolation]) - - return output + + return linear2srgb(output) def scale(*args, **kwargs): diff --git a/codes/dataops/augmentations.py b/codes/dataops/augmentations.py index 8b681c6b..115e5de6 100644 --- a/codes/dataops/augmentations.py +++ b/codes/dataops/augmentations.py @@ -5,7 +5,8 @@ import numpy as np import dataops.common as util -from dataops.common import fix_img_channels, get_image_paths, read_img, np2tensor +from dataops.colors import linear2srgb, srgb2linear +from dataops.common import fix_img_channels, np2tensor from dataops.debug import * from dataops.imresize import resize as imresize # resize # imresize_np @@ -202,20 +203,24 @@ def __call__(self, img:np.ndarray) -> np.ndarray: if len(self.out_shape) < 3: self.out_shape = self.out_shape + (image_channels(img),) + img = srgb2linear(img) + if self.kind == 'transforms': if self.out_shape: - return resize( - np.copy(img), + img = resize( + img, w=self.out_shape[1], h=self.out_shape[0], method=self.interpolation) - return scale_( - np.copy(img), self.scale, method=self.interpolation) - scale = None if self.out_shape else 1/self.scale - # return imresize_np( - # np.copy(img), scale=scale, antialiasing=self.antialiasing, interpolation=self.interpolation) - return imresize( - np.copy(img), scale, out_shape=self.out_shape, - antialiasing=self.antialiasing, interpolation=self.interpolation) + else: + img = scale_( + img, self.scale, method=self.interpolation) + else: + scale = None if self.out_shape else 1/self.scale + img = imresize( + img, scale, out_shape=self.out_shape, + antialiasing=self.antialiasing, interpolation=self.interpolation) + + return linear2srgb(img) def get_resize(size=None, scale=None, ds_algo=None, diff --git a/codes/dataops/colors.py b/codes/dataops/colors.py index 625011b8..8a616e36 100644 --- a/codes/dataops/colors.py +++ b/codes/dataops/colors.py @@ -193,14 +193,65 @@ def yuv_to_rgb(input: torch.Tensor, consts='yuv') -> torch.Tensor: b: torch.Tensor = y + Wb * u_shifted return torch.stack((r, g, b), -3) -# Not tested: -def rgb2srgb(imgs): - return torch.where(imgs<=0.04045,imgs/12.92,torch.pow((imgs+0.055)/1.055,2.4)) -# Not tested: -def srgb2rgb(imgs): - return torch.where(imgs<=0.0031308,imgs*12.92,1.055*torch.pow((imgs),1/2.4)-0.055) +def srgb2linear(img): + """Convert sRGB images to linear RGB color space. + Tensors are left as f32 in the range [0, 1]. + Uint8 numpy arrays are converted from uint8 in the range [0, 255] + to f32 in the range [0, 1]. + F32 numpy arrays are assumed to be already be linear RGB. + Always returns a new array. + All values are exact as per the sRGB spec. + """ + a = 0.055 + att = 12.92 + gamma = 2.4 + th = 0.04045 + + if isinstance(img, torch.Tensor): + return torch.where( + img <= th, img / att, torch.pow((img + a)/(1 + a), gamma)) + + if img.dtype == np.uint8: + linear = np.float32(img) / 255.0 + + return np.where( + linear <= th, linear / att, np.power((linear + a) / (1 + a), gamma)) + + return img.copy() + + +def linear2srgb(img): + """Convert linear RGB to the sRGB colour space. + Tensors are left as f32 in the range [0, 1]. + F32 numpy arrays are converted back to the expected uint8 format + in the range [0, 255]. + Uint8 numpy arrays are assumed to already be sRGB. + Always returns a new array. + All values are exact as per the sRGB spec. + """ + a = 0.055 + att = 12.92 + gamma = 2.4 + th = 0.0031308 + + if isinstance(img, torch.Tensor): + return torch.where( + img <= th, + img * att, (1 + a) * torch.pow((img), 1 / gamma) - a) + + if img.dtype == np.float32 or img.dtype == np.float64: + srgb = np.clip(img, 0.0, 1.0) + + srgb = np.where( + srgb <= th, srgb * att, (1 + a) * np.power(srgb, 1.0 / gamma) - a) + + np.clip(srgb * 255, 0.0, 255, out=srgb) + np.around(srgb, out=srgb) + + return srgb.astype(np.uint8) + return img.copy() def color_shift(image: torch.Tensor, mode:str='uniform',