-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Would you plan to release your train code on raw image dataset(SIDD)? #7
Comments
I am sorry for my late reply. For the training of Neighbor2Neighbor in the SIDD dataset:
def space_to_depth(x, block_size):
n, c, h, w = x.size()
unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
return unfolded_x.view(n, c * block_size**2, h // block_size,
w // block_size)
def depth_to_space(x, block_size):
return torch.nn.functional.pixel_shuffle(x, block_size)
def generate_mask_pair(img):
# prepare masks (N x C x H/2 x W/2)
n, c, h, w = img.shape
mask1 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
dtype=torch.bool,
device=img.device)
mask2 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
dtype=torch.bool,
device=img.device)
# prepare random mask pairs
idx_pair = torch.tensor(
[[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]],
dtype=torch.int64,
device=img.device)
rd_idx = torch.zeros(size=(n * h // 2 * w // 2, ),
dtype=torch.int64,
device=img.device)
torch.randint(low=0,
high=8,
size=(n * h // 2 * w // 2, ),
generator=get_generator(),
out=rd_idx)
rd_pair_idx = idx_pair[rd_idx]
rd_pair_idx += torch.arange(start=0,
end=n * h // 2 * w // 2 * 4,
step=4,
dtype=torch.int64,
device=img.device).reshape(-1, 1)
# get masks
mask1[rd_pair_idx[:, 0]] = 1
mask2[rd_pair_idx[:, 1]] = 1
return mask1, mask2
def generate_subimages(img, mask):
n, c, h, w = img.shape
subimage = torch.zeros(n,
c,
h // 2,
w // 2,
dtype=img.dtype,
layout=img.layout,
device=img.device)
# per channel
for i in range(c):
img_per_channel = space_to_depth(img[:, i:i + 1, :, :], block_size=2)
img_per_channel = img_per_channel.permute(0, 2, 3, 1).reshape(-1)
subimage[:, i:i + 1, :, :] = img_per_channel[mask].reshape(
n, h // 2, w // 2, 1).permute(0, 3, 1, 2)
return subimage
class DataLoader_SIDD_Medium_Raw(data.Dataset):
def __init__(self, data_dir):
super(DataLoader_SIDD_Medium_Raw, self).__init__()
self.data_dir = data_dir
# get images path
self.train_fns = glob.glob(os.path.join(self.data_dir, "*"))
self.train_fns.sort()
print('fetch {} samples for training'.format(len(self.train_fns)))
def __getitem__(self, index):
# fetch image
fn = self.train_fns[index]
im = loadmat(fn)["x"]
im = im[np.newaxis, :, :]
im = torch.from_numpy(im)
return im
def __len__(self):
return len(self.train_fns)
def get_SIDD_validation(dataset_dir):
val_data_dict = loadmat(
os.path.join(dataset_dir, "ValidationNoisyBlocksRaw.mat"))
val_data_noisy = val_data_dict['ValidationNoisyBlocksRaw']
val_data_dict = loadmat(
os.path.join(dataset_dir, 'ValidationGtBlocksRaw.mat'))
val_data_gt = val_data_dict['ValidationGtBlocksRaw']
num_img, num_block, _, _ = val_data_gt.shape
return num_img, num_block, val_data_noisy, val_data_gt |
I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images? |
yes, on the packed 4-channel raw images. |
Hello, I have also encountered a problem that cannot be reproduced on SIDD raw-RGB as high as 51.06dB. I implemented it directly on the source code provided by the author, and psnr can only reach 46.7dB. This is my code: for epoch in range(1, opt.n_epoch + 1):
|
Thanks for your great work.
I apply neighbor sub-sampler on the packed 4-channel raw images and cannot reproduce the results.
Did I do anything wrong?
Here is my data-processing code.
The text was updated successfully, but these errors were encountered: