In [15]:
import torch
import torch.nn as nn



import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.blocks import TransformerBlock
import torch
import torch.nn as nn
from monai.networks.blocks import TransformerBlock

class ShiftedWindowCrossAttention(nn.Module):
    def __init__(self, in_channels, num_heads=8, window_size=4, shift_size=2, dropout_rate=0.1):
        """
        Shifted Window Cross Attention on Channel Dimension
        Args:
            in_channels: Input channels
            num_heads: Number of attention heads
            window_size: Size of each window in channel dimension
            shift_size: Size to shift the windows
            dropout_rate: Dropout rate
        """
        super(ShiftedWindowCrossAttention, self).__init__()
        self.in_channels = in_channels
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        
        # Change hidden_size to match the flattened spatial dimensions
        self.attn = TransformerBlock(
            hidden_size=in_channels,  # Changed back to in_channels
            num_heads=num_heads,
            mlp_dim=in_channels * 4,
            qkv_bias=False,
            with_cross_attention=True,
            use_flash_attention=True,
            dropout_rate=dropout_rate,
        )

    def forward(self, q_feat, k_feat):
        """
        Args:
            q_feat: Features from modality 1 (B, C, D, H, W)
            k_feat: Features from modality 2 (B, C, D, H, W)
        Returns:
            Output features (B, C, D, H, W)
        """
        B, C, D, H, W = q_feat.shape
        
        # Create windows along channel dimension
        q_windows = q_feat.view(B, -1, self.window_size, D, H, W)  # (B, num_windows, window_size, D, H, W)
        k_windows = k_feat.view(B, -1, self.window_size, D, H, W)
        
        # Apply cyclic shift along channel dimension
        q_shifted = torch.roll(q_windows, shifts=-self.shift_size, dims=2)
        k_shifted = torch.roll(k_windows, shifts=-self.shift_size, dims=2)
        
        # Reshape for attention
        q = q_shifted.flatten(3)  # (B, num_windows, window_size, D*H*W)
        k = k_shifted.flatten(3)
        
        # Process each window
        outputs = []
        for i in range(q.size(1)):
            # Reshape to match TransformerBlock expectations
            q_window = q[:, i].transpose(1, 2)  # (B, D*H*W, window_size)
            k_window = k[:, i].transpose(1, 2)  # (B, D*H*W, window_size)
            
            # Project to match hidden_size
            q_window = F.linear(q_window, torch.eye(self.window_size, device=q_window.device))
            k_window = F.linear(k_window, torch.eye(self.window_size, device=k_window.device))
            
            out_window = self.attn(q_window, k_window)  # Process each window separately
            outputs.append(out_window.transpose(1, 2))  # Transpose back
        
        out = torch.stack(outputs, dim=1)  # (B, num_windows, window_size, D*H*W)
        
        # Reshape back and reverse shift
        out = out.view(B, -1, self.window_size, D, H, W)
        out = torch.roll(out, shifts=self.shift_size, dims=2)
        out = out.reshape(B, C, D, H, W)
        
        return out

channels = 1024
input = torch.randn(1,1024,3,3,3)
groups = 4

split_x1 = torch.chunk(input, groups, dim=1) # [B, C/groups, D, H, W]
split_x2 = torch.chunk(input, groups, dim=1)
cross_attn = ShiftedWindowCrossAttention(
    in_channels=1024 // groups, 
    num_heads=2, 
    window_size=64,  # Window size in channel dimension
    shift_size=32,   # Shift size
    dropout_rate=0.1
)
outputs = []
for q, k in zip(split_x1, split_x2):
    outputs.append(cross_attn(q, k))
output = torch.cat(outputs, dim=1)  # (B, C, D, H, W)
print(output.shape)  # 应为 (1, 1024, 3, 3, 3)

RuntimeError: Given normalized_shape=[256], expected input with shape [*, 256], but got input of size[1, 27, 64]

In [25]:
import os
from pathlib import Path
check_dir = '/Users/keyi/Desktop/test_src/fold_0'
save_dir = "/".join(check_dir.split("/")) + "/PRED"
os.makedirs(save_dir, exist_ok=True)
path = os.path.join(save_dir, "CHUM-007_PRED.nii.gz")
print(path)

def get_unique_filename(base_path: str, filename: str) -> str:
    """
    Generate a unique filename by adding a counter if the file already exists.
    
    Args:
        base_path: Directory path where the file will be saved
        filename: Original filename with extension
        
    Returns:
        str: A unique filename that doesn't exist in the base_path
        
    Example:
        >>> get_unique_filename('/path/to/dir', 'image.nii.gz')
        'image.nii.gz'  # If file doesn't exist
        'image_1.nii.gz'  # If 'image.nii.gz' exists
    """
    file_path = Path(base_path) / filename
    
    # Return original filename if it doesn't exist
    if not file_path.exists():
        return filename
        
    # Split filename into name and extension
    stem = file_path.stem  # Gets filename without extension
    suffix = file_path.suffix  # Gets extension including the dot
    
    # Handle double extensions (e.g., .nii.gz)
    if stem.endswith('.nii') and suffix == '.gz':
        stem = stem[:-4]  # Remove '.nii'
        suffix = '.nii.gz'
    
    # Try incremental numbers until a unique filename is found
    counter = 1
    while (Path(base_path) / f"{stem}_{counter}{suffix}").exists():
        counter += 1
        
    return f"{stem}_{counter}{suffix}"
save_dir = Path('/Users/keyi/Desktop/test_src/fold_0') / "PRED"
filename = get_unique_filename(save_dir, "PRED_CHUM-007.nii.gz")
with open(os.path.join(save_dir, filename), 'w') as f:
    f.write(str("ok"))
    

/Users/keyi/Desktop/test_src/fold_0/PRED/CHUM-007_PRED.nii.gz


In [20]:
import json
src = "FDG-PET-CT-Lesions/PETCT_0dbf2c2731/10-27-2005-NA-PET-CT Ganzkoerper  primaer mit KM-07954/CTres.nii.gz"
person_uid = src.split("/")[1]
print(person_uid)
json_dir = '/Users/keyi/Desktop/DL_template/_assets/split_json/AutoPET_Cluster'
for file in os.listdir(json_dir):
    if file.endswith(".json"):
        with open(os.path.join(json_dir, file), 'r') as f:
            data = json.load(f)
            for item in data["VALIDATION"]:
                person_uid = item["CTRES"].split("/")[0]
                item["PERSON_UID"] = person_uid
        
            
        with open(os.path.join(json_dir, file), 'w') as f:
            json.dump(data, f, indent=4)

PETCT_0dbf2c2731


In [28]:
save_dir = Path('/home/kit/anthropomatik/ew2572/trained_old/2_channels/baseline/mednext/fold_0/fold=0-epoch=299-step=26700-Dice=0.9258.ckpt').parent
print(save_dir)

/home/kit/anthropomatik/ew2572/trained_old/2_channels/baseline/mednext/fold_0


In [16]:
import numpy as np
volume = np.prod([2.04, 2.04, 3.0][1:4]) / 1000
print(volume)

import monai.transforms as mt
pet_file = '/Users/keyi/Desktop/MA/AutoPet_Anatomy/FDG-PET-CT-Lesions/08-10-2001-NA-PET-CT Ganzkoerper  primaer mit KM-23662/SUV.nii.gz'
reference = mt.LoadImage()(pet_file)
print(reference.shape)
print(reference.meta["pixdim"])
print(np.prod([2.04, 2.04, 3.0])/1000)
# volume = np.prod(reference.meta["pixdim"][1:4]) / 1000
print(volume)

0.0061200000000000004
torch.Size([400, 400, 284])
[-1.       2.03642  2.03642  3.       1.       1.       1.       1.     ]
0.0124848
0.0061200000000000004


In [15]:
class ShiftedGroupAttention(nn.Module):
    def __init__(self, in_channels, num_heads=4, groups=4):
        super().__init__()
        self.groups = groups
        self.channels_per_group = in_channels // groups
        
        self.window_attention = TransformerBlock(
            hidden_size=self.channels_per_group,
            num_heads=num_heads,
            mlp_dim=self.channels_per_group * 4,
            qkv_bias=True,
            with_cross_attention=True,
            use_flash_attention=True,
            dropout_rate=0.1,
        )
        
        self.shift_sizes = [1, 2]
        
    def forward(self, x1, x2):
        B, C, D, H, W = x1.shape
        outputs = []
        
  
        x1_groups = torch.chunk(x1, self.groups, dim=1)  # [B, C/groups, D, H, W]
        x2_groups = torch.chunk(x2, self.groups, dim=1)
        
    
        for g_idx in range(self.groups):
            q = x1_groups[g_idx].flatten(2).transpose(1, 2)  # [B, D*H*W, C/groups]
            k = x2_groups[g_idx].flatten(2).transpose(1, 2)
            out = self.window_attention(q, k).transpose(1, 2).view(B, -1, D, H, W)
            outputs.append(out)
        

        for shift_size in self.shift_sizes:
  
            x1_shift = torch.roll(x1, shifts=self.channels_per_group * shift_size, dims=1)
            x2_shift = torch.roll(x2, shifts=self.channels_per_group * shift_size, dims=1)
            
            x1_shift_groups = torch.chunk(x1_shift, self.groups, dim=1)
            x2_shift_groups = torch.chunk(x2_shift, self.groups, dim=1)
            
            for g_idx in range(self.groups):
                q = x1_shift_groups[g_idx].flatten(2).transpose(1, 2)
                k = x2_shift_groups[g_idx].flatten(2).transpose(1, 2)
                out = self.window_attention(q, k).transpose(1, 2).view(B, -1, D, H, W)
                outputs.append(out)
        
    
        outputs = torch.stack([torch.cat(outputs[i:i+self.groups], dim=1) 
                             for i in range(0, len(outputs), self.groups)], dim=0)
        return outputs.mean(0)  # [B, C, D, H, W]

x1 = torch.randn(12, 1024, 3, 3, 3)  
x2 = torch.randn(12, 1024, 3, 3, 3)
out = ShiftedGroupAttention(1024, 4, 4)(x1, x2)
print(out.shape)

torch.Size([12, 1024, 3, 3, 3])
