In [1]:
from warnings import filterwarnings

from random import seed
from random import random
# seed random number generator
seed(1)

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchkbnufft as tkbn
from skimage.data import shepp_logan_phantom
import time

filterwarnings("ignore") # ignore floor divide warnings
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [2]:
nf = 40000
nx = 64
ncoil = 1
im_size = (nx,nx,nx)

coords = -np.pi + 2*np.pi*np.random.rand(3,nf).astype(np.float32)
image = (np.random.rand(ncoil,nx,nx,nx).astype(np.float32) + 1j*np.random.rand(ncoil,nx,nx,nx).astype(np.float32)).astype(np.complex64)

In [3]:
# convert k-space trajectory to a tensor
coords = torch.tensor(coords).to(device).requires_grad_(False)
print('coords shape: {}'.format(coords.shape))
image = torch.tensor(image).to(device).unsqueeze_(0).requires_grad_(False)
print('image shape: {}'.format(image.shape))

coords shape: torch.Size([3, 40000])
image shape: torch.Size([1, 1, 64, 64, 64])


In [5]:
toep_ob = tkbn.ToepNufft()

dcomp = tkbn.calc_density_compensation_function(ktraj=coords, im_size=im_size)

normal_kernel = tkbn.calc_toeplitz_kernel(coords, im_size, norm="ortho")  # without density compensation

In [8]:
print(coords.shape)
print(dcomp.shape)

print('normal real: ', torch.real(normal_kernel).abs().max())
print('normal imag: ', torch.imag(normal_kernel).abs().max())


torch.Size([3, 400000])
torch.Size([1, 1, 400000])
normal real:  tensor(1.9877, device='cuda:0')
normal imag:  tensor(3.0999e-07, device='cuda:0')


In [6]:
# calculate k-space data
torch.cuda.synchronize()
start = time.time()
kdata = nufft_ob(image, coords)
torch.cuda.synchronize()
end = time.time()
print(end - start)

# calculate k-space data
torch.cuda.synchronize()
start = time.time()
for i in range(0,10):
	kdata = nufft_ob(image+random(), coords)
torch.cuda.synchronize()
end = time.time()
print(end - start)


0.18611788749694824
1.6487987041473389


In [10]:
# calculate k-space data
torch.cuda.synchronize()
start = time.time()

back = toep_ob(image, normal_kernel, norm="ortho")

torch.cuda.synchronize()
end = time.time()
print(end - start)

# calculate k-space data
torch.cuda.synchronize()
start = time.time()

for i in range(0,100):
	back = toep_ob(image+random(), normal_kernel, norm="ortho")

torch.cuda.synchronize()
end = time.time()
print(end - start)


0.3881187438964844
3.834625720977783
