In [1]:
import os
import numpy as np
import nibabel as nib
from skimage.transform import resize, radon
import tomopy
from tqdm import tqdm

# -------------------------------
# Parameters
# -------------------------------
N = 128                    # resize target per slice
pixel_size = 1e-4          # pixel size in meters
wavelength = 6.2e-11       # 20 keV X-rays ~ 0.062 nm
k = 2 * np.pi / wavelength
beta = 1e-10                # ridotto per evitare overflow
delta = 5e-8                # ridotto per evitare overflow
distances = [0.01, 0.001, 0.1]  # detector distances in meters
num_angles = 180
theta = np.linspace(0., 180., num_angles, endpoint=False)
energy_keV = 20.0

data_dir = "./ct_scans"
nii_files = [f for f in os.listdir(data_dir) if f.endswith(".nii") or f.endswith(".nii.gz")]
print(f"Trovati {len(nii_files)} file NIfTI.")

# -------------------------------
# Phase contrast propagation
# -------------------------------
def propagate_phase_contrast(img, z, pixel_size, wavelength, delta, beta):
    """Simula phase contrast propagation di una slice 2D."""
    img = img.astype(np.float32)
    trans = np.exp(-k * beta * img) * np.exp(-1j * k * delta * img)
    
    fx = np.fft.fftfreq(N, pixel_size)
    fy = np.fft.fftfreq(N, pixel_size)
    FX, FY = np.meshgrid(fx, fy, indexing='xy')
    arg = 1.0 - (wavelength * FX)**2 - (wavelength * FY)**2
    sqrt_term = np.sqrt(arg.astype(np.complex128))
    H = np.exp(1j * 2 * np.pi * z / wavelength * sqrt_term)
    
    wave = np.fft.ifft2(np.fft.fft2(trans) * H)
    I_phase = np.abs(wave)**2
    return I_phase

def phase_retrieve_image(I_phase, pixel_size, z, delta, beta, energy_keV):
    """Recupera fase usando TomoPy."""
    tomo = I_phase[np.newaxis, :, :].astype(np.float32)
    pixel_size_cm = pixel_size * 100.0
    dist_cm = z * 100.0
    alpha = beta / delta
    retrieved = tomopy.prep.phase.retrieve_phase(
        tomo, pixel_size=pixel_size_cm, dist=dist_cm,
        energy=energy_keV, alpha=alpha, pad=True
    )
    return retrieved[0, :, :]

# -------------------------------
# Dataset generation
# -------------------------------
X = []  # sinogrammi
Y = []  # slice originali

for nii_file in tqdm(nii_files, desc="Processing CT volumes"):
    path = os.path.join(data_dir, nii_file)
    img_nii = nib.load(path)
    vol = img_nii.get_fdata()
    
    # Assumiamo (H, W, num_slices), altrimenti trasponi
    if vol.shape[2] < vol.shape[0] or vol.shape[2] < vol.shape[1]:
        vol = np.transpose(vol, (1, 2, 0))
    
    # ciclo sulle slice
    for i in range(vol.shape[2]):
        slice_orig = vol[:, :, i]
        
        # ridimensiona a N×N
        slice_resized = resize(slice_orig, (N, N), preserve_range=True)
        
        # Normalizzazione per evitare overflow
        slice_clipped = np.clip(slice_resized, -1000, 3000)  # HU tipici
        slice_norm = (slice_clipped - slice_clipped.min()) / (slice_clipped.max() - slice_clipped.min())
        slice_norm = slice_norm.astype(np.float32)
        
        # stack phase contrast per più distanze
        phase_stack = []
        for z in distances:
            I_phase = propagate_phase_contrast(slice_norm, z, pixel_size, wavelength, delta, beta)
            retrieved = phase_retrieve_image(I_phase, pixel_size, z, delta, beta, energy_keV)
            phase_stack.append(retrieved)
        
        multi_stack = np.stack(phase_stack, axis=-1)  # (N, N, num_distances)
        
        # calcola sinogrammi per ciascun canale
        sinograms = [radon(multi_stack[:, :, j], theta=theta, circle=False) for j in range(len(distances))]
        sinograms = np.stack(sinograms, axis=-1)  # (detector_pixels, num_angles, num_distances)
        
        X.append(sinograms)
        Y.append(slice_norm)  # target normalizzato

# converte in array NumPy
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)

print("Dataset creato:")
print("X (sinogrammi):", X.shape)
print("Y (slices):", Y.shape)

# -------------------------------
# Salvataggio su disco
# -------------------------------
np.savez_compressed("phase_contrast_dataset.npz", X=X, Y=Y)
print("Dataset salvato in 'phase_contrast_dataset.npz'")


1
Trovati 20 file NIfTI.


Processing CT volumes: 100%|████████████████████████████████████████████████████████| 20/20 [3:57:01<00:00, 711.08s/it]


Dataset creato:
X (sinogrammi): (11420, 182, 180, 3)
Y (slices): (11420, 128, 128)
Dataset salvato in 'phase_contrast_dataset.npz'


In [1]:
# ==========================================
# Train neural network on phase-contrast sinograms
# ==========================================
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers
from tqdm import tqdm

# -------------------------------
# Parameters
# -------------------------------
np.random.seed(42)
tf.random.set_seed(42)

DATA_FILE = "phase_contrast_dataset.npz"   # dataset generato con il codice precedente
MODEL_OUT = "sino2img_cnn.h5"
BATCH_SIZE = 8
EPOCHS = 50
VAL_SPLIT = 0.15

# -------------------------------
# Load dataset
# -------------------------------
if not os.path.exists(DATA_FILE):
    raise FileNotFoundError(f"{DATA_FILE} non trovato! Assicurati di aver eseguito il codice di generazione prima.")

data = np.load(DATA_FILE)
X = data["X"]
Y = data["Y"]

print(f"Loaded dataset shapes: X={X.shape}, Y={Y.shape}")

# Add channel dimension to Y if missing
if Y.ndim == 3:
    Y = Y[..., np.newaxis]
print("Y reshaped to:", Y.shape)

input_shape = X.shape[1:]  # (detector_pixels, num_angles, num_distances)
print("Input shape:", input_shape)

# -------------------------------
# Normalize
# -------------------------------
def normalize_X(X):
    Xn = np.empty_like(X, dtype=np.float32)
    print("Normalizing X...")
    for i in tqdm(range(X.shape[0])):
        arr = X[i]
        mn, mx = arr.min(), arr.max()
        if mx > mn:
            Xn[i] = (arr - mn) / (mx - mn)
        else:
            Xn[i] = arr
    return Xn

Xn = normalize_X(X)
Y = np.clip(Y.astype(np.float32), 0, 1)

Loaded dataset shapes: X=(11420, 182, 180, 3), Y=(11420, 128, 128)
Y reshaped to: (11420, 128, 128, 1)
Input shape: (182, 180, 3)
Normalizing X...


100%|██████████| 11420/11420 [00:04<00:00, 2438.45it/s]


In [None]:
data = np.load(DATA_FILE)
X = data["X"]
Y = data["Y"]
X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=VAL_SPLIT, random_state=42)
print("Train shapes:", X_train.shape, Y_train.shape)
print("Val shapes:", X_val.shape, Y_val.shape)

In [None]:
# -------------------------------
# Model definition
# -------------------------------
detector_pixels, num_angles, num_channels = input_shape
output_H, output_W, _ = Y_train.shape[1:]

def build_model(input_shape, base_filters=32):
    inp = layers.Input(shape=input_shape)
    x = layers.Conv2D(base_filters, 3, padding="same", activation="relu")(inp)
    x = layers.Conv2D(base_filters, 3, padding="same", activation="relu")(x)
    x = layers.MaxPooling2D(2)(x)

    x = layers.Conv2D(base_filters*2, 3, padding="same", activation="relu")(x)
    x = layers.Conv2D(base_filters*2, 3, padding="same", activation="relu")(x)
    x = layers.MaxPooling2D(2)(x)

    x = layers.Conv2D(base_filters*4, 3, padding="same", activation="relu")(x)
    x = layers.Conv2D(base_filters*4, 3, padding="same", activation="relu")(x)

    x = layers.UpSampling2D(2)(x)
    x = layers.Conv2D(base_filters*2, 3, padding="same", activation="relu")(x)
    x = layers.Conv2D(base_filters*2, 3, padding="same", activation="relu")(x)

    x = layers.UpSampling2D(2)(x)
    x = layers.Conv2D(base_filters, 3, padding="same", activation="relu")(x)
    x = layers.Conv2D(base_filters, 3, padding="same", activation="relu")(x)

    out = layers.Conv2D(1, 1, padding="same", activation="sigmoid")(x)
    return models.Model(inp, out)

model = build_model(input_shape)
model.summary()

# -------------------------------
# Compile
# -------------------------------
model.compile(optimizer=optimizers.Adam(1e-4), loss="mse", metrics=["mae"])
cb_list = [
    callbacks.ModelCheckpoint(MODEL_OUT, save_best_only=True, monitor="val_loss"),
    callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, verbose=1),
    callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True, verbose=1)
]


Loaded dataset shapes: X=(11420, 182, 180, 3), Y=(11420, 128, 128)
Y reshaped to: (11420, 128, 128, 1)
Input shape: (182, 180, 3)
Normalizing X...


100%|██████████| 11420/11420 [00:04<00:00, 2727.40it/s]


In [None]:
# -------------------------------
# Train
# -------------------------------
history = model.fit(
    X_train, Y_train,
    validation_data=(X_val, Y_val),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=cb_list,
    verbose=2
)

model.save(MODEL_OUT)
print("✅ Training finished. Model saved at", MODEL_OUT)

# -------------------------------
# Evaluate and visualize
# -------------------------------
print("Evaluating model...")
val_preds = []
for i in tqdm(range(0, len(X_val), BATCH_SIZE)):
    preds = model.predict(X_val[i:i+BATCH_SIZE], verbose=0)
    val_preds.append(preds)
val_preds = np.concatenate(val_preds, axis=0)

mse = np.mean((val_preds - Y_val)**2)
mae = np.mean(np.abs(val_preds - Y_val))
print(f"Validation MSE={mse:.6f}, MAE={mae:.6f}")

# -------------------------------
# Plot sample reconstructions
# -------------------------------
n_show = min(6, len(X_val))
idxs = np.random.choice(len(X_val), n_show, replace=False)

fig, axes = plt.subplots(n_show, 4, figsize=(12, 3*n_show))
for r, i in enumerate(idxs):
    sino = X_val[i, :, :, 0]
    gt = Y_val[i, :, :, 0]
    pred = val_preds[i, :, :, 0]
    err = np.abs(gt - pred)

    ax = axes[r, 0]
    ax.imshow(sino.T, aspect="auto")
    ax.set_title("Sinogram (ch0)")
    ax.axis("off")

    ax = axes[r, 1]
    ax.imshow(gt, cmap="gray")
    ax.set_title("Ground Truth")
    ax.axis("off")

    ax = axes[r, 2]
    ax.imshow(pred, cmap="gray")
    ax.set_title("NN Reconstruction")
    ax.axis("off")

    ax = axes[r, 3]
    im = ax.imshow(err, cmap="inferno")
    ax.set_title("Error")
    ax.axis("off")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

# -------------------------------
# Plot training history
# -------------------------------
plt.figure(figsize=(8,4))
plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Val Loss")
plt.yscale("log")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.title("Training Curve (log scale)")
plt.legend()
plt.show()
