In [None]:
from torch.nn.functional import pad
import torch.nn.functional as F
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import time

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'

In [None]:
# parameters
simulation_num = 1375
layers = 8
attenuation = 0.1
zero_pads = 22
X, Y = 16, 16
upsampling_ratio = 16

# load circuit data from circuitFaker
train_ims = np.load('../TomoMAEv2/ic-64x64x8.npy')[:simulation_num, :X, :Y, :]
print(np.shape(train_ims))

# upsample the circuits
train_ims_upsampled = np.zeros((simulation_num, X*upsampling_ratio, Y*upsampling_ratio, layers))

for i in range(X):
    for j in range(Y):
        for k in range(upsampling_ratio):
            for l in range(upsampling_ratio):
                train_ims_upsampled[:, upsampling_ratio*i+k, upsampling_ratio*j+l, :] = train_ims[:, i, j, :] 

train_ims = train_ims_upsampled

# stack 3D circuits to 2D layer-wise circuit
train_ims_stack = np.zeros((simulation_num*layers, X*upsampling_ratio, Y*upsampling_ratio))

for i in range(layers):
    train_ims_stack[simulation_num*i:simulation_num*(i+1), :, :] = train_ims[:, :, :, i]

# permutation the order
np.random.seed(0)
train_ims_stack = np.random.permutation(train_ims_stack)

# attenuation the value for physical object
train_ims = np.array(train_ims_stack)
train_ims = (attenuation * train_ims).astype(np.float32)

# zero padding around the X and Y plane
train_ims = np.pad(train_ims, ((0, 0), (zero_pads, zero_pads), (zero_pads, zero_pads)))
train_ims = np.expand_dims(np.array(train_ims), axis=1)
print(np.shape(train_ims))

In [None]:
plt.rcParams['figure.figsize'] = [40*2, 10*2]
fig = plt.figure()
fig.subplots_adjust(hspace=.05, wspace=.05)

for i in range(layers):
    ax = fig.add_subplot(1, layers, i +1)
    plt.imshow(train_ims_stack[i*simulation_num, :, :])
    plt.clim(0, 1)

In [None]:
def Projection2D_batch(data3d, batch_size=100):
    """
    Genearting 2D projections from cone beam geometry with batching in the angles
    data3d in shape of nz, nx, ny
    """
    # to store multi-angle projections
    proj_tot = []
    
    data3d.requires_grad_()
    
    # batching for the angle
    for i in range(nProj // batch_size + 1):
        # genearting affine grids for sample rotation
        angle_rad = deg[i * batch_size:(i+1) * batch_size] / 360 * 2 * np.pi
        angle_rad = angle_rad.unsqueeze(-1).unsqueeze(-1)

        uu, vv = torch.meshgrid(us, vs)
        uu, vv = uu.T, vv.T
        xx, yy = torch.meshgrid(xs, ys)
        xx, yy = xx.T, yy.T

        rx = (xx * torch.cos(angle_rad) - yy * torch.sin(angle_rad)) / dx / nx * 2
        ry = (xx * torch.sin(angle_rad) + yy * torch.cos(angle_rad)) / dy / ny * 2

        if GPU:
            rx = rx.to(device='cuda:0')
            ry = ry.to(device='cuda:0')

        # rxry in (batch, rx, ry, 2)
        rxry = torch.stack((rx, ry), -1)

        # using bilinear interpolation to sample the rotated objects
        # expand the data3d in the batch dimension without consuming GPU memory
        # data3d_new in shape of (batch, nz, nx, ny)
        data3d_new = F.grid_sample(data3d.unsqueeze(0).expand(len(angle_rad), nz, nx, ny), rxry, 
                           mode='bilinear', padding_mode='zeros', align_corners=True)

        # generate sampling grid in xs and zs for cone beam
        xx, zz = torch.meshgrid(xs, zs)
        xx, zz = xx.T, zz.T

        Ratio = (ys + DSO) / DSD
        # to do broadcasting later on
        Ratio = Ratio.unsqueeze(-1).unsqueeze(-1)

        pu, pv = uu * Ratio, vv * Ratio
        pu, pv = (pu) / dx / nx * 2, (pv) / dz / nz * 2
        pu, pv = pu.to(torch.float32), pv.to(torch.float32)

        if GPU:
            pv = pv.to(device='cuda:0')
            pu = pu.to(device='cuda:0')

        # pupv in (ny, nu, nv, 2)
        pupv = torch.stack((pu, pv), -1)

        # permute data3d from (batch, nz, nx, ny) to (ny, batch, nz, nx)
        # to process all the ny elements
        data3d_new = data3d_new.permute(3, 0, 1, 2)
        temp = F.grid_sample(data3d_new, pupv, 
                             mode='bilinear', padding_mode='zeros', align_corners=True)

        # sum in ny to produce the 2d projections
        # then transpose to  (batch, nx, nz)
        proj2d = torch.sum(temp, dim=[0])
        proj2d = proj2d.permute(0, 2, 1)

        dist = torch.sqrt(DSD**2 + uu**2 + vv**2) / DSD * dy

        if GPU:
            dist = dist.to(device='cuda:0')
        # multiple the distance matrix
        proj2d = proj2d * dist.T
        
        proj_tot.append(proj2d)
    # concat alone the batch dimension to get all the projections
    proj_tot = torch.cat(proj_tot, dim=0)

    return proj_tot

In [None]:
# using GPU?
GPU = False

# define number of object voxels
nx, ny, nz = 300, 300, 1

# object real size in mm
sx, sy, sz = 8, 8, 16/128

# real detector pixel density
nu, nv = 256, 1

# detector real size in mm
su, sv = 25, 25/256

# single voxel size
dx, dy, dz, du, dv = sx/nx, sy/ny, sz/nz, su/nu, sv/nv

# Geometry calculation
xs = torch.arange(-(nx -1) / 2, nx / 2, 1) * dx
ys = torch.arange(-(ny -1) / 2, ny / 2, 1) * dy
zs = torch.arange(-(nz -1) / 2, nz / 2, 1) * dz

us = torch.arange(-(nu -1) / 2, nu / 2, 1) * du
vs = torch.arange(-(nv -1) / 2, nv / 2, 1) * dv

# source to detector, and source to object axis distances in mm
DSD, DSO = 230.11+342.17, 230.11

tot_projections = []

for obj in tqdm(train_ims):
    
    direction = 1 # clock wise/ counter clockwise
    start = 0.
    end = 360.
    step = 360/1601
    deg = direction * torch.arange(start, end-step, step)
    nProj = len(deg)

    projections = Projection2D_batch(torch.from_numpy(obj.astype(np.float32)).to('cpu'),
                                     batch_size=1600).to('cpu').detach().numpy()
    tot_projections.append(projections)
    
tot_projections = np.array(tot_projections).astype(np.float32)
print(np.shape(tot_projections))
np.save('tot_proj_ic_1600.npy', tot_projections)

In [None]:
noise_free = np.load('tot_proj_ic_1600.npy')[-1000:]
print("shape", np.shape(noise_free))
print("noise_free", np.max(noise_free), np.min(noise_free))

for photon in [32, 40, 50, 64, 80, 100, 128, 160, 200, 256, 320, 400, 500, 640, 800, 1000, 1280, 1600, 2000]:
    measured = photon * np.exp(-noise_free)
    print("measured photons", np.max(measured), np.min(measured))
    
    noisy = np.random.poisson(measured)
    print("noisy photons", np.max(noisy), np.min(noisy))
    
    convert_back = - np.where(noisy>0, np.log(noisy/photon, where=noisy>0), 0)
    print("noisy measurements", np.max(convert_back), np.min(convert_back))
    
    np.save('tot_proj_test_ic_1600_photon_' + str(photon) + '.npy', convert_back.astype(np.float32))