- Takes as input the (768, 768, 64) images from `dataset-1`.
- It takes 4^3 average across the 2D, and average across the depth
- This results in a (192, 192, 16) volume.
- I have a image that is of shape (1, 192, 192, 16). Divide the image and the labels into 32 patches: (1, 8*24, 8*24, 16) → (24, 24, 1024) → (576, 1024). 
- Pass the image into a transformer architecture with 6 layers, which will return (576, 1) tensor, which corresponds to probabilities of pixel occuring. You should also pass the “positional embeddings”, which should be of dimension 1024.

# 4^3 3D average 


In [46]:
import numpy as np
import torch
import torch.nn as nn

def avg_3d(volume):
    # Convert the numpy array to a PyTorch tensor
    volume_tensor = torch.tensor(volume, dtype=torch.float32)

    # Add batch and channel dimensions to the tensor
    volume_tensor = volume_tensor.permute(0, 3, 1, 2)  # Reorder dimensions to (batch, channels, height, width)

    # Create the 3D average pooling layer with the appropriate kernel size and stride values
    avg_pool = nn.AvgPool3d(kernel_size=4, stride=4, padding=0)

    # Apply the average pooling layer to the input tensor
    with torch.no_grad():
        filtered_volume_tensor = avg_pool(volume_tensor)

    # Convert the output tensor back to a numpy array
    filtered_volume = filtered_volume_tensor.permute(0, 2, 3, 1).numpy()  # Reorder dimensions back to (batch, height, width, channels)
    
    return filtered_volume


# Example volume
volume = np.random.rand(1, 768, 768, 64)
volume_avgd = avg_3d(volume)
# %time avg_3d(volume) # 85 ms!

In [47]:
volume_avgd.shape

(1, 192, 192, 16)

# Reshape

In [75]:
import numpy as np


def reshape_img(image):
    # Calculate the size of each patch
    patch_size = 24

    B, H, W, C = image.shape
    image = image.reshape(B, H // patch_size, patch_size, W // patch_size, patch_size, C) # (B, 8, 24, 8, 24, 16)
    image = image.permute(0, 2, 4, 1, 3, 5) # (B, 24, 24, 8, 8, 16)
    image = image.reshape(B, patch_size, patch_size, -1) # (B, 24, 24, 1024)
    image = image.reshape(B, -1, 1024) # (B, 576, 1024)
    
    return image

# Example image
image = np.random.rand(2, 192, 192, 16)
image = torch.tensor(image, dtype=torch.float32)


image = reshape_img(image)

print(image.shape)


torch.Size([2, 576, 1024])


# Model

In [97]:
import torch
import torch.nn as nn

# Function to generate positional encodings
def generate_positional_encoding(seq_len, d_pos_enc):
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_pos_enc, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_pos_enc))
    pos_enc = torch.zeros((1, seq_len, d_pos_enc))
    pos_enc[0, :, 0::2] = torch.sin(position * div_term)
    pos_enc[0, :, 1::2] = torch.cos(position * div_term)
    return pos_enc

class CustomTransformer(nn.Module):
    def __init__(self, d_model, d_pos_enc, nhead, num_layers):
        super(CustomTransformer, self).__init__()

        self.positional_encoding = generate_positional_encoding(576, d_pos_enc)
        

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model + d_pos_enc, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.final_linear = nn.Linear(d_model + d_pos_enc, 1)

    def forward(self, x, targets=None):
        
        # repeat positional encoding for each sample in the batch
        self.positional_encoding = self.positional_encoding.repeat(x.shape[0], 1, 1)
        
        # Concatenate positional embeddings along the channel (embedding) dimension        
        x_with_pos_enc = torch.cat((x, self.positional_encoding), dim=-1)

        # Pass the input tensor with positional encodings to the transformer encoder
        output = self.transformer_encoder(x_with_pos_enc)
        
        output = self.final_linear(output) # (B, 576, 64)
        
        # Apply sigmoid activation to the output to get rid of the last dimension
        output = torch.sigmoid(output) # (B, 576, 1)
    
        
        if targets is None:
            loss = None
        else:
            # output and targets have shape (batch_size, seq_len, 1). I want to apply   
            loss_function = nn.BCEWithLogitsLoss()
            
            loss = None    
        
        return output, loss

# Instantiate the custom transformer model
model = CustomTransformer(d_model=1024, d_pos_enc=128, nhead=8, num_layers=2)

# Example input tensor
B = 2
input_tensor = torch.rand((B, 576, 1024))
targets = torch.rand((B, 576, 1))


# Pass the input tensor through the custom transformer model
output, loss = model(input_tensor)

print(output.shape)


torch.Size([2, 576, 1])


# Loss function

In [114]:
# Calculate class weights based on the imbalance
class_weights = torch.tensor([1.0, 5.0]) # weight 1 for class 0, weight 5 for class 1

# Instantiate the loss function
loss_function = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])

# Example output and targets
out = [[0.1], [0.2], [0.9], [0.9]]
tar = [[1], [1], [0], [1]]
output = torch.tensor([out], dtype=torch.float32)
targets = torch.tensor([tar], dtype=torch.float32)
loss_function(output, targets)

tensor(2.2899)

In [117]:
# Example output and targets
out = [[0.1], [0.2], [0.9], [0.1]]
tar = [[0], [0], [0], [0]]
output = torch.tensor([out], dtype=torch.float32)
targets = torch.tensor([tar], dtype=torch.float32)
loss_function(output, targets)

tensor(0.8820)