In [1]:
from benchopt import BaseDataset, safe_import_context
import random
import torch
from deepinv.physics import Denoising, GaussianNoise
import imageio.v3 as iio
import numpy as np
import glob

In [2]:
image_path = "./../data/images/BSD/train/"
random_state = 27
extension = ".png"
device = "cpu"
sigma = 0.1
n_samples = 10

In [3]:
it = 22

print("{:03d}".format(it))

022


In [4]:
# Generate pseudorandom data using `numpy`.
random.seed(random_state)
torch.manual_seed(random_state)

<torch._C.Generator at 0x28646546f70>

In [5]:
# Load the data
file_list = list(glob.glob(image_path + "*" + extension))
random.shuffle(file_list)

In [6]:
# Load images into a list
gt_img_list = [] # torch.zeros(self.n_samples)
for it in range(n_samples):
    gt_img = np.array(
        iio.imread(file_list[it])
    )
    # Scale to [0,1]
    gt_img = (gt_img - gt_img.min())/(gt_img.max() - gt_img.min())

    gt_img_list.append(gt_img)

In [7]:
x = torch.tensor(
    np.array(gt_img_list), dtype=torch.float32, device=device
) 


In [24]:
x = x[:,None,:,:]


In [25]:
# Define the forward model
physics = Denoising(noise=GaussianNoise(sigma))
# Generate the observations 
y = physics(x)

In [26]:
dict(x=x, y=y, physics=physics)

{'x': tensor([[[[0.9957, 0.9957, 0.9957,  ..., 0.1082, 0.1991, 0.2900],
           [1.0000, 0.9957, 0.9957,  ..., 0.1385, 0.1169, 0.1732],
           [1.0000, 1.0000, 1.0000,  ..., 0.1082, 0.1515, 0.1169],
           ...,
           [0.8312, 0.8658, 0.9048,  ..., 0.2294, 0.3463, 0.3680],
           [0.8528, 0.8571, 0.8658,  ..., 0.2684, 0.3680, 0.2900],
           [0.9221, 0.9394, 0.9870,  ..., 0.2814, 0.2814, 0.2554]]],
 
 
         [[[0.6452, 0.6452, 0.6406,  ..., 0.6129, 0.6083, 0.6083],
           [0.6406, 0.6406, 0.6359,  ..., 0.6129, 0.6129, 0.6129],
           [0.6452, 0.6406, 0.6313,  ..., 0.6221, 0.6175, 0.6175],
           ...,
           [0.0276, 0.0276, 0.0276,  ..., 0.0783, 0.0922, 0.0876],
           [0.0276, 0.0276, 0.0230,  ..., 0.1336, 0.1705, 0.1336],
           [0.0276, 0.0276, 0.0276,  ..., 0.1659, 0.2166, 0.1843]]],
 
 
         [[[0.4416, 0.4569, 0.4569,  ..., 0.1827, 0.1878, 0.1827],
           [0.4569, 0.4721, 0.4721,  ..., 0.1929, 0.1827, 0.1777],
           [0

In [27]:
x.shape

torch.Size([10, 1, 180, 180])

In [28]:
y.shape

torch.Size([10, 1, 180, 180])

In [29]:
physics

Denoising(
  (noise_model): GaussianNoise()
)

In [30]:
import deepinv as dinv
likelihood = dinv.optim.L2(sigma=physics.noise_model.sigma)

In [31]:
likelihood(x, y, physics)

tensor([16294.7803, 16295.4180, 16017.9814, 16102.9668, 16456.5098, 15978.1602,
        16197.1289, 16365.1260, 16134.1201, 16148.5723])

In [33]:
prior = dinv.optim.ScorePrior(
                denoiser=dinv.models.DnCNN(
                    pretrained="download_lipschitz",
                    in_channels=1,
                    out_channels=1,
                    device="cpu"
                )
            )

print(prior)

print(x.shape)

prior.grad(x, physics.noise_model.sigma)

ScorePrior(
  (denoiser): DnCNN(
    (in_conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv_list): ModuleList(
      (0-17): 18 x Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (out_conv): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (nl_list): ModuleList(
      (0-18): 19 x ReLU()
    )
  )
)
torch.Size([10, 1, 180, 180])


tensor([[[[-1.2280e-01, -1.1879e-01,  9.8234e-02,  ..., -2.8033e-01,
            2.3569e-01,  5.6019e-01],
          [ 1.3737e-01, -1.2459e-01,  4.2325e-02,  ...,  4.5015e-01,
           -6.7659e-01,  3.3575e-01],
          [ 1.5595e-01,  2.5843e-01,  3.3464e-01,  ..., -6.9703e-01,
            4.6601e-01, -5.0129e-01],
          ...,
          [-5.3582e-01, -4.1361e-01,  1.2338e-02,  ..., -6.9946e-01,
            2.9644e-01,  5.7030e-01],
          [-1.6011e-01,  1.9491e-02, -8.2583e-01,  ..., -3.8725e-02,
            8.0482e-01, -5.4622e-01],
          [-6.5970e-02, -1.3337e-01,  6.7149e-01,  ..., -9.8485e-02,
           -4.8998e-02, -1.0541e-01]]],


        [[[ 1.5827e-01,  3.8451e-01,  3.4815e-02,  ...,  5.5313e-02,
           -1.4835e-01, -3.2859e-01],
          [-1.9144e-01,  6.4850e-02, -1.2977e-01,  ..., -2.3425e-01,
           -2.1195e-02, -2.0820e-01],
          [ 1.3366e-01, -2.4992e-02, -6.7228e-01,  ...,  1.0828e-01,
           -2.2799e-01, -2.1515e-01],
          ...,
   