In [7]:
import torch
def init_random_2d_freqs(head_dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
    freqs_x = []
    freqs_y = []
    theta = theta
    mag = 1 / (theta ** (torch.arange(0, head_dim, 4)[: (head_dim // 4)].float() / head_dim))
    for i in range(num_heads):
        angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
        fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
        fy = torch.cat([mag * torch.sin(angles + torch.pi/12), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
        freqs_x.append(fx)
        freqs_y.append(fy)
    freqs_x = torch.stack(freqs_x, dim=0)
    freqs_y = torch.stack(freqs_y, dim=0)
    freqs = torch.stack([freqs_x, freqs_y], dim=0)
    return freqs

def init_random_3d_freqs(head_dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
    """
    Initialize frequency parameters for 3D rotary embeddings.
    
    Assumes that head_dim is divisible by 6 so that the rotary sub-dimension per axis is (head_dim // 6)*2,
    and that we want three sets of phase shifts per axis.
    
    For each axis, we generate three phase shifts:
      - 0 offset:      cos(angle)
      - π/2 offset:    cos(π/2 + angle)
      - π offset:      cos(π + angle)
    Adjust these offsets as needed.
    """
    freqs_x, freqs_y, freqs_z = [], [], []
    # For each axis, we want three groups, so determine the number of frequency pairs per group.
    num_pairs = head_dim // 6  # each pair (cos, sin) becomes one complex number
    
    # Create a magnitude vector of length num_pairs
    mag = 1 / (theta ** (torch.arange(num_pairs, dtype=torch.float32) / num_pairs))
    
    for _ in range(num_heads):
        # Generate axis-specific random angles (or zeros if rotation is disabled)
        angle_x = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
        angle_y = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
        angle_z = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
        
        # For each axis, create three sets of frequency components.
        fx = torch.cat([
            mag * torch.cos(angle_x),
            mag * torch.cos(torch.pi/2 + angle_x),
            mag * torch.cos(torch.pi + angle_x)
        ], dim=-1)
        fy = torch.cat([
            mag * torch.cos(angle_y),
            mag * torch.cos(torch.pi/2 + angle_y),
            mag * torch.cos(torch.pi + angle_y)
        ], dim=-1)
        fz = torch.cat([
            mag * torch.cos(angle_z),
            mag * torch.cos(torch.pi/2 + angle_z),
            mag * torch.cos(torch.pi + angle_z)
        ], dim=-1)
        
        freqs_x.append(fx)
        freqs_y.append(fy)
        freqs_z.append(fz)
    
    # Stack into a frequency tensor with shape [3, num_heads, 3*num_pairs]
    freqs_x = torch.stack(freqs_x, dim=0)
    freqs_y = torch.stack(freqs_y, dim=0)
    freqs_z = torch.stack(freqs_z, dim=0)
    freqs = torch.stack([freqs_x, freqs_y, freqs_z], dim=0)
    
    return freqs

def compute_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, t_z: torch.Tensor = None):
    N = t_x.shape[0]
    # No float 16 for this range
    with torch.amp.autocast('cuda', enabled=False):
        freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
        freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
        if t_z != None:
            freqs_z = (t_z.unsqueeze(-1) @ freqs[2].unsqueeze(-2))
            freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y + freqs_z)
            
            a = (freqs_x + freqs_y + freqs_z)
            x = freqs_x.flatten()
            y = freqs_y.flatten()
            z = freqs_z.flatten()
            # for i in range(a.shape[1]):
            #     # print([x[i], y[i], z[i]])
            #     print('B',a[0, i, :])
        else:
            a = (freqs_x + freqs_y)
            x = freqs_x.flatten()
            y = freqs_y.flatten()
            # for i in range(a.shape[1]):
            #     # print([x[i], y[i], z[i]])
            #     print('B',a[0, i, :])
            freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)

    return freqs_cis

def init_t_xy(end_x: int, end_y: int, zero_center=False):
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode='floor').float()
    
    return t_x, t_y

def init_t_xyz(end_x: int, end_y: int, end_z: int, zero_center=False):
    t = torch.arange(end_x * end_y * end_z, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = ((t // end_x) % end_y).float()  # Compute y-axis
    t_z = (t // (end_x * end_y)).float()  # Compute z-axis
    return t_x, t_y, t_z

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    print(freqs_cis.shape, x.shape)
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-4 else 1 for i, d in enumerate(x.shape)]
    else:
        raise ValueError("freqs_cis shape does not match expected dimensions.")
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    with torch.amp.autocast('cuda', enabled=False):
        xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
        xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
        print('B',xq_.shape)
        freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

In [None]:
import torch
# Load the model checkpoint
checkpoint = torch.load('./logs/2025-02-27T02-14-17_swlin_unetr_btcv_no_rope/checkpoints/last.ckpt')
# # Extract the state_dict
if 'state_dict' in checkpoint:    
    state_dict = checkpoint['state_dict']
else:    
    state_dict = checkpoint# Save

def print_nested_keys(d, parent_key=""):
    """ Recursively print nested keys in a checkpoint """
    if isinstance(d, dict):
        for key in d:
            full_key = f"{parent_key}.{key}" if parent_key else key
            print(full_key)
            print_nested_keys(d[key], full_key)
    elif isinstance(d, list):
        print(f"{parent_key} -> List of length {len(d)}")
    else:
        print(f"{parent_key} -> {type(d)}")

print("Checkpoint Structure:")
print_nested_keys(checkpoint)


In [1]:
def compute_brats_metrics(dice_scores):
    """
    Compute BraTS segmentation metrics: Whole Tumor (WT), Tumor Core (TC), Enhancing Tumor (ET).

    Args:
        dice_scores (list or tensor): List of Dice scores for each tumor class.
            - dice_scores[0] -> Dice_Necrotic (NCR)
            - dice_scores[1] -> Dice_Edema (ED)
            - dice_scores[2] -> Dice_Enhancing Tumor (ET)

    Returns:
        dict: Dictionary containing BraTS metrics:
              - Dice_WT (Whole Tumor)
              - Dice_TC (Tumor Core)
              - Dice_ET (Enhancing Tumor)
    """
    Dice_Necrotic = dice_scores[0]
    Dice_Edema = dice_scores[1]
    Dice_Enhancing = dice_scores[2]

    # Compute BraTS Dice metrics
    Dice_WT = (Dice_Necrotic + Dice_Edema + Dice_Enhancing) / 3  # Whole Tumor
    Dice_TC = (Dice_Necrotic + Dice_Enhancing) / 2  # Tumor Core
    Dice_ET = Dice_Enhancing  # Enhancing Tumor

    return {
        "Dice_Necrotic": Dice_Necrotic,
        "Dice_Edema": Dice_Edema,
        "Dice_Enhancing": Dice_Enhancing,
        "Dice_WT": Dice_WT,
        "Dice_TC": Dice_TC,
        "Dice_ET": Dice_ET,
    }

print(compute_brats_metrics([89.47, 91.14, 86.88]))

{'Dice_Necrotic': 89.47, 'Dice_Edema': 91.14, 'Dice_Enhancing': 86.88, 'Dice_WT': 89.16333333333334, 'Dice_TC': 88.175, 'Dice_ET': 86.88}


In [None]:
head_dim = 12  # Dimensionality per head
num_heads = 2  # Number of attention heads
end_x, end_y, end_z = 4, 4, 4 # Grid size
batch_size = 2  # Example batch size
seq_len = end_x * end_y * end_z # Sequence length (tokens in spatial grid)

# Step 2: Initialize 2D rotary frequencies
freqs = init_random_3d_freqs(head_dim, num_heads, theta=100)
# Step 3: Generate token positions
t_x, t_y, t_z = init_t_xyz(end_x, end_y, end_z)

# Step 4: Compute complex frequency embeddings
freqs_cis = compute_cis(freqs, t_x, t_y, t_z)
# print(freqs.shape)

# Step 5: Generate random query and key tensors
# Shape: (batch, seq_len, num_heads, head_dim)
xq = torch.ones(batch_size, num_heads, seq_len, head_dim, dtype=torch.float32)
xk = torch.ones(batch_size, num_heads, seq_len, head_dim, dtype=torch.float32)

print(xq.shape, xk.shape, freqs_cis.shape)
# Step 6: Apply rotary embeddings
xq_out, xk_out = apply_rotary_emb(xq.to(torch.bfloat16), xk.to(torch.bfloat16), freqs_cis)
print(xq_out, xk_out)
A = torch.softmax(xq_out @ xk_out.transpose(-1, -2), -1)
# for i in range(seq_len):
#     print('A', A[0, 0, i, :])

orig_norm = torch.norm(xq.to(torch.bfloat16), dim=-1)
rot_norm = torch.norm(xq_out, dim=-1)
print("Original norm stats: ", orig_norm.mean().item(), orig_norm.std().item())
print("Rotated norm stats: ", rot_norm.mean().item(), rot_norm.std().item())
print("xq_out sample:", xq_out[0,0,1,:])
print("xk_out sample:", xk_out[0,0,1,:])
row_sums = A.sum(dim=-1)
print("Row sums (should be 1):", row_sums)
# Step 7: Print results
print(f"Original xq shape: {xq.shape}, xq_out shape: {xq_out.shape}")
print(f"Original xk shape: {xk.shape}, xk_out shape: {xk_out.shape}")

# Check if output shapes match input shapes
assert xq_out.shape == xq.shape, "xq_out shape mismatch!"
assert xk_out.shape == xk.shape, "xk_out shape mismatch!"

print("✅ Pipeline executed successfully!")

In [None]:
head_dim = 4  # Dimensionality per head
num_heads = 2  # Number of attention heads
end_x, end_y = 2, 2 # Grid size
batch_size = 1  # Example batch size
seq_len = end_x * end_y # Sequence length (tokens in spatial grid)

# Step 2: Initialize 2D rotary frequencies
freqs = init_random_2d_freqs(head_dim, num_heads, theta=100)
# Step 3: Generate token positions
t_x, t_y= init_t_xy(end_x, end_y)

# Step 4: Compute complex frequency embeddings
freqs_cis = compute_cis(freqs, t_x, t_y)
print(freqs.shape)

# Step 5: Generate random query and key tensors
# Shape: (batch, seq_len, num_heads, head_dim)
xq = torch.ones(batch_size, num_heads, seq_len, head_dim, dtype=torch.float32)
xk = torch.ones(batch_size, num_heads, seq_len, head_dim, dtype=torch.float32)

print(xq.shape, xk.shape, freqs_cis.shape)
# Step 6: Apply rotary embeddings
xq_out, xk_out = apply_rotary_emb(xq.to(torch.bfloat16), xk.to(torch.bfloat16), freqs_cis.to(torch.bfloat16))
print(xq_out.shape, xk_out.shape, xq.shape, xk.shape)
A = torch.softmax(xq_out @ xk_out.transpose(-1, -2), -1)
for i in range(seq_len):
    print('A', A[0, 0, i, :])

# Step 7: Print results
print(f"Original xq shape: {xq.shape}, xq_out shape: {xq_out.shape}")
print(f"Original xk shape: {xk.shape}, xk_out shape: {xk_out.shape}")

# Check if output shapes match input shapes
assert xq_out.shape == xq.shape, "xq_out shape mismatch!"
assert xk_out.shape == xk.shape, "xk_out shape mismatch!"

print("✅ Pipeline executed successfully!")