# Using Ring Deconvolution Microscopy (RDM)

This is a notebook which will walk you through basics of rdmpy.

Interested in some background on these methods? Check out the [paper](https://arxiv.org/abs/2206.08928).

The data we use in this notebook comes from the [UCLA Miniscope](http://miniscope.org/index.php/Main_Page).

Refer to the other notebooks in the repository for replication of the experiments in the paper.


## Setup

Here we import some necessary packages/helper functions for the demo, including the star of the show: rdmpy

In [None]:
%load_ext autoreload
%autoreload 2
!export CUDA_VISIBLE_DEVICES=0,1,2,3 # REPLACE this line according to your GPUs
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3" # REPLACE this line according to your GPUs

import rdmpy

# here are some basics we will need for the demo
import torch
import numpy as np
import matplotlib.pyplot as plt

from skimage import io
from PIL import Image

# Fill this in according to your computer, we highly recommend using a GPU.
# We needed ~20GB of GPU memory to run at full resolution (1024 x 1024).
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
    
    
print('Using ' +str(device) + ' for computation')

def center_crop(img, des_shape, center=None):
    if center is None:
        center = (img.shape[0] // 2, img.shape[1] // 2)
    left, right, up, down = (
        center[1] - des_shape[1] // 2,
        center[1] + int(np.round(des_shape[1] / 2)),
        center[0] - des_shape[0] // 2,
        center[0] + int(np.round(des_shape[0] / 2)),
    )
    img = img[up:down, left:right]
    return img

def crop(img, c):
    return img[c:-c,c:-c]


## One-time Calibration

The first step in the RDM pipeline is to calibrate your microscope. This is very similar to measuring a point spread function for deconvolution. The only difference is now you take a single image of randomly-placed point sources. We visualize one such calibration image from the Miniscope in the cell below. 

In [None]:
dim = 512 # Pick this according to how big of an image you want to deblur. Reduce if memory errors.

In [None]:
calibration_image = np.array(Image.open('test_images/calibration_image.tif'))
#calibration_image = np.array(Image.open('/media/al/hs_results/20250701_AI/20250706_results/20250425_0gan_single_reg_hs_100_points_512/test_latest/images/hs_gen_99.tif'))
#background = np.array(Image.open('/media/al/Extreme SSD/20250425_results/results/20250705/results/20250425_0gan_single_reg_hs_100_points/test_latest/average_stack.tif'))

#calibration_image = calibration_image - 1 * background
#calibration_image = np.clip(calibration_image, 0, None)  # Remove negative values

import tifffile
calibration_image = tifffile.imread('/media/al/20250701_AI/hs_results/20250706_results/20250425_0gan_single_reg_hs_100_points_512/test_latest/images/hs_gen_99.tif')
calibration_image = calibration_image[20, :, :]


print('Calibration image shape: ', calibration_image.shape)

plt.imshow(calibration_image, cmap='gray')
plt.show()

To calibrate, just input this image into the `calibrate` function. This will give back Seidel coefficients and synthetic PSFs. These characterize the system and can be directly used to blur/deblur any image taken from the microscope.

In [None]:
# sys_params = {"NA": 0.5}
seidel_coeffs = rdmpy.calibrate_rdm(calibration_image, dim=512, device=device, show_psfs=True, get_psfs=False)

## Simulating blur with ring convolution
Now that we have calibrated the system, we can simulate the spatially-varying blur using ring convolution. Normally this process would take over an hour, even for 512x512 image! Check out how fast it is with ring convolution. If you want to speed it up even more or reduce memory, try setting patch_size to something like 8 or 16.

In [None]:
psf_data = rdmpy.get_rdm_psfs(seidel_coeffs, dim=dim, model='lri', device=device) #optional set patch_size=16


In [None]:
path = '/media/al/20250701_AI/20250706_results/20250425_0gan_single_reg_hs_100_points_512/test_latest/99point_stack_ch20.pt'

# Save the variables.
def save_variables(seidel_coeffs, psf_data, path):
    torch.save({'seidel_coeffs': seidel_coeffs, 'psf_data': psf_data}, path)

# Example use: 
    
save_variables(seidel_coeffs, psf_data, path)

In [None]:
import rdmpy

# here are some basics we will need for the demo
import torch
import numpy as np
import matplotlib.pyplot as plt

from skimage import io, color, img_as_ubyte
from skimage.transform import resize
from PIL import Image
import os
import time
import tifffile

# Fill this in according to your computer, we highly recommend using a GPU.
# We needed ~20GB of GPU memory to run at full resolution.
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
    
    
print('Using ' +str(device) + ' for computation')

# Load the variables.
def load_variables(path):
    loaded_data = torch.load(path)
    return loaded_data['seidel_coeffs'], loaded_data['psf_data']

path = '/media/al/20250701_AI/hs_results/20250706_results/20250425_0gan_single_reg_hs_100_points_512/test_latest/99point_stack.pt'

seidel_coeffs, psf_data = load_variables(path)# Determine the correct angular dimension size from psf_roft
#print(seidel_coeffs.shape)
dim = 512

print(psf_data.shape)

In [None]:
# Here is a test image
#measurement = plt.imread('test_images/baboon.png')
measurement = plt.imread('/media/al/Extreme SSD/20250701_usaf/20250425_0gan_single_reg_hs_usaf/test_latest/images/hs_gen_0-0020.png')
measurement = tifffile.imread('/media/al/Extreme SSD/20250701_usaf/20250425_0gan_single_reg_hs_usaf/test_latest/images/hs_gen_0.tif')
print('Measurement shape: ', measurement.shape)
measurement = measurement[10, :, :] 
measurement = measurement/measurement.max()
plt.imshow(measurement, cmap='gray')
plt.show()

## Deblurring time: Ring deconvolution
Now that we have a blurry, noisy measurement and the Seidel PSFs from the calibration procedure, we can run our main algorithm Ring deconvolution! 


In [None]:
#crop to get rid of edge artifacts caused by finite size PSF
c = 10

#now deblur with ring deconvolution
rd_recon = rdmpy.ring_deconvolve(measurement, psf_data, iters=150, device=device, process=False) #optional set patch_size=16
#rd_recon = rdmpy.ring_deconvolve(measurement, psf_data, iters=200, lr=5e-2, tv_reg=1e-11, l2_reg=1e-6, opt_params={"upper_projection": True}, process=True, hot_pixel=True, device=device)

plt.imsave('/media/al/Extreme SSD/20250701_usaf/deconvolution_20250706/rd_100_ch20_0.png', rd_recon, cmap='gray')
plt.imshow(crop(rd_recon,c), cmap='gray')
plt.show()

## Faster alternative: DeepRD

In [None]:
#Now with DeepRD, an even faster solution!
deeprd_recon = rdmpy.deeprd(measurement, seidel_coeffs, device=device, process=False)

plt.imsave('/media/al/Extreme SSD/20250701_usaf/deconvolution_20250706/deeprd_100_ch20_0.png', deeprd_recon, cmap='gray')
plt.imshow(crop(deeprd_recon,c), cmap='gray')
plt.show()

## Alternative methods: Seidel deconvolution
Sometimes the system is sufficiently spatially-invariant. Even so RDM can help by using the synthetic Seidel PSFs for noiseless deconvolution. There is also a blind version!

In [None]:
# First we get the center PSF from our fitted Seidel coefficients
center_psf = rdmpy.get_rdm_psfs(seidel_coeffs, dim, model='lsi', device=device)

# Now we can do standard deconvolution
deconv = rdmpy.deconvolve(measurement, center_psf, device=device, process=False)

plt.imsave('/media/al/Extreme SSD/20250701_usaf/deconvolution_20250706/wiener_100_ch20_newer.png', deconv, cmap='gray')
plt.imshow(deconv, cmap='gray')
plt.show()

## Batch Processing: Generate Calibration Files for All 30 Channels

This code block will process calibration images to generate calibration.pt files for all 30 channels. Each channel will be calibrated separately and saved as a .pt file.

In [None]:
import os
import glob
import tifffile
import torch
import numpy as np
from pathlib import Path
import rdmpy


# Configuration
calibration_folder = '/media/al/20250701_AI/hs_results/20250706_results/20250425_0gan_single_reg_hs_100_points_512/test_latest/images/'
calibration_pattern = 'hs_gen_99.tif'  # Pattern for calibration files
output_folder = '/media/al/20250701_AI/calibration_files/'
dim = 512
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Create output directory
os.makedirs(output_folder, exist_ok=True)

# Find calibration files
calibration_files = glob.glob(os.path.join(calibration_folder, calibration_pattern))
print(f"Found {len(calibration_files)} calibration files")

# Process each calibration file
for cal_file in calibration_files:
    print(f"\nProcessing: {cal_file}")
    
    # Load the calibration image stack
    calibration_stack = tifffile.imread(cal_file)
    print(f"Calibration stack shape: {calibration_stack.shape}")
    
    # Process first 30 channels
    num_channels = min(30, calibration_stack.shape[0])
    
    for channel in range(num_channels):
        print(f"Processing channel {channel + 1}/{num_channels}")
        
        # Extract single channel
        calibration_image = calibration_stack[channel, :, :]
        
        # Generate unique filename for this channel
        base_name = Path(cal_file).stem
        output_file = os.path.join(output_folder, f"{base_name}_ch{channel:02d}_calibration.pt")
        
        # Skip if already exists
        if os.path.exists(output_file):
            print(f"  Channel {channel} already exists, skipping...")
            continue
            
        try:
            # Calibrate the system for this channel
            # Note: Using show_psfs=True to avoid the 'psfs_gt' error in rdmpy
            seidel_coeffs = rdmpy.calibrate_rdm(calibration_image, dim=dim, device=device, show_psfs=True, get_psfs=False)
            
            # Get the center PSF for spatially invariant deconvolution
            center_psf = rdmpy.get_rdm_psfs(seidel_coeffs, dim, model='lsi', device=device)
            
            # Save calibration data
            torch.save({
                'seidel_coeffs': seidel_coeffs,
                'center_psf': center_psf,
                'channel': channel,
                'dim': dim,
                'calibration_file': cal_file
            }, output_file)
            
            print(f"  Saved calibration for channel {channel} to {output_file}")
            
        except Exception as e:
            print(f"  Error processing channel {channel}: {str(e)}")
            continue

print("\nBatch calibration processing completed!")

## Batch Processing: Wiener Deconvolution for TIF Stacks

This code block will process a folder of TIF stack images, performing spatially invariant Wiener deconvolution on each channel using the corresponding calibration file.

In [None]:
import os
import glob
import tifffile
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import rdmpy

# Configuration
measurement_folder = '/media/al/Extreme SSD/20250701_usaf/20250425_0gan_single_reg_hs_usaf/test_latest/images' 
measurement_folder = '/media/al/Extreme SSD/20250425_results/results/20250704/results/20250425_0gan_single_reg_hs_3/test_latest/images'
measurement_folder = '/media/al/Extreme SSD/20250701_usaf/lower_res/results/20250425_0gan_single_reg_hs_usaf/20250425_0gan_single_reg_hs/test_latest/images'
measurement_folder = '/media/al/Extreme SSD/20250701_usaf/20250708_results/20250425_0gan_single_reg_hs_usaf/test_latest/images'
measurement_folder = '/media/al/Extreme SSD/20250425_results/results/misc_test_dataset_layernorm/results/20250425_0gan_single_reg_hs_suberine/test_latest/images'
calibration_folder = '/media/al/20250701_AI/calibration_files/'  # Where calibration .pt files are stored
output_folder = '/media/al/Extreme SSD/20250425_results/results/misc_test_dataset_layernorm/results/20250425_0gan_single_reg_hs_suberine/test_latest/deconvolved_images'
measurement_pattern = '*.tif'  # Pattern for measurement files
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Create output directory
os.makedirs(output_folder, exist_ok=True)

# Find measurement files
measurement_files = glob.glob(os.path.join(measurement_folder, measurement_pattern))
print(f"Found {len(measurement_files)} measurement files")

# Process each measurement file
for meas_file in measurement_files:
    print(f"\nProcessing: {meas_file}")
    
    # Load the measurement image stack
    measurement_stack = tifffile.imread(meas_file)
    print(f"Measurement stack shape: {measurement_stack.shape}")
    
    # Process first 30 channels
    num_channels = min(30, measurement_stack.shape[0])
    
    # Prepare output stack
    deconv_stack = np.zeros_like(measurement_stack[:num_channels])
    
    for channel in range(num_channels):
        print(f"Processing channel {channel + 1}/{num_channels}")
        
        # Extract single channel
        measurement_stack = measurement_stack /measurement_stack.max()  # Normalize
        measurement_image = measurement_stack[channel, :, :]

        # Find corresponding calibration file
        # Assuming calibration files are named like: hs_gen_99_ch00_calibration.pt
        cal_file = os.path.join(calibration_folder, f"hs_gen_99_ch{channel:02d}_calibration.pt")
        #cal_file = os.path.join(calibration_folder, f"hs_gen_99_ch{20}_calibration.pt")

        
        if not os.path.exists(cal_file):
            print(f"  Calibration file not found for channel {channel}: {cal_file}")
            # Use original image if no calibration available
            deconv_stack[channel] = measurement_image
            continue
            
        try:
            # Load calibration data
            cal_data = torch.load(cal_file, map_location=device)
            center_psf = cal_data['center_psf']
            
            # Perform Wiener deconvolution
            deconv_result = rdmpy.deconvolve(measurement_image, center_psf, device=device, process=False)
            
            # Store result
            deconv_stack[channel] = deconv_result.cpu().numpy() if torch.is_tensor(deconv_result) else deconv_result
            
            print(f"  Successfully deconvolved channel {channel}")
            
        except Exception as e:
            print(f"  Error processing channel {channel}: {str(e)}")
            # Use original image if deconvolution fails
            deconv_stack[channel] = measurement_image
            continue
    
    # Save deconvolved stack
    base_name = Path(meas_file).stem
    output_file = os.path.join(output_folder, f"{base_name}_deconvolved.tif")
    
    # Save as 32-bit float to preserve dynamic range
    tifffile.imwrite(output_file, deconv_stack.astype(np.float32))
    print(f"Saved deconvolved stack to: {output_file}")
    
    # Optional: Save individual channels as PNG for visualization
    png_folder = os.path.join(output_folder, f"{base_name}_channels")
    os.makedirs(png_folder, exist_ok=True)
    
    for channel in range(num_channels):
        png_file = os.path.join(png_folder, f"channel_{channel:02d}.png")
        plt.imsave(png_file, deconv_stack[channel], cmap='gray')
    
    print(f"Saved individual channel PNGs to: {png_folder}")

print("\nBatch deconvolution processing completed!")

In [None]:
import os
import glob
import tifffile
import numpy as np
from pathlib import Path
from matplotlib import cm
from PIL import Image

# Configuration - matching visualize_models_misc.py exactly
clip_min = 0.0
clip_max = 0.6  # Adjust this to match your data range (0.03, 0.2, 0.3, etc.)
selected_channels = [0, 12, 23]  # Channels 700, 600, 500 nm (adjust as needed)

# Folders to process - both original and intensity preserved deconvolution results
folders_to_process = [
    '/media/al/Extreme SSD/20250425_results/results/misc_test_dataset_layernorm/results/20250425_0gan_single_reg_hs_suberine/test_latest/deconvolved_images',
]

def apply_viridis_colormap_and_clipping(img_data, output_dir, base_name, method_suffix=""):
    """Apply viridis colormap and clipping exactly like visualize_models_misc.py"""
    try:
        # Check if image has enough channels for hyperspectral processing
        if img_data.ndim == 3 and img_data.shape[0] > 3:
            # Format is [channels, height, width]
            for i, channel in enumerate(selected_channels):
                if channel < img_data.shape[0]:
                    selected_channel_data = img_data[channel, :, :]
                    
                    # Normalize data using min/max clipping (EXACT same as visualize_models_misc.py)
                    clipped_data = np.clip(selected_channel_data, clip_min, clip_max)
                    channel_norm = ((clipped_data - clip_min) / (clip_max - clip_min) * 255).astype(np.uint8)
                    
                    # Apply viridis colormap (EXACT same as visualize_models_misc.py)
                    colored_img = cm.viridis(channel_norm)
                    colored_img = (colored_img[:, :, :3] * 255).astype(np.uint8)
                    
                    # Save the image with same naming convention
                    output_path = os.path.join(output_dir, f"{base_name}_ch{channel}_clip[{clip_min},{clip_max}]{method_suffix}.png")
                    img = Image.fromarray(colored_img)
                    img.save(output_path)
                    print(f"Saved visualization: {base_name}_ch{channel}_clip[{clip_min},{clip_max}]{method_suffix}.png")
            return True
        else:
            print(f"Not enough channels for hyperspectral processing (shape: {img_data.shape})")
            return False
            
    except Exception as e:
        print(f"Error applying colormap: {str(e)}")
        return False

# Process both deconvolution result folders
for folder_path in folders_to_process:
    if not os.path.exists(folder_path):
        print(f"Folder not found: {folder_path}")
        continue
        
    print(f"\nProcessing folder: {folder_path}")
    
    # Determine method suffix based on folder name
    if "intensity_preserved" in folder_path:
        method_suffix = "_intensity_preserved"
    else:
        method_suffix = "_original"
    
    # Create visualization output directory
    vis_output_dir = os.path.join(folder_path, "viridis_visualizations")
    os.makedirs(vis_output_dir, exist_ok=True)
    
    # Find all deconvolved TIF files
    tif_files = glob.glob(os.path.join(folder_path, "*_deconvolved.tif"))
    
    for tif_file in tif_files:
        print(f"Visualizing: {tif_file}")
        
        try:
            # Load the deconvolved image stack
            img_data = tifffile.imread(tif_file)
            print(f"Loaded image with shape: {img_data.shape}")
            
            # Extract base name
            base_name = Path(tif_file).stem.replace("_deconvolved", "")
            
            # Apply viridis colormap and clipping (same as visualize_models_misc.py)
            apply_viridis_colormap_and_clipping(img_data, vis_output_dir, base_name, method_suffix)
            
        except Exception as e:
            print(f"Error processing {tif_file}: {str(e)}")

print(f"\nVisualization processing completed!")
print(f"Results saved to 'viridis_visualizations' folders in both deconvolution output directories")
print(f"Using clipping range: [{clip_min}, {clip_max}]")
print(f"Selected channels: {selected_channels}")
print(f"Colormap: viridis (same as visualize_models_misc.py)")

## Alternate methods: Blind deconvolution
By co-optimizing the spherical Seidel coefficient with the reconstruction we can perform calibration-free deconvolution

In [None]:
# note, we don't use the PSF data or Seidel coefficients; this method only requires a blurry image.
deconv_blind = rdmpy.blind(measurement, device=device, process=True, balance=3e-4)

print('Pretty close to the fitted coefficient: ', seidel_coeffs[0].cpu())
plt.imsave('/media/al/Extreme SSD/20250701_usaf/deconvolution_20250706/blind_3e-4.png', deconv_blind, cmap='gray')
plt.imshow(crop(deconv_blind,c), cmap='gray')
plt.show()


Batch blind DC

In [None]:
import os
import glob
import tifffile
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import rdmpy

# Configuration
measurement_folder = '/media/al/Extreme SSD/20250701_usaf/20250425_0gan_single_reg_hs_usaf/test_latest/images'  # Update this path
output_folder = '/media/al/Extreme SSD/20250701_usaf/deconvolution_results/batch_blind'
measurement_pattern = '*.tif'  # Pattern for measurement files
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Blind deconvolution parameters
balance = 3e-4  # Adjust this parameter as needed

# Create output directory
os.makedirs(output_folder, exist_ok=True)

# Find measurement files
measurement_files = glob.glob(os.path.join(measurement_folder, measurement_pattern))
print(f"Found {len(measurement_files)} measurement files")

# Process each measurement file
for meas_file in measurement_files:
    print(f"\nProcessing: {meas_file}")
    
    # Load the measurement image stack
    measurement_stack = tifffile.imread(meas_file)
    print(f"Measurement stack shape: {measurement_stack.shape}")
    
    # Process first 30 channels
    num_channels = min(30, measurement_stack.shape[0])
    
    # Prepare output stack
    deconv_stack = np.zeros_like(measurement_stack[:num_channels])
    
    for channel in range(num_channels):
        print(f"Processing channel {channel + 1}/{num_channels}")
        
        # Extract single channel
        measurement_image = measurement_stack[channel, :, :]
        measurement_image = measurement_image / measurement_image.max()  # Normalize
        
        try:
            # Perform blind deconvolution
            deconv_result = rdmpy.blind(measurement_image, device=device, process=True, balance=balance)
            
            # Store result
            deconv_stack[channel] = deconv_result.cpu().numpy() if torch.is_tensor(deconv_result) else deconv_result
            
            print(f"  Successfully deconvolved channel {channel}")
            
        except Exception as e:
            print(f"  Error processing channel {channel}: {str(e)}")
            # Use original image if deconvolution fails
            deconv_stack[channel] = measurement_image
            continue
    
    # Save deconvolved stack
    base_name = Path(meas_file).stem
    output_file = os.path.join(output_folder, f"{base_name}_blind_deconvolved.tif")
    
    # Save as 32-bit float to preserve dynamic range
    tifffile.imwrite(output_file, deconv_stack.astype(np.float32))
    print(f"Saved blind deconvolved stack to: {output_file}")
    
    # Optional: Save individual channels as PNG for visualization
    png_folder = os.path.join(output_folder, f"{base_name}_blind_channels")
    os.makedirs(png_folder, exist_ok=True)
    
    for channel in range(num_channels):
        png_file = os.path.join(png_folder, f"channel_{channel:02d}.png")
        plt.imsave(png_file, deconv_stack[channel], cmap='gray')
    
    print(f"Saved individual channel PNGs to: {png_folder}")

print("\nBatch blind deconvolution processing completed!")