In [1]:
import torch
import lpips
import numpy as np

def compute_lpips_3d(prediction, ground_truth, max_val = None, min_val = None, net_type='vgg'):
    assert prediction.shape == ground_truth.shape, "Shape mismatch between prediction and ground truth!"
    
    # Convert to float32
    prediction = prediction.astype(np.float32)
    ground_truth = ground_truth.astype(np.float32)

    # Normalize to [-1, 1] range as required by LPIPS
    if max_val == None:
        prediction = (prediction - prediction.min()) / (prediction.max() - prediction.min()) * 2 - 1 
    else:
        prediction = (prediction - min_val) / (max_val - min_val) * 2 - 1
    if max_val == None:
        ground_truth = (ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min()) * 2 - 1
    else:
        ground_truth = (ground_truth - min_val) / (max_val - min_val) * 2 - 1

    # Initialize LPIPS loss model
    loss_fn = lpips.LPIPS(net=net_type).to('cuda' if torch.cuda.is_available() else 'cpu')

    lpips_scores = []
    
    # Loop through each slice along the z-axis
    for i in range(prediction.shape[2]):
        pred_slice = prediction[:, :, i]  # Get 2D slice
        gt_slice = ground_truth[:, :, i]  # Get corresponding GT slice

        # Convert numpy arrays to torch tensors
        pred_tensor = torch.tensor(pred_slice).unsqueeze(0).unsqueeze(0)  # Shape: [1,1,H,W]
        gt_tensor = torch.tensor(gt_slice).unsqueeze(0).unsqueeze(0)

        # Move to GPU if available
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pred_tensor = pred_tensor.to(device)
        gt_tensor = gt_tensor.to(device)
        loss_fn = loss_fn.to(device)

        # Compute LPIPS score for this slice
        lpips_score = loss_fn(pred_tensor, gt_tensor)
        lpips_scores.append(lpips_score.item())

    # Compute average LPIPS score across all slices
    return np.mean(lpips_scores)

# Example usage:
prediction = np.random.rand(128, 128, 32)  # Example 3D image (128x128x32)
ground_truth = np.random.rand(128, 128, 32)  # Example 3D ground truth

lpips_score = compute_lpips_3d(prediction, ground_truth)
print(f"LPIPS Score: {lpips_score:.4f}")


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:42<00:00, 12.9MB/s] 


Loading model from: /usr/local/lib/python3.8/dist-packages/lpips/weights/v0.1/vgg.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


LPIPS Score: 0.3063
