In [1]:
# # Clone xvr repo and move into the directory
# !git clone https://github.com/eigenvivek/xvr.git
# %cd xvr

# # Install PyTorch for CUDA 11.8 (Colab T4 GPU)
# !pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# # Install xvr and dependencies
# !pip install git+https://github.com/eigenvivek/xvr.git
!pip install wandb nibabel torchio tqdm





[notice] A new release of pip is available: 23.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import torch
import torchvision
import torchaudio
import nibabel as nib
import torchio
import tqdm
import wandb

# Check CUDA availability and GPU info
print("Torch CUDA available:", torch.cuda.is_available())
print("Torch version:", torch.__version__)
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("CUDA Version:", torch.version.cuda)

# Check xvr import
try:
    import xvr
    print("✅ xvr imported successfully")
except ImportError as e:
    print("❌ xvr import failed:", e)


Torch CUDA available: True
Torch version: 2.5.1+cu121
GPU Name: NVIDIA RTX 6000 Ada Generation
CUDA Version: 12.1
✅ xvr imported successfully


In [3]:
import os

# Set a value
os.environ["WANDB_API_KEY"] = "524773117955f64051f5e2bd2ecef2642a677a34"

In [4]:
# Fix wandb login inside xvr training script (do this only once)
!sed -i 's/wandb.login(key=os.environ\["WANDB_API_KEY"\])/wandb.login()/' /usr/local/lib/python3.*/dist-packages/xvr/commands/train.py

'sed' is not recognized as an internal or external command,
operable program or batch file.


In [5]:
import sys
import os

# Add the path to xvr/src so you can import xvr.*
sys.path.append(os.path.join(os.getcwd(), 'xvr', 'src'))


In [6]:
import sys
import os

# Step 1: Add xvr/src to Python path so imports work
sys.path.append(os.path.join(os.getcwd(), 'xvr', 'src'))

# Step 2: Patch wandb.login line in local train.py
# train_py = os.path.join(os.getcwd(), 'xvr', 'src', 'xvr', 'commands', 'train.py')
# finetune_py = os.path.join(os.getcwd(), 'xvr', 'src', 'xvr', 'commands', 'finetune.py')
register_py = os.path.join(os.getcwd(), 'xvr', 'src', 'xvr', 'commands', 'register.py')

if os.path.exists(register_py):
    with open(register_py, 'r') as f:
        lines = f.readlines()

    with open(register_py, 'w') as f:
        for line in lines:
            f.write(line.replace(
                'wandb.login(key=os.environ["WANDB_API_KEY"])',
                'wandb.login()'
            ))

    print(f"✅ Patched wandb.login() in: {register_py}")
else:
    print("❌ register.py not found. Check the path.")


✅ Patched wandb.login() in: C:\Users\aksha\Robossis\xvr\src\xvr\commands\register.py


In [11]:
import pydicom

# Load your C-arm DICOM file
filename = "C_Arm/proximal.dcm"  # Update this path if needed
ds = pydicom.dcmread(filename)

# Set the missing attribute
ds.DistanceSourceToDetector = 1000  # Replace 1000 with your actual sdd value in mm

# Optional: make sure PixelSpacing is present and valid
if "PixelSpacing" not in ds:
    ds.PixelSpacing = [0.316, 0.316]  # based on your delx

# Save the updated DICOM
ds.save_as(filename)
print("DICOM updated with SDD and saved.")


DICOM updated with SDD and saved.


In [12]:
import subprocess

# --- Fix DICOM image first ---
import pydicom

# Load and update DICOM metadata
dcm_path = "C_Arm/distal.dcm"
ds = pydicom.dcmread(dcm_path)

# Forcefully add or overwrite key fields
ds.DistanceSourceToDetector = 1000             # in mm
ds.PixelSpacing = [0.316, 0.316]               # in mm (horizontal, vertical)

# Save to a new file to avoid corrupting the original
fixed_dcm_path = "C_Arm/distal.dcm"
ds.save_as(fixed_dcm_path)
print("✅ DICOM updated and saved as:", fixed_dcm_path)

# --- Run registration ---
input_path = "SE000002.nii"                    # CT volume path
output_path = "output_register"                # Where results will go
ckpt_path = "output_trained/RXVR1.1_best.pth"          # Finetuned model

command = [
    "xvr", "register", "model", fixed_dcm_path,
    "--volume", input_path,
    "--ckptpath", ckpt_path,
    "--outpath", output_path,

    "--reverse_x_axis",
    "--saveimg",
    "--crop", "0",
    "--scales", "8",
    "--parameterization", "euler_angles",
    "--convention", "ZXY",
    
    # KEY: Try boosting translation learning rate
    "--lr_rot", "0.01",
    "--lr_xyz", "3.0",                 # 🚀 More freedom in translation

    "--patience", "10",
    "--threshold", "0.0001",
    "--max_n_itrs", "800",            # 🔁 More optimization time
    "--max_n_plateaus", "3",
    "--renderer", "trilinear",
    "--verbose", "2"
]


# Handle Unicode tqdm.write → fallback to UTF-8 print
import os
os.environ["PYTHONIOENCODING"] = "utf-8"  # Ensures subprocess uses UTF-8

# Run registration and print output
import io

# Use binary mode and wrap with UTF-8 reader
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

# Explicitly decode as UTF-8 to avoid Unicode issues
with io.TextIOWrapper(process.stdout, encoding="utf-8", errors="replace") as stdout:
    for line in stdout:
        print(line, end='')


✅ DICOM updated and saved as: C_Arm/distal.dcm

Registering C_Arm\distal.dcm ...

Stage 1:   0%|                                                              | 0/800 [00:00<?, ?it/s]
Stage 1:   0%|                                                 | 0/800 [00:00<?, ?it/s, ncc = 0.030]
                                                                                                    

Stage 1:   0%|                                                 | 0/800 [00:00<?, ?it/s, ncc = 0.030]
Stage 1:   0%|                                         | 1/800 [00:00<02:29,  5.33it/s, ncc = 0.030]
Stage 1:   0%|                                         | 1/800 [00:00<02:29,  5.33it/s, ncc = 0.047]
Stage 1:   0%|                                         | 1/800 [00:00<02:29,  5.33it/s, ncc = 0.054]
Stage 1:   0%|                                         | 1/800 [00:00<02:29,  5.33it/s, ncc = 0.070]
Stage 1:   0%|                                         | 1/800 [00:00<02:29,  5.33it/s, ncc = 0.072]
Stage 1:

In [14]:
# import os
# import subprocess
# import torch
# from PIL import Image
# import torchvision.transforms.functional as TF
# import pydicom
# import shutil # For creating directories safely

# # Define paths
# DICOM_PATH = "C_Arm/distal.dcm"
# FIXED_DICOM_PATH = DICOM_PATH # Overwrite the same file
# CT_VOLUME_PATH = "SE000002.nii"
# OUTPUT_DIR = "output_register"
# CHECKPOINT_PATH = "output_trained/RXVR1.1_best.pth"

# # Ensure output directory exists
# os.makedirs(os.path.join(OUTPUT_DIR, "distal"), exist_ok=True)

# # ========== 1. Fix DICOM Metadata ==========
# print("========== Step 1: Fixing DICOM Metadata ==========")
# try:
#     ds = pydicom.dcmread(DICOM_PATH)
#     ds.DistanceSourceToDetector = 1000  # mm
#     ds.PixelSpacing = [0.316, 0.316]    # mm
#     ds.save_as(FIXED_DICOM_PATH)
#     print(f"✅ DICOM updated and saved as: {FIXED_DICOM_PATH}")
# except Exception as e:
#     print(f"❌ Error fixing DICOM metadata: {e}")
#     # Exit or handle error appropriately if DICOM fix is critical
#     exit()

# # ========== 2. Run Registration ==========
# print("\n========== Step 2: Running Registration ==========")
# registration_command = [
#     "xvr", "register", "model", FIXED_DICOM_PATH,
#     "--volume", CT_VOLUME_PATH,
#     "--ckptpath", CHECKPOINT_PATH,
#     "--outpath", OUTPUT_DIR,
#     "--reverse_x_axis",
#     "--saveimg",
#     "--crop", "0",
#     "--scales", "8",
#     "--parameterization", "euler_angles",
#     "--convention", "ZXY",
#     "--lr_rot", "0.01",
#     "--lr_xyz", "3.0",
#     "--patience", "10",
#     "--threshold", "0.0001",
#     "--max_n_itrs", "800",
#     "--max_n_plateaus", "3",
#     "--renderer", "trilinear",
#     "--verbose", "2"
# ]

# os.environ["PYTHONIOENCODING"] = "utf-8"

# try:
#     process = subprocess.Popen(
#         registration_command,
#         stdout=subprocess.PIPE,
#         stderr=subprocess.STDOUT,
#         text=True,
#         encoding="utf-8", # Specify encoding for consistent output
#         errors="ignore" # Ignore errors for characters that can't be decoded
#     )
#     print("--- xvR Registration Output ---")
#     for line in process.stdout:
#         print(line, end='')
#     process.wait() # Wait for the process to complete
#     if process.returncode == 0:
#         print("✅ xvR Registration completed successfully.")
#     else:
#         print(f"❌ xvR Registration failed with exit code {process.returncode}.")
#         # Handle registration failure, e.g., exit
#         exit()
# except FileNotFoundError:
#     print(f"❌ Error: 'xvr' command not found. Make sure xvR is installed and in your system's PATH.")
#     exit()
# except Exception as e:
#     print(f"❌ An unexpected error occurred during registration: {e}")
#     exit()

# # ========== 3. Read Pose ==========
# print("\n========== Step 3: Reading Pose ==========")
# pose_file_path = os.path.join(OUTPUT_DIR, "distal", "parameters.pt")
# try:
#     pose = torch.load(pose_file_path)
#     print("🧊 Original Final Pose Matrix:\n", pose["final_pose"][0])
#     print("✅ Pose loaded successfully.")
# except FileNotFoundError:
#     print(f"❌ Error: Pose file not found at {pose_file_path}. Registration might have failed or not generated it.")
#     exit()
# except Exception as e:
#     print(f"❌ Error loading pose: {e}")
#     exit()

# # ========== 4. Visual Shift of the Final Image Downward (Post-processing for display) ==========
# # This does NOT affect the 3D registration. It's a pure 2D image manipulation.
# print("\n========== Step 4: Applying Visual Shift to Final Image ==========")
# original_img_path = os.path.join(OUTPUT_DIR, "distal", "final_img.png")
# try:
#     img = Image.open(original_img_path).convert("L") # Ensure grayscale
#     width, height = img.size

#     # --- Option 1: Pad at the top (original image content fully retained, image size increases) ---
#     top_padding_safe = 40 # pixels
#     padded_img_safe = Image.new("L", (width, height + top_padding_safe), color=0) # Black background
#     padded_img_safe.paste(img, (0, top_padding_safe))
#     safe_corrected_img_path = os.path.join(OUTPUT_DIR, "distal", "final_img_corrected_safe.png")
#     padded_img_safe.save(safe_corrected_img_path)
#     print(f"✅ Saved 'final_img_corrected_safe.png' (shifted down with padding at top).")

#     # --- Option 2: Shift down within original dimensions (top becomes black, bottom content cropped) ---
#     # This is what you likely intended for 'final_img_corrected.png'
#     shift_amount_visual = 40 # pixels
#     if shift_amount_visual >= height:
#         print("⚠️ Warning: Shift amount is too large, image will be entirely black.")
#         shifted_img_corrected = Image.new("L", (width, height), color=0)
#     else:
#         # Create a new blank image
#         shifted_img_corrected = Image.new("L", (width, height), color=0)
#         # Paste the original image's top part into the new image, shifted down
#         # The region to paste is from (0, 0) to (width, height - shift_amount_visual)
#         # It will be pasted starting at (0, shift_amount_visual)
#         cropped_region = img.crop((0, 0, width, height - shift_amount_visual))
#         shifted_img_corrected.paste(cropped_region, (0, shift_amount_visual))

#     final_corrected_img_path = os.path.join(OUTPUT_DIR, "distal", "final_img_corrected.png")
#     shifted_img_corrected.save(final_corrected_img_path)
#     print(f"✅ Saved 'final_img_corrected.png' (visually shifted down, cropped at bottom).")

# except FileNotFoundError:
#     print(f"❌ Error: Original image 'final_img.png' not found at {original_img_path}. Registration might not have generated it.")
# except Exception as e:
#     print(f"❌ Error processing images: {e}")

# # ========== 5. Adjust Pose Translation (for 3D correction) ==========
# print("\n========== Step 5: Adjusting Pose Translation ==========")
# pose_corrected_path = os.path.join(OUTPUT_DIR, "distal", "parameters_corrected.pt")

# try:
#     # IMPORTANT: Deep copy the pose if you intend to keep the original pose intact
#     # If you only want to modify the loaded 'pose' object, a deep copy is not strictly necessary
#     # but good practice if you plan further operations on the original pose.
#     # For this script, direct modification of 'pose' is fine.
    
#     # Adjust Y translation (ty) by +10.0mm.
#     # Note: The visual effect in the 2D image depends on your coordinate system convention.
#     # If positive Y in 3D maps to positive Y (downwards) in 2D image space, this will make the object appear lower.
#     # If positive Y in 3D maps to negative Y (upwards) in 2D image space, this will make the object appear higher.
#     pose["final_pose"][0, 1, 3] += 10.0
#     torch.save(pose, pose_corrected_path)
#     print("✅ Saved corrected pose to 'parameters_corrected.pt'.")
#     print("🧊 Corrected Final Pose Matrix:\n", pose["final_pose"][0])

# except Exception as e:
#     print(f"❌ Error adjusting or saving corrected pose: {e}")

# # ========== 6. (OPTIONAL) Re-render with Corrected Pose ==========
# print("\n========== Step 6: (OPTIONAL) Re-rendering with Corrected Pose ==========")
# print("To see the visual effect of the 'parameters_corrected.pt' on a new X-ray image,")
# print("you would typically need to run an 'xvr render' command or a similar function,")
# print("passing the CT volume, the original DICOM/camera parameters, and the 'parameters_corrected.pt'.")
# print("Example (hypothetical, replace with actual xvr render command if available):")
# print(f"  xvr render --volume {CT_VOLUME_PATH} --dicom {FIXED_DICOM_PATH} --pose {pose_corrected_path} --outimg {os.path.join(OUTPUT_DIR, 'distal', 'final_img_re_rendered.png')}")
# print("Please consult the xvR documentation for the exact command to render an image from a given pose.")