<a href="https://colab.research.google.com/github/safaiat/CT_EXP/blob/main/MBIR_Full_Safaiat/A_matrix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import math
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import math

def build_params():


  imgparams = {


  "Nx": 256,                 # pixels in X
  "Ny": 256,                 # pixels in Y
  "Nz": 1,                   # slices
  "FirstSliceNumber": 0,     #
  "Deltaxy": 1.0,            # pixel size, mm
  "DeltaZ": 1.0,             # mm
  "ROIRadius": 128.0

  }

  sinoparams = {
      "NViews": 64,             # number of projection angles
      "NChannels": 512,          # detector bins
      "NSlices": 1,              # slices
      "FirstSliceNumber": 0,     # slice index
      "DeltaChannel": 1.0,       # mm,
      "CenterOffset": 0.0,       # detector center shift in bins
      "DeltaSlice": 1.0,         # mm, slice spacing (dummy=1 for 2D)
      "ViewAngles": []


  }


  step_size = math.pi / sinoparams["NViews"]

  for i in range(sinoparams["NViews"]):
    sinoparams["ViewAngles"].append(i * step_size)



  return imgparams, sinoparams




In [3]:
imgparams, sinoparams = build_params()
LEN_PIX =511

In [4]:
def pixel_center_from_index(j, imgparams):

    Nx = imgparams["Nx"]
    Ny = imgparams["Ny"]
    dxy = imgparams["Deltaxy"]

    # 1) Unflatten j -> (jx, jy), that means convert j into x,y
    jx = j % Nx
    jy = j // Nx

    # 2) Convert to physical coordinates (mm), origin at image center
    x = (jx - (Nx - 1) / 2.0) * dxy
    y = (jy - (Ny - 1) / 2.0) * dxy

    return x, y





In [5]:
def detector_centers(sinoparams):

    N = sinoparams["NChannels"]
    d = sinoparams["DeltaChannel"]
    off = sinoparams["CenterOffset"]


    t_centers = [(i - ((N - 1) / 2.0) - off) * d for i in range(N)]
    return t_centers

In [6]:
import torch, math

def ComputePixelProfile3DParallel(sinoparams, imgparams, LEN_PIX, device=None, dtype=None):


    # ---- device/dtype defaults ----
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if dtype is None:
        dtype = torch.float32

    # ---- allocate output; mirror "get_img(LEN_PIX, NViews, sizeof(float))" ----
    NViews = int(sinoparams["NViews"])
    pix_prof = torch.zeros((NViews, int(LEN_PIX)), dtype=dtype, device=device)

    # ---- copy C variable names/logic exactly ----
    pi = math.pi                                   # PI
    rc = math.sin(pi/4.0)                          # sin(pi/4)
    DeltaPix = float(imgparams["Deltaxy"])         # pixel size
    # NOTE: we read angles from sinoparams["ViewAngles"] like the C code

    for i in range(NViews):
        ang = float(sinoparams["ViewAngles"][i])   # ang = sinoparams->ViewAngles[i]

        # while (ang >= pi/2) ang -= pi/2; while (ang < 0) ang += pi/2;
        while ang >= pi/2.0:
            ang -= pi/2.0
        while ang < 0.0:
            ang += pi/2.0

        # if (ang <= pi/4) maxval = DeltaPix/cos(ang); else maxval = DeltaPix/cos(pi/2-ang);
        if ang <= pi/4.0:
            maxval = DeltaPix / max(math.cos(ang), 1e-12)
        else:
            maxval = DeltaPix / max(math.cos(pi/2.0 - ang), 1e-12)

        # d1 = rc*cos(pi/4 - ang);  d2 = rc*fabs(sin(pi/4 - ang));
        d1 = rc * math.cos(pi/4.0 - ang)
        d2 = rc * abs(math.sin(pi/4.0 - ang))

        # this is for the centering to the pixel profile which is [0,2] later it will be adjusted but the grid is 2 unit long.
        t_1 = 1.0 - d1
        t_2 = 1.0 - d2
        t_3 = 1.0 + d2
        t_4 = 1.0 + d1

        # for (j=0; j<LEN_PIX; j++) { t = 2.0*j/(float)LEN_PIX; ... }
        for j in range(LEN_PIX):
            t = 2.0 * j / float(LEN_PIX)

            if (t <= t_1) or (t > t_4):
                val = 0.0
            elif t <= t_2:
                denom = (t_2 - t_1) if (t_2 - t_1) != 0.0 else 1e-12
                val = maxval * (t - t_1) / denom
            elif t <= t_3:
                val = maxval
            else:
                denom = (t_4 - t_3) if (t_4 - t_3) != 0.0 else 1e-12
                val = maxval * (t_4 - t) / denom

            pix_prof[i, j] = val

    return pix_prof


In [7]:
#this detects where the center of the pixel casts shadow in the detector pixel in mm

def channel_window_for_pixel(x, y, theta, DeltaPix, t0, DeltaChannel, NChannels):

    s0 = y*math.cos(theta) - x*math.sin(theta)  # pixel projection center (mm)

    # Support of the footprint is within s0 ± DeltaPix
    t_min = s0 - DeltaPix
    t_max = s0 + DeltaPix

    # Convert to bin indices; include half-bin extent with ±0.5
    i_min = math.ceil((t_min - t0) / DeltaChannel - 0.5)
    i_max = math.floor((t_max - t0) / DeltaChannel + 0.5)

    # Clamp to valid range
    i_min = max(i_min, 0)
    i_max = min(i_max, NChannels - 1)
    #i_min and i_max are the maximum and minimum shadow range and and s0 is the center of the pixel shadow on the detector bin

    return i_min, i_max, s0


In [8]:
#this converts the mm into to exact detector pixel number
def profile_index_for_channel(t_i, s0, DeltaPix, LEN_PIX):

    # ((t_i - s0) + DeltaPix) / (2*DeltaPix) maps [-Δpix, +Δpix] -> [0, 1]
    # Multiply by (LEN_PIX - 1) to get index range, then round to nearest
    u_float = ((t_i - s0) + DeltaPix) * (LEN_PIX - 1) / (2.0 * DeltaPix)
    u = int(math.floor(u_float + 0.5))
    # Clip
    if u < 0:
        u = 0
    elif u >= LEN_PIX:
        u = LEN_PIX - 1
    return u


In [9]:

#Given the view (angle) and detector element position (mapped to u_idx), return the trapezoid footprint weight for that pixel.
def sample_profile(pix_prof, view_idx, u_idx):
    return float(pix_prof[view_idx, u_idx].item())


In [10]:
#creates sparese matreix column
def compute_sysmatrix_column(j, imgparams, sinoparams, pix_prof, t_centers, LEN_PIX):

    Nx = imgparams["Nx"]; Ny = imgparams["Ny"]; DeltaPix = imgparams["Deltaxy"]
    NViews = sinoparams["NViews"]; NChannels = sinoparams["NChannels"]
    DeltaChannel = sinoparams["DeltaChannel"]
    ViewAngles = sinoparams["ViewAngles"]
    t0 = t_centers[0]

    # 1) pixel center
    x, y = pixel_center_from_index(j, imgparams)

    rows, vals = [], []

    # 2) loop views
    for p in range(NViews):
        theta = ViewAngles[p]

        # window of potentially nonzero bins + s0
        i_min, i_max, s0 = channel_window_for_pixel(
            x, y, theta, DeltaPix, t0, DeltaChannel, NChannels
        )
        if i_min > i_max:
            continue

        # 3) loop only bins in the window
        for i in range(i_min, i_max + 1):
            u = profile_index_for_channel(t_centers[i], s0, DeltaPix, LEN_PIX)
            w = sample_profile(pix_prof, p, u)  # simple mode

            if w > 0.0:
                rows.append(p * NChannels + i)
                vals.append(w)

    return rows, vals


In [11]:
from tqdm import tqdm   # add this at the top of your file

# assembles A-matrix
def assemble_sysmatrix(imgparams, sinoparams, LEN_PIX):

    # Precompute detector centers and per-view profiles
    t_centers = detector_centers(sinoparams)
    pix_prof = ComputePixelProfile3DParallel(sinoparams, imgparams, LEN_PIX=LEN_PIX)

    Nx = imgparams["Nx"]; Ny = imgparams["Ny"]
    Ncols = Nx * Ny
    columns = [None] * Ncols

    # wrap range(Ncols) with tqdm to track loop progress
    for j in tqdm(range(Ncols), desc="Building system matrix", unit="col"):
        rows, vals = compute_sysmatrix_column(
            j, imgparams, sinoparams, pix_prof, t_centers, LEN_PIX
        )
        columns[j] = {"row_idx": rows, "val": vals}

    return {"Ncolumns": Ncols, "columns": columns}



In [12]:
import numpy as np
import json
import os

def save_sysmatrix(A, imgparams, sinoparams, path):


    Ncols = A["Ncolumns"]

    indptr = np.zeros(Ncols + 1, dtype=np.int64)
    indices = []
    data = []

    nnz = 0
    for j in range(Ncols):
        rows = A["columns"][j]["row_idx"]
        vals = A["columns"][j]["val"]

        indptr[j] = nnz
        indices.extend(rows)
        data.extend(vals)
        nnz += len(rows)

    indptr[Ncols] = nnz

    # Save CSC arrays
    np.savez_compressed(
        path + "_sysmat.npz",
        indptr=indptr,
        indices=np.array(indices, dtype=np.int32),
        data=np.array(data, dtype=np.float32)
    )

    # Save metadata
    meta = {
        "imgparams": imgparams,
        "sinoparams": {
            "NViews": sinoparams["NViews"],
            "NChannels": sinoparams["NChannels"],
            "DeltaChannel": sinoparams["DeltaChannel"],
            "CenterOffset": sinoparams["CenterOffset"],
            "ViewAngles": sinoparams["ViewAngles"]  # may be big
        },
        "row_encoding": "row = view * NChannels + channel",
        "Ncolumns": Ncols,
        "Nrows": sinoparams["NViews"] * sinoparams["NChannels"]
    }

    with open(path + "_meta.json", "w") as f:
        json.dump(meta, f, indent=2)

    print(f"Saved matrix with {nnz} nonzeros to {path}_sysmat.npz")
    print(f"Metadata written to {path}_meta.json")


In [13]:
# Assemble A
A = assemble_sysmatrix(imgparams, sinoparams, LEN_PIX)

# Save to disk
save_sysmatrix(A, imgparams, sinoparams, "my_sysmatrix")


Building system matrix: 100%|██████████| 65536/65536 [01:58<00:00, 555.05col/s]


Saved matrix with 5347526 nonzeros to my_sysmatrix_sysmat.npz
Metadata written to my_sysmatrix_meta.json
