- Takes as input the (768, 768, 64) images from `dataset-1`.
- It takes 4^3 average across the 2D, and average across the depth
- This results in a (192, 192, 16) volume.
- I have a image that is of shape (1, 192, 192, 16). Divide the image and the labels into 32 patches: (1, 8*24, 8*24, 16) → (24, 24, 1024) → (576, 1024). Pass the image into a transformer architecture with 6 layers, which will return (576, 1) tensor, which corresponds to probabilities of pixel occuring. You should also pass the “positional embeddings”, which should be of dimension 1024.

# 4^3 3D average 


In [46]:
import numpy as np
import torch
import torch.nn as nn

def avg_3d(volume):
    # Convert the numpy array to a PyTorch tensor
    volume_tensor = torch.tensor(volume, dtype=torch.float32)

    # Add batch and channel dimensions to the tensor
    volume_tensor = volume_tensor.permute(0, 3, 1, 2)  # Reorder dimensions to (batch, channels, height, width)

    # Create the 3D average pooling layer with the appropriate kernel size and stride values
    avg_pool = nn.AvgPool3d(kernel_size=4, stride=4, padding=0)

    # Apply the average pooling layer to the input tensor
    with torch.no_grad():
        filtered_volume_tensor = avg_pool(volume_tensor)

    # Convert the output tensor back to a numpy array
    filtered_volume = filtered_volume_tensor.permute(0, 2, 3, 1).numpy()  # Reorder dimensions back to (batch, height, width, channels)
    
    return filtered_volume


# Example volume
volume = np.random.rand(1, 768, 768, 64)
volume_avgd = avg_3d(volume)
# %time avg_3d(volume) # 85 ms!

In [47]:
volume_avgd.shape

(1, 192, 192, 16)

# Reshape

In [74]:
import numpy as np

# Example image
image = np.random.rand(2, 192, 192, 16)
image = torch.tensor(image, dtype=torch.float32)

# Calculate the size of each patch
patch_size = 24

B, H, W, C = image.shape
image = image.reshape(B, H // patch_size, patch_size, W // patch_size, patch_size, C) # (B, 8, 24, 8, 24, 16)
image = image.permute(0, 2, 4, 1, 3, 5) # (B, 24, 24, 8, 8, 16)
image = image.reshape(B, patch_size, patch_size, -1) # (B, 24, 24, 1024)
image = image.reshape(B, -1, 1024) # (B, 576, 1024)


print(image.shape)


torch.Size([2, 576, 1024])
