In [1]:
import torch
import torch.nn as nn
from torchkbnufft import KbNufft, KbNufftAdjoint
import numpy as np
import h5py
from einops import rearrange
import nibabel as nib

In [2]:
class RadialDCLayerSingleCoil(nn.Module):
    def __init__(
        self,
        im_size,
        grid_size,
        lambda_init: float = np.log(np.exp(1) - 1.0) / 1.0,
        learnable: bool = True,
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    ):
        super().__init__()
        self.device = device
        self.lambda_ = nn.Parameter(torch.ones(1) * lambda_init, requires_grad=learnable).to(device)

        # Forward NUFFT and adjoint (no smaps)
        self.nufft_op   = KbNufft(im_size=im_size, grid_size=grid_size).to(device)
        self.adjnufft_op = KbNufftAdjoint(im_size=im_size, grid_size=grid_size).to(device)

    def forward(self, x, y_radial, ktraj):
        """
        x: (batch, 1, H, W, 2) real
        y_radial: (n_samples, n_spokes, 2) real
        ktraj: (1, 2, n_samples*n_spokes) float
        """
        # x_c = x.to(dtype=torch.complex64)

        # (1) Forward NUFFT (no smaps => single‐coil)
        A_x = self.nufft_op(x.contiguous(), ktraj.contiguous(), smaps=None, norm='ortho')  # (batch, 1, n_samples*n_spokes, 2)

        # reshape simulated k-space
        A_x = rearrange(A_x, "b c r i -> b c i r ")#.to(dtype)
        A_x = torch.reshape(A_x, (1, 1, 2, 288, 640, 1)).squeeze()
        A_x = rearrange(A_x, 'i sp sam -> sam sp i')

        # (2) Weighted combine
        lambda_c = torch.sigmoid(self.lambda_).type(torch.complex64)
        k_dc = lambda_c * A_x + (1 - lambda_c) * y_radial

        # (3) Adjoint NUFFT
        x_dc = self.adjnufft_op(k_dc, ktraj, smaps=None, norm='ortho')  # (batch, 1, H, W)
        return x_dc
    
    def extra_repr(self):
        return f"lambda (raw)={self.lambda_.item():.4g}, learnable={self.lambda_.requires_grad}"



In [3]:
def get_ktraj(N_spokes, N_time, base_res, device):
    """
    Precompute k-space trajectory for efficiency.
    """
    N_tot_spokes = N_spokes * N_time
    N_samples = base_res * 2

    base_lin = torch.arange(N_samples, dtype=torch.float32).to(device) - base_res
    tau = 0.5 * (1 + 5**0.5)
    base_rad = torch.pi / (1 + tau - 1)
    base_rot = torch.arange(N_tot_spokes, dtype=torch.float32).to(device) * base_rad

    traj_x = torch.cos(base_rot).unsqueeze(1) @ base_lin.unsqueeze(0)
    traj_y = torch.sin(base_rot).unsqueeze(1) @ base_lin.unsqueeze(0)
    traj = torch.stack([traj_x, traj_y], dim=-1) / 2  # Shape: (N_tot_spokes, N_samples, 2)

    # reshape the trajectory to be compatible with torchkbnufft
    traj = traj.reshape(N_time, N_spokes * N_samples, 2)#.transpose(1, 0, 2)
    # traj = rearrange(traj, 't len i -> t i len')
    traj = rearrange(traj, 't len i -> len t i')
    
    # normalize
    # traj /= torch.mean(torch.abs(traj))

    traj = traj*torch.tensor([1, -1]).to(device)

    traj = rearrange(traj, "len t i -> t i len")  # shape: (2, N_TIME, N_SPOKES)

    return traj


In [4]:
# define trajectory
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

N_spokes = 288
N_time = 1
N_samples = 640

ktraj_tensor = get_ktraj(N_spokes, N_time, N_samples // 2, device)
print(ktraj_tensor.shape)

torch.Size([1, 2, 184320])


In [5]:
# load k-space
kspace_path = '/ess/scratch/scratch1/rachelgordon/fastMRI_breast_data/fastMRI_breast_IDS_001_010/fastMRI_breast_001_2.h5'

f = h5py.File(kspace_path, 'r')
original_kspace = f['kspace'][:].T

original_kspace = torch.tensor(original_kspace, dtype=torch.float32)
print(original_kspace.shape)

y_radial = original_kspace[0][0]
print(y_radial.shape)

torch.Size([83, 16, 640, 288, 2])
torch.Size([640, 288, 2])


In [6]:
# load image 

def load_nii(path):
    """
    Load a NIfTI image from the specified path.
    Args:
    - path (str): File path to the NIfTI image.
    """
    nii_image = nib.load(path)
    return nii_image.get_fdata()

ground_truth_dir = "/ess/scratch/scratch1/rachelgordon/complex_fully_sampled/"
patient_id = "fastMRI_breast_001_1"
image_path = f"{ground_truth_dir}{patient_id}/slice_040_frame_000.nii"

image = load_nii(image_path)
image = torch.from_numpy(image)

image = rearrange(image, 'i h w -> h w i').unsqueeze(0).unsqueeze(0)
print(image.shape)


torch.Size([1, 1, 320, 320, 2])


In [7]:
# ---------------------------------------------------------------------
# 2) Choose an image‐domain size that matches how you built the NUFFT operators.
#    For example, if your images are 320×320:
# ---------------------------------------------------------------------
H = W = 320
im_size   = (H, W)        # (coils=1, H, W) for single‐coil
grid_size = (H * 2, W * 2)   # e.g. double‐density grid

In [8]:
# ---------------------------------------------------------------------
# 3) Build a standalone adjoint‐NUFFT operator to compare against the DC layer.
# ---------------------------------------------------------------------
adjnufft_op = KbNufftAdjoint(im_size=im_size, grid_size=grid_size).cuda()

In [9]:
# ---------------------------------------------------------------------
# 4) Instantiate your Radial DC layer, but force λ such that sigmoid(λ)=0.
#    For sigmoid(λ)=0 up to machine precision, you can pick raw λ = -20 (sigmoid(-20)≈2×10^(-9)).
#    Also turn off gradient‐updates by setting requires_grad=False (so λ stays fixed).
# ---------------------------------------------------------------------
raw_lambda_for_zero = -20.0
dc_layer = nn.Sequential()  # a dummy wrapper so we can set .lambda_ manually
dc = getattr(__import__("__main__"), "RadialDCLayerSingleCoil")(
    im_size=im_size,
    grid_size=grid_size,
    lambda_init=raw_lambda_for_zero,
    learnable=False,
    device=torch.device("cuda")
)
dc_layer.add_module("radial_dc", dc)

In [10]:
# ---------------------------------------------------------------------
# 5) Run the “pure adjoint” through your DC layer:
# ---------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
with torch.no_grad():
    x_dc_from_layer = dc_layer.radial_dc(
        x=image.to(device),
        y_radial=y_radial.to(device),
        ktraj=ktraj_tensor.to(device)
    )
    # Direct adjoint:
    x_direct_adjoint = adjnufft_op(y_radial, ktraj_tensor, smaps=None, norm='ortho')


RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/gpfs/data/karczmar-lab/workspaces/rachelgordon/micromamba/envs/recon_mri/lib/python3.11/site-packages/torchkbnufft/_nufft/interp.py", line 582, in sort_data
        data_ret = torch.cat([result[2] for result in results])
    else:
        tm_ret, omega_ret, data_ret = sort_one_batch(tm, omega, data, grid_size)
                                      ~~~~~~~~~~~~~~ <--- HERE

    return tm_ret, omega_ret, data_ret
  File "/gpfs/data/karczmar-lab/workspaces/rachelgordon/micromamba/envs/recon_mri/lib/python3.11/site-packages/torchkbnufft/_nufft/interp.py", line 562, in sort_one_batch
    _, indices = torch.sort(tmp)

    return tm[:, indices], omega[:, indices], data[:, :, indices]
                                              ~~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: CUDA out of memory. Tried to allocate 506.25 GiB. GPU 0 has a total capacity of 39.49 GiB of which 38.95 GiB is free. Including non-PyTorch memory, this process has 548.00 MiB memory in use. Of the allocated memory 29.67 MiB is allocated by PyTorch, and 14.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


In [None]:
# ---------------------------------------------------------------------
# 6) Compare the two outputs:
# ---------------------------------------------------------------------
# Move to CPU & convert to numpy for easy comparison:
x1 = x_dc_from_layer.cpu().numpy()
x2 = x_direct_adjoint.cpu().numpy()

max_abs_diff = np.max(np.abs(x1 - x2))
print(f"Max absolute difference between DC‐layer output and direct adjoint: {max_abs_diff:.3e}")

tol = 1e-5
if max_abs_diff < tol:
    print("✅ PASS: Radial DC layer’s adjoint branch matches pure adjoint‐NUFFT.")
else:
    print("❌ FAIL: outputs differ by more than tolerance.")