/
rgb_lab_formulation_pytorch.py
executable file
·123 lines (95 loc) · 4.43 KB
/
rgb_lab_formulation_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import cv2
import torch
import numpy as np
from scipy import misc
from PIL import Image
def preprocess_lab(lab):
L_chan, a_chan, b_chan =torch.unbind(lab,dim=2)
# L_chan: black and white with input range [0, 100]
# a_chan/b_chan: color channels with input range ~[-110, 110], not exact
# [0, 100] => [-1, 1], ~[-110, 110] => [-1, 1]
return [L_chan / 50.0 - 1.0, a_chan / 110.0, b_chan / 110.0]
def deprocess_lab(L_chan, a_chan, b_chan):
#TODO This is axis=3 instead of axis=2 when deprocessing batch of images
# ( we process individual images but deprocess batches)
#return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
return torch.stack([(L_chan + 1) / 2.0 * 100.0, a_chan * 110.0, b_chan * 110.0], dim=2)
def rgb_to_lab(srgb):
srgb_pixels = torch.reshape(srgb, [-1, 3])
linear_mask = (srgb_pixels <= 0.04045).type(torch.FloatTensor).cuda()
exponential_mask = (srgb_pixels > 0.04045).type(torch.FloatTensor).cuda()
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = torch.tensor([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
]).type(torch.FloatTensor).cuda()
xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz)
# XYZ to Lab
xyz_normalized_pixels = torch.mul(xyz_pixels, torch.tensor([1/0.950456, 1.0, 1/1.088754]).type(torch.FloatTensor).cuda())
epsilon = 6.0/29.0
linear_mask = (xyz_normalized_pixels <= (epsilon**3)).type(torch.FloatTensor).cuda()
exponential_mask = (xyz_normalized_pixels > (epsilon**3)).type(torch.FloatTensor).cuda()
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4.0/29.0) * linear_mask + ((xyz_normalized_pixels+0.000001) ** (1.0/3.0)) * exponential_mask
# convert to lab
fxfyfz_to_lab = torch.tensor([
# l a b
[ 0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[ 0.0, 0.0, -200.0], # fz
]).type(torch.FloatTensor).cuda()
lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor([-16.0, 0.0, 0.0]).type(torch.FloatTensor).cuda()
#return tf.reshape(lab_pixels, tf.shape(srgb))
return torch.reshape(lab_pixels, srgb.shape)
def lab_to_rgb(lab):
lab_pixels = torch.reshape(lab, [-1, 3])
# convert to fxfyfz
lab_to_fxfyfz = torch.tensor([
# fx fy fz
[1/116.0, 1/116.0, 1/116.0], # l
[1/500.0, 0.0, 0.0], # a
[ 0.0, 0.0, -1/200.0], # b
]).type(torch.FloatTensor).cuda()
fxfyfz_pixels = torch.mm(lab_pixels + torch.tensor([16.0, 0.0, 0.0]).type(torch.FloatTensor).cuda(), lab_to_fxfyfz)
# convert to xyz
epsilon = 6.0/29.0
linear_mask = (fxfyfz_pixels <= epsilon).type(torch.FloatTensor).cuda()
exponential_mask = (fxfyfz_pixels > epsilon).type(torch.FloatTensor).cuda()
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29.0)) * linear_mask + ((fxfyfz_pixels+0.000001) ** 3) * exponential_mask
# denormalize for D65 white point
xyz_pixels = torch.mul(xyz_pixels, torch.tensor([0.950456, 1.0, 1.088754]).type(torch.FloatTensor).cuda())
xyz_to_rgb = torch.tensor([
# r g b
[ 3.2404542, -0.9692660, 0.0556434], # x
[-1.5371385, 1.8760108, -0.2040259], # y
[-0.4985314, 0.0415560, 1.0572252], # z
]).type(torch.FloatTensor).cuda()
rgb_pixels = torch.mm(xyz_pixels, xyz_to_rgb)
# avoid a slightly negative number messing up the conversion
#clip
rgb_pixels[rgb_pixels > 1] = 1
rgb_pixels[rgb_pixels < 0] = 0
linear_mask = (rgb_pixels <= 0.0031308).type(torch.FloatTensor).cuda()
exponential_mask = (rgb_pixels > 0.0031308).type(torch.FloatTensor).cuda()
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + (((rgb_pixels+0.000001) ** (1/2.4) * 1.055) - 0.055) * exponential_mask
return torch.reshape(srgb_pixels, lab.shape)
# test
'''
img = cv2.imread('data/test_rgb.jpg',1)/ 255.0
img = img[:, :, (2, 1, 0)]
#img = misc.imread('data/test_rgb.jpg')/255.0
img = torch.from_numpy(img).cuda()
lab = rgb_to_lab(img)
L_chan, a_chan, b_chan = preprocess_lab(lab)
lab = deprocess_lab(L_chan, a_chan, b_chan)
true_image = lab_to_rgb(lab)
true_image = np.round(true_image.cpu()* 255.0)
true_image = np.uint8(true_image)
#np.save('torch.npy',np.array(img.cpu()))
#conv_img = Image.fromarray(true_image, 'RGB')
#conv_img.save('converted_test_pytorch.jpg')
true_image = true_image[:, :, (2, 1, 0)]
cv2.imwrite('pytorch.jpg',true_image)
#import pdb; pdb.set_trace()
'''