In [1]:
pip install torch 

Collecting torch
  Obtaining dependency information for torch from https://files.pythonhosted.org/packages/5a/6a/775b93d6888c31f1f1fc457e4f5cc89f0984412d5dcdef792b8f2aa6e812/torch-2.4.1-cp311-cp311-win_amd64.whl.metadata
  Using cached torch-2.4.1-cp311-cp311-win_amd64.whl.metadata (27 kB)
Collecting typing-extensions>=4.8.0 (from torch)
  Obtaining dependency information for typing-extensions>=4.8.0 from https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl.metadata
  Using cached typing_extensions-4.12.2-py3-none-any.whl.metadata (3.0 kB)
Downloading torch-2.4.1-cp311-cp311-win_amd64.whl (199.4 MB)
   ---------------------------------------- 0.0/199.4 MB ? eta -:--:--
   ---------------------------------------- 0.0/199.4 MB ? eta -:--:--
   ---------------------------------------- 0.0/199.4 MB ? eta -:--:--
   ---------------------------------------- 0.0/199.4 MB ? eta -:--:--
   ----------

In [2]:
pip install torchvision 

Collecting torchvision
  Obtaining dependency information for torchvision from https://files.pythonhosted.org/packages/f8/69/dc769cf54df8e828c0b8957b4521f35178f5bd4cc5b8fbe8a37ffd89a27c/torchvision-0.19.1-cp311-cp311-win_amd64.whl.metadata
  Using cached torchvision-0.19.1-cp311-cp311-win_amd64.whl.metadata (6.1 kB)
Downloading torchvision-0.19.1-cp311-cp311-win_amd64.whl (1.3 MB)
   ---------------------------------------- 0.0/1.3 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.3 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.3 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.3 MB ? eta -:--:--
    --------------------------------------- 0.0/1.3 MB 119.1 kB/s eta 0:00:11
    --------------------------------------- 0.0/1.3 MB 119.1 kB/s eta 0:00:11
   - -------------------------------------- 0.0/1.3 MB 115.5 kB/s eta 0:00:11
   - -------------------------------------- 0.0/1.3 MB 115.5 kB/s eta 0:00:11
   - -----------------------

In [3]:
pip install Pillow 

Note: you may need to restart the kernel to use updated packages.


In [4]:
pip install matplotlib 

Note: you may need to restart the kernel to use updated packages.


In [1]:

from model import Unet
from pnp import pnp_admm
from utils import conv2d_from_kernel, compute_psnr, ImagenetDataset, myplot

import torch
import torch.nn.functional as F

import PIL.Image as Image


device = 'cuda'

## Download plug-in denoiser

In [3]:
pip install gdown 

Collecting gdown
  Obtaining dependency information for gdown from https://files.pythonhosted.org/packages/54/70/e07c381e6488a77094f04c85c9caf1c8008cdc30778f7019bc52e5285ef0/gdown-5.2.0-py3-none-any.whl.metadata
  Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)
Downloading gdown-5.2.0-py3-none-any.whl (18 kB)
Installing collected packages: gdown
Successfully installed gdown-5.2.0
Note: you may need to restart the kernel to use updated packages.


In [5]:

import gdown
#url = 'https://drive.google.com/file/d/1FFuauq-PUjY_kG3iiiHfDpHcG4Srl8mQ/view?usp=sharing'
#output = "denoiser.pth"
#gdown.download(url, output, quiet=False,fuzzy=True)

model = Unet(3, 3, chans=64).to(device)
model.load_state_dict(torch.load('denoiser.pth', map_location=device))
print('#Parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad))

AssertionError: Torch not compiled with CUDA enabled

## Plug and play examples
#### Load test image

In [None]:
# Get sample
test_image = Image.open('figs/test_image.png').convert("RGB")
test_image = ImagenetDataset([]).test_transform(test_image)
# test_image = dataset[1231]['target'] # from dataset
channels, h, w = test_image.shape
test_image = test_image.unsqueeze(0).to(device) 

#### Motion deblur

In [None]:
# Motion deblur
kernel_size = 21
kernel_motion_blur = torch.ones((1, kernel_size))
forward, forward_adjoint = conv2d_from_kernel(kernel_motion_blur, channels, device)

# Run plug and play
y = forward(test_image)
with torch.no_grad():
    model.eval()
    x = pnp_admm(y, forward, forward_adjoint, model)
    x = x.clip(0,1)

# Plot
print('PSNR [dB]: {:.2f}'.format(compute_psnr(x, test_image)))
myplot(F.pad(y, (kernel_size//2, kernel_size//2)), x, test_image) 

#### Inpainting

In [None]:
# Inpainting
mask = torch.rand(1,1,h,w).to(device)
mask = mask < 0.2

def forward(x):
    return x*mask
forward_adjoint = forward

# Run plug and play
y = forward(test_image)
with torch.no_grad():
    model.eval()
    x = pnp_admm(y, forward, forward_adjoint, model, num_iter=100)
    x = x.clip(0,1)

# Plot
print('PSNR [dB]: {:.2f}'.format(compute_psnr(x, test_image)))
myplot(y, x, test_image)

#### Super-resolution

In [None]:
# Super-resolultion
kernel_size = 4
kernel_downsampling = torch.ones((kernel_size, kernel_size))
forward, forward_adjoint = conv2d_from_kernel(kernel_downsampling, channels, device, stride=kernel_size)

# Run plug and play
y = forward(test_image)
with torch.no_grad():
    model.eval()
    x = pnp_admm(y, forward, forward_adjoint, model, num_iter=100, max_cgiter=30, cg_tol=1e-4)
    x = x.clip(0,1)

# Plot
print('PSNR [dB]: {:.2f}'.format(compute_psnr(x, test_image)))
myplot(y, x, test_image) 