In [None]:
import models
import torch
import SimpleITK as sitk
import glob
import utils
from utils_clip.simple_tokenizer import SimpleTokenizer
import numpy as np
import os
from itertools import product
from CLIP.model import CLIP
from utils_clip import load_config_file
import time
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# checkpoint_path = 'checkpoint_CLIP.pt'
MODEL_CONFIG_PATH = 'CLIP/model_config.yaml'
model_config = load_config_file(MODEL_CONFIG_PATH)

tokenizer = SimpleTokenizer()
model_params = dict(model_config.RN50)
model_params['vision_layers'] = tuple(model_params['vision_layers'])
model_params['vision_patch_size'] = None
model = CLIP(**model_params)
# checkpoint = torch.load(checkpoint_path)
# state_dict = checkpoint['model_state_dict']

# model.load_state_dict(state_dict)
model = model.cuda()
model.eval()


def tokenize(texts, tokenizer, context_length=90):
    if isinstance(texts, str):
        texts = [texts]

    sot_token = tokenizer.encoder["<|startoftext|>"]
    eot_token = tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)
    return result

def img_pad(img, target_shape):
    current_shape = img.shape
    pads = [(0, max(0, target_shape[i] - current_shape[i])) for i in range(len(target_shape))]
    padded_img = np.pad(img, pads, mode='constant', constant_values=0)
    current_shape_2 = padded_img.shape
    crops = []
    for i in range(len(target_shape)):
        if current_shape_2[i] > target_shape[i]:
            crops.append(
                slice((current_shape_2[i] - target_shape[i]) // 2, (current_shape_2[i] + target_shape[i]) // 2))
        else:
            crops.append(slice(None))
    cropped_img = padded_img[tuple(crops)]
    return cropped_img

def calculate_patch_index(target_size, patch_size, overlap_ratio=0.25):
    shape = target_size

    gap = int(patch_size[0] * (1 - overlap_ratio))
    index1 = [f for f in range(shape[0])]
    index_x = index1[::gap]
    index2 = [f for f in range(shape[1])]
    index_y = index2[::gap]
    index3 = [f for f in range(shape[2])]
    index_z = index3[::gap]

    index_x = [f for f in index_x if f < shape[0] - patch_size[0]]
    index_x.append(shape[0] - patch_size[0])
    index_y = [f for f in index_y if f < shape[1] - patch_size[1]]
    index_y.append(shape[1] - patch_size[1])
    index_z = [f for f in index_z if f < shape[2] - patch_size[2]]
    index_z.append(shape[2] - patch_size[2])

    start_pos = list()
    loop_val = [index_x, index_y, index_z]
    for i in product(*loop_val):
        start_pos.append(i)
    return start_pos

def patch_slicer(img_vol_0, overlap_ratio, crop_size, scale0, scale1, scale2):
    W, H, D = img_vol_0.shape
    pos = calculate_patch_index((W, H, D), crop_size, overlap_ratio)
    scan_patches = []
    patch_idx = []
    for start_pos in pos:
        img_0_lr_patch = img_vol_0[start_pos[0]:start_pos[0] + crop_size[0], start_pos[1]:start_pos[1] + crop_size[1],
                         start_pos[2]:start_pos[2] + crop_size[2]]
        #print(img_0_lr_patch.shape)
        scan_patches.append(torch.tensor(img_0_lr_patch).float().unsqueeze(0))
        patch_idx.append([int(start_pos[0]), int(start_pos[0])+int(crop_size[0] * scale0), int(start_pos[1]), int(start_pos[1])+int(crop_size[1] * scale1), int(start_pos[2]), int(start_pos[2])+int(crop_size[2] * scale2)])
    return scan_patches, patch_idx

class PatchDataset(Dataset):
    def __init__(self, patches):
        self.patches = patches

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        return self.patches[idx]

def _get_pred(model, dataloader, coord_hr, seq_tgt, crop_size, out_dir):
    model.eval()
    results = []
    with torch.no_grad():
        for i, batch in tqdm(enumerate(dataloader)):
            batchsize = batch.size(0)  # batch size (60, 60, 60)
            input_patch = batch.cuda()
            batch_coord_hr = coord_hr.repeat(batchsize, 1, 1)
            tgt_prompt = seq_tgt.repeat(batchsize, 1)
            pred_0_1_patch = model.generation(input_patch, batch_coord_hr, tgt_prompt.cuda().float()) # (1, crop_size*3, 1)
            pred_0_1_patch = pred_0_1_patch.reshape(crop_size) # (crop_size, crop_size, crop_size)
            # utils.write_img(
            #     pred_0_1_patch.reshape(crop_size), 
            #     os.path.join(f'{out_dir}/generated_{i}.nii.gz'),
            #     os.path.join(img_path_1, 'test_HCPD_T2w.nii.gz'), 
            #     new_spacing=None)
            results.append(pred_0_1_patch)
            # print(torch.stack(results, dim=0).shape)
    return results   # (1, crop_size, crop_size, crop_size), list of tensor.cuda()

# Above are input pre-processing (divide the original 3D into patches)
#####################################################################
# Below are integrate generated patches into a unified 3D image



    
torch.multiprocessing.set_start_method('spawn', force=True)
psnr_0_1_list = []
psnr_1_0_list = []
ssim_0_1_list = []
ssim_1_0_list = []
# model_pth = './save/checkpoint.pth'
model_pth = './checkpoint.pth'
model_img = models.make(torch.load(model_pth)['model_G'], load_sd=True).cuda()
img_path_0 = './Experimental_data/image'
img_path_1 = './Experimental_data/image' # Using to provide target image spacing, it is not necessary, the target image spacing can be manually set
img_list_0 = sorted(os.listdir(img_path_0))
img_list_1 = sorted(os.listdir(img_path_1))
prompt_M1 = './Experimental_data/test_HCPD_T2w_prompt.txt'
with open(prompt_M1) as f1:
    lines_M1 = f1.readlines()

img_0 = sitk.ReadImage(os.path.join(img_path_0, 'test_HCPD_T1w.nii.gz'))
img_0_spacing = img_0.GetSpacing()
img_vol_0 = sitk.GetArrayFromImage(img_0)
H, W, D = img_vol_0.shape
img_vol_0 = img_pad(img_vol_0, target_shape=(H, W, D))
img_vol_0 = utils.percentile_clip(img_vol_0)
coord_size = [60, 60, 60]
coord_hr = utils.make_coord(coord_size, flatten=True)
coord_hr = torch.tensor(coord_hr).cuda().float()
text_tgt = lines_M1[0].replace('"', '')
text_tgt = text_tgt.strip((text_tgt.strip().split(':'))[0])
text_tgt = text_tgt.strip(text_tgt[0])
seq_tgt = tokenize(text_tgt, tokenizer).cuda()
with torch.no_grad():
    seq_tgt = model.encode_text(seq_tgt)
crop_size = (60, 60, 60)
scale0 = coord_size[0] / crop_size[0]
scale1 = coord_size[1] / crop_size[1]
scale2 = coord_size[2] / crop_size[2]
patches, _ = patch_slicer(img_vol_0, 0.5, crop_size, scale0, scale1, scale2)
dataset = PatchDataset(patches)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=True)

out_dir = "./save_img/nii/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir, exist_ok=True)

pred_0_1 = _get_pred(model_img, dataloader, coord_hr, seq_tgt, crop_size, out_dir)
pred_0_1 = [x.cpu().numpy() for x in pred_0_1]

  from .autonotebook import tqdm as notebook_tqdm
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  coord_hr = torch.tensor(coord_hr).cuda().float()
441it [06:44,  1.09it/s]


In [3]:
def create_weight_mask(patch_shape, overlap_ratio=0.5, sigma_factor=0.15):
    """
    Create a weight mask for blending patches smoothly.
    Uses a 3D Gaussian-based weighting to prioritize central voxels over edge voxels.
    
    Args:
        patch_shape: Tuple of (depth, height, width) for the patch
        overlap_ratio: Amount of overlap between patches
        sigma_factor: Controls the steepness of the weight falloff
        
    Returns:
        3D numpy array of same shape as patch with weights between 0 and 1
    """
    # Create coordinates relative to center
    z, y, x = np.indices(patch_shape)
    center = np.array([(d - 1) / 2 for d in patch_shape])
    
    # Calculate distance from center (normalized)
    z_norm = (z - center[0]) / (patch_shape[0]/2)
    y_norm = (y - center[1]) / (patch_shape[1]/2)
    x_norm = (x - center[2]) / (patch_shape[2]/2)
    
    # Combine for radial distance
    sigma = sigma_factor * (1 + 1/overlap_ratio)
    distance = np.sqrt(z_norm**2 + y_norm**2 + x_norm**2)
    
    # Convert to weights with Gaussian falloff
    weights = np.exp(-(distance**2) / (2 * sigma**2))
    
    return weights


def stitch_patches(patch_vol: list, patch_indices, output_shape, overlap_ratio=0.5):
    """
    Stitch together a set of overlapping patches into a single volume.
    
    Args:
        patch_vol: List of np.ndarray
        patch_indices: List of [z_start, z_end, y_start, y_end, x_start, x_end] for each patch
        output_shape: Tuple of (depth, height, width) for the final volume
        overlap_ratio: Amount of overlap between adjacent patches
        
    Returns:
        Stitched volume as a numpy array
    """
    # Initialize output volume and weight accumulator
    output_volume = np.zeros(output_shape, dtype=np.float32)
    weight_accumulator = np.zeros(output_shape, dtype=np.float32)
    
    print(f"Stitching {len(patch_vol)} patches into output shape {output_shape}...")
    
    for i, (patch_data, patch_idx) in enumerate(tqdm(zip(patch_vol, patch_indices), total=len(patch_vol))):
        
        # Get patch coordinates
        z_start, z_end, y_start, y_end, x_start, x_end = patch_idx
        
        # Ensure patches don't exceed output volume dimensions
        z_end = min(z_end, output_shape[0])
        y_end = min(y_end, output_shape[1])
        x_end = min(x_end, output_shape[2])
        
        # Get patch shape
        patch_shape = (z_end - z_start, y_end - y_start, x_end - x_start)
        
        # Check if loaded patch needs reshaping
        if patch_data.shape != patch_shape:
            # Reshape patch if needed
            print(f"Reshape patch if needed {patch_data.shape}->{patch_shape}")
            patch_data = patch_data[:patch_shape[0], :patch_shape[1], :patch_shape[2]]
            
        # Create weight mask for this patch
        weights = create_weight_mask(patch_shape, overlap_ratio)
        
        # Add weighted patch to output
        output_volume[z_start:z_end, y_start:y_end, x_start:x_end] += patch_data * weights
        weight_accumulator[z_start:z_end, y_start:y_end, x_start:x_end] += weights
        
    # Avoid division by zero
    mask = weight_accumulator > 0
    output_volume[mask] = output_volume[mask] / weight_accumulator[mask]
    
    return output_volume


def save_stitched_volume(output_volume, reference_file, output_path):
    """
    Save the stitched volume as a NIfTI file, copying metadata from a reference image.
    
    Args:
        output_volume: The stitched volume as numpy array
        reference_file: Path to a reference NIfTI file to copy metadata from
        output_path: Path where the stitched volume will be saved
    """
    # Read reference image to get metadata
    reference_img = sitk.ReadImage(reference_file)
    
    # Convert the stitched volume to an ITK image
    output_img = sitk.GetImageFromArray(output_volume)
    
    # Copy metadata from reference image
    output_img.SetSpacing(reference_img.GetSpacing())
    output_img.SetOrigin(reference_img.GetOrigin())
    output_img.SetDirection(reference_img.GetDirection())
    
    # Save the image
    sitk.WriteImage(output_img, output_path)
    print(f"Saved stitched volume to {output_path}")


def stitch(patch_vol: list, original_img_path=None):

    # Load the first patch to get dimensions
    first_patch = patch_vol[0]
    patch_size = first_patch.shape
    
    if original_img_path:
        original_img = sitk.ReadImage(original_img_path)
        output_shape = sitk.GetArrayFromImage(original_img).shape
    else:
        # If original image is not available, you need to load or reconstruct patch indices
        # For this example, we'll assume the output shape is known or can be determined
        # from your patch generation logic
        # Placeholder example:
        output_shape = (256, 256, 256)  # Replace with actual dimensions

    # Load or reconstruct patch indices
    # For this example, we'll calculate it from scratch using the pattern from the original code
    crop_size = patch_size
    overlap_ratio = 0.5
    
    # Calculate patch indices
    patch_indices = []
    patch_start_positions = calculate_patch_index(output_shape, crop_size, overlap_ratio)
    
    for pos in patch_start_positions:
        z_start, y_start, x_start = pos
        z_end = min(z_start + crop_size[0], output_shape[0])
        y_end = min(y_start + crop_size[1], output_shape[1])
        x_end = min(x_start + crop_size[2], output_shape[2])
        patch_indices.append([z_start, z_end, y_start, y_end, x_start, x_end])
    
    # Check if we have enough patches
    if len(patch_indices) != len(patch_vol):
        print(f"Warning: Number of patch files ({len(patch_vol)}) does not match calculated indices ({len(patch_indices)})")
        print("Using the minimum of the two")
        patch_vol = patch_vol[:min(len(patch_vol), len(patch_indices))]
        patch_indices = patch_indices[:min(len(patch_vol), len(patch_indices))]
    
    # Stitch patches together
    stitched_volume = stitch_patches(patch_vol, patch_indices, output_shape, overlap_ratio)
    return stitched_volume
   
reference_file = os.path.join(img_path_1, 'test_HCPD_T2w.nii.gz')
stitched_volume = stitch(pred_0_1, reference_file)
    # Save the stitched volume
    
output_path = os.path.join(out_dir, 'stitched_volume.nii.gz')
save_stitched_volume(stitched_volume, reference_file, output_path)

Stitching 441 patches into output shape (227, 272, 227)...


100%|██████████| 441/441 [00:02<00:00, 213.30it/s]


Saved stitched volume to ./save_img/nii/stitched_volume.nii.gz
