Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Would you plan to release your train code on raw image dataset(SIDD)? #7

Closed
madfff opened this issue Sep 1, 2021 · 4 comments
Closed

Comments

@madfff
Copy link

madfff commented Sep 1, 2021

Thanks for your great work.
I apply neighbor sub-sampler on the packed 4-channel raw images and cannot reproduce the results.
Did I do anything wrong?
Here is my data-processing code.

NOISY_PATH = ['_NOISY_RAW_010.MAT','_NOISY_RAW_011.MAT']

MODEL_BAYER = {'GP':'BGGR','IP':'RGGB','S6':'GRBG','N6':'BGGR','G4':'BGGR'}

TARGET_PATTERN = 'RGGB'

class Dataset(data.Dataset):
    def __init__(self, path, crop_size, is_train=True):
        super(Dataset, self).__init__()
        self.crop_size = crop_size
        self.is_train = is_train
        
        if self.is_train:
            self.unify_mode = 'crop'
        else:
            self.unify_mode = 'pad'
        self.file_lists = []
        folder_names = os.listdir(os.path.join(path,'Data'))
        
        for folder_name in folder_names:
            scene_instance_number,scene_number,_ = meta_read(folder_name)
            for file_path in NOISY_PATH:
                self.file_lists.append(os.path.join(path, 'Data', folder_name, scene_instance_number+file_path))
    
    def __getitem__(self, index):
        noisy_path = self.file_lists[index]
        gt_path = noisy_path.replace('NOISY', 'GT')
        _,_,bayer_pattern = meta_read(noisy_path.split('/')[-2])
        
        noisy = h5py_loadmat(noisy_path)
        noisy = BayerUnifyAug.bayer_unify(noisy, bayer_pattern, TARGET_PATTERN, self.unify_mode)
        
        gt = h5py_loadmat(gt_path)
        gt = BayerUnifyAug.bayer_unify(gt, bayer_pattern, TARGET_PATTERN, self.unify_mode)
        
        if self.is_train:
            augment = np.random.rand(3) > 0.5
            noisy = BayerUnifyAug.bayer_aug(noisy, augment[0], augment[1], augment[2], TARGET_PATTERN)
            gt = BayerUnifyAug.bayer_aug(gt, augment[0], augment[1], augment[2], TARGET_PATTERN)

        noisy = pack_raw_np(noisy[:,:,None])
        gt = pack_raw_np(gt[:,:,None])
        
        if self.crop_size[0] != 0 and self.crop_size[1] != 0:
            H, W, _ = noisy.shape
            rnd_h = random.randint(0, max(0, H - self.crop_size[0]))
            rnd_w = random.randint(0, max(0, W - self.crop_size[1]))
            
            noisy = noisy[rnd_h:rnd_h + self.crop_size[0], rnd_w:rnd_w + self.crop_size[1], :]
            gt = gt[rnd_h:rnd_h + self.crop_size[0], rnd_w:rnd_w + self.crop_size[1], :]
        
        noisy = torch.from_numpy(noisy.transpose(2, 0, 1))
        gt = torch.from_numpy(gt.transpose(2, 0, 1))
        return noisy, gt, bayer_pattern
    
    def __len__(self):
        return len(self.file_lists)

def meta_read(info):
    info = info.split('_')
    scene_instance_number       = info[0]
    scene_number                = info[1]
    smartphone_code             = info[2]
    #ISO_level                   = info[3]
    #shutter_speed               = info[4]
    #illuminant_temperature      = info[5]
    #illuminant_brightness_code  = info[6]

    return scene_instance_number,scene_number,MODEL_BAYER[smartphone_code]

def pack_raw_np(im):
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]
    ## R G G B
    out = np.concatenate((im[0:H:2,0:W:2,:], 
                       im[0:H:2,1:W:2,:],
                       im[1:H:2,0:W:2,:],
                       im[1:H:2,1:W:2,:]), axis=2)
    return out

def h5py_loadmat(file_path:str):
    with h5py.File(file_path, 'r') as f:
        return np.array(f.get('x'),dtype=np.float32)
@TaoHuang2018
Copy link
Owner

I am sorry for my late reply. For the training of Neighbor2Neighbor in the SIDD dataset:

  1. The code for data preparation is in the newly uploaded file dataset_tool_raw.py. It would be better to consider the camera device difference in the SIDD Medium Set. However, as the device info is not provided in the SIDD validation set, we simply ignore the device difference for training set preparation.
  2. The training code for the SIDD dataset is similar to that for the ImageNet validation set, except some operations for raw images. To mention, the training scheme for the SIDD dataset is also similar, i.e., the number of epoch is 100 and the learning rate decay schedule is the same, except we use a smaller $gamma=1$. Here is some code for processing raw images.
def space_to_depth(x, block_size):
    n, c, h, w = x.size()
    unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
    return unfolded_x.view(n, c * block_size**2, h // block_size,
                           w // block_size)

def depth_to_space(x, block_size):
    return torch.nn.functional.pixel_shuffle(x, block_size)

def generate_mask_pair(img):
    # prepare masks (N x C x H/2 x W/2)
    n, c, h, w = img.shape
    mask1 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
                        dtype=torch.bool,
                        device=img.device)
    mask2 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
                        dtype=torch.bool,
                        device=img.device)
    # prepare random mask pairs
    idx_pair = torch.tensor(
        [[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]],
        dtype=torch.int64,
        device=img.device)
    rd_idx = torch.zeros(size=(n * h // 2 * w // 2, ),
                         dtype=torch.int64,
                         device=img.device)
    torch.randint(low=0,
                  high=8,
                  size=(n * h // 2 * w // 2, ),
                  generator=get_generator(),
                  out=rd_idx)
    rd_pair_idx = idx_pair[rd_idx]
    rd_pair_idx += torch.arange(start=0,
                                end=n * h // 2 * w // 2 * 4,
                                step=4,
                                dtype=torch.int64,
                                device=img.device).reshape(-1, 1)
    # get masks
    mask1[rd_pair_idx[:, 0]] = 1
    mask2[rd_pair_idx[:, 1]] = 1
    return mask1, mask2

def generate_subimages(img, mask):
    n, c, h, w = img.shape
    subimage = torch.zeros(n,
                           c,
                           h // 2,
                           w // 2,
                           dtype=img.dtype,
                           layout=img.layout,
                           device=img.device)
    # per channel
    for i in range(c):
        img_per_channel = space_to_depth(img[:, i:i + 1, :, :], block_size=2)
        img_per_channel = img_per_channel.permute(0, 2, 3, 1).reshape(-1)
        subimage[:, i:i + 1, :, :] = img_per_channel[mask].reshape(
            n, h // 2, w // 2, 1).permute(0, 3, 1, 2)
    return subimage

class DataLoader_SIDD_Medium_Raw(data.Dataset):
    def __init__(self, data_dir):
        super(DataLoader_SIDD_Medium_Raw, self).__init__()
        self.data_dir = data_dir
        # get images path
        self.train_fns = glob.glob(os.path.join(self.data_dir, "*"))
        self.train_fns.sort()
        print('fetch {} samples for training'.format(len(self.train_fns)))
    def __getitem__(self, index):
        # fetch image
        fn = self.train_fns[index]
        im = loadmat(fn)["x"]
        im = im[np.newaxis, :, :]
        im = torch.from_numpy(im)
        return im
    def __len__(self):
        return len(self.train_fns)

def get_SIDD_validation(dataset_dir):
    val_data_dict = loadmat(
        os.path.join(dataset_dir, "ValidationNoisyBlocksRaw.mat"))
    val_data_noisy = val_data_dict['ValidationNoisyBlocksRaw']
    val_data_dict = loadmat(
        os.path.join(dataset_dir, 'ValidationGtBlocksRaw.mat'))
    val_data_gt = val_data_dict['ValidationGtBlocksRaw']
    num_img, num_block, _, _ = val_data_gt.shape
    return num_img, num_block, val_data_noisy, val_data_gt

@madfff
Copy link
Author

madfff commented Sep 22, 2021

I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images?
Thank you for your patience.

@TaoHuang2018
Copy link
Owner

TaoHuang2018 commented Sep 22, 2021

yes, on the packed 4-channel raw images.

@zejinwang
Copy link

I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images? Thank you for your patience.

Hello, I have also encountered a problem that cannot be reproduced on SIDD raw-RGB as high as 51.06dB. I implemented it directly on the source code provided by the author, and psnr can only reach 46.7dB. This is my code:

for epoch in range(1, opt.n_epoch + 1):
cnt = 0

for param_group in optimizer.param_groups:
    current_lr = param_group['lr']
print("LearningRate of Epoch {} = {}".format(epoch, current_lr))

network.train()
for iteration, noisy in enumerate(TrainingLoader):
    st = time.time()
    noisy = noisy.cuda()
    # pack raw data
    noisy = space_to_depth(noisy, 2)

    optimizer.zero_grad()

    mask1, mask2 = generate_mask_pair(noisy)
    noisy_sub1 = generate_subimages(noisy, mask1)
    noisy_sub2 = generate_subimages(noisy, mask2)
    with torch.no_grad():
        noisy_denoised = network(noisy)
    noisy_sub1_denoised = generate_subimages(noisy_denoised, mask1)
    noisy_sub2_denoised = generate_subimages(noisy_denoised, mask2)

    noisy_output = network(noisy_sub1)
    noisy_target = noisy_sub2
    Lambda = epoch / opt.n_epoch * opt.increase_ratio
    diff = noisy_output - noisy_target
    exp_diff = noisy_sub1_denoised - noisy_sub2_denoised

    loss1 = torch.mean(diff**2)
    loss2 = Lambda * torch.mean((diff - exp_diff)**2)
    loss_all = opt.Lambda1 * loss1 + opt.Lambda2 * loss2

    loss_all.backward()
    optimizer.step()
    print(
        '{:04d} {:05d} Loss1={:.6f}, Lambda={}, Loss2={:.6f}, Loss_Full={:.6f}, Time={:.4f}'
        .format(epoch, iteration, np.mean(loss1.item()), Lambda,
                np.mean(loss2.item()), np.mean(loss_all.item()),
                time.time() - st))

scheduler.step()

if epoch % opt.n_snapshot == 0 or epoch == opt.n_epoch:
    network.eval()
    # save checkpoint
    checkpoint(network, epoch, "model")
    # validation
    save_model_path = os.path.join(opt.save_model_path, opt.log_name,
                                   systime)
    validation_path = os.path.join(save_model_path, "validation")
    os.makedirs(validation_path, exist_ok=True)
    np.random.seed(101)

    for valid_name, valid_data in valid_dict.items():
        psnr_result = []
        ssim_result = []
        num_img, num_block, valid_noisy, valid_gt = valid_data
        for idx in range(num_img):
            for idy in range(num_block):
                im = valid_gt[idx, idy][:, :, np.newaxis]
                noisy_im = valid_noisy[idx, idy][:, :, np.newaxis]

                origin255 = im.copy() * 255.0
                origin255 = origin255.astype(np.uint8)
                noisy255 = noisy_im.copy() * 255.0
                noisy255 = noisy255.astype(np.uint8)
                # padding to square
                H = noisy_im.shape[0]
                W = noisy_im.shape[1]
                val_size = (max(H, W) + 31) // 32 * 32
                noisy_im = np.pad(
                    noisy_im,
                    [[0, val_size - H], [0, val_size - W], [0, 0]],
                    'reflect')
                transformer = transforms.Compose([transforms.ToTensor()])
                noisy_im = transformer(noisy_im)
                noisy_im = torch.unsqueeze(noisy_im, 0)
                noisy_im = noisy_im.cuda()
                # pack raw data
                noisy_im = space_to_depth(noisy_im, block_size=2)
                with torch.no_grad():
                    prediction = network(noisy_im)
                    # unpack raw data
                    prediction = depth_to_space(prediction, block_size=2)
                    prediction = prediction[:, :, :H, :W]
                prediction = prediction.permute(0, 2, 3, 1)
                prediction = prediction.cpu().data.clamp(0, 1).numpy()
                prediction = prediction.squeeze(0)
                pred255 = np.clip(prediction * 255.0 + 0.5, 0,
                                    255).astype(np.uint8)
                # calculate psnr
                cur_psnr = calculate_psnr(origin255.astype(np.float32),
                                            pred255.astype(np.float32))
                psnr_result.append(cur_psnr)
                cur_ssim = calculate_ssim(origin255.astype(np.float32),
                                            pred255.astype(np.float32))
                ssim_result.append(cur_ssim)

                # visualization
                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_clean.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(origin255.squeeze()).save(save_path)
                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_noisy.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(noisy255.squeeze()).save(save_path)

                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_denoised.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(pred255.squeeze()).save(save_path)

        psnr_result = np.array(psnr_result)
        avg_psnr = np.mean(psnr_result)
        avg_ssim = np.mean(ssim_result)
        log_path = os.path.join(validation_path,
                                "A_log_{}.csv".format(valid_name))
        with open(log_path, "a") as f:
            f.writelines("{},{},{}\n".format(epoch, avg_psnr, avg_ssim))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants