- Takes as input the (768, 768, 64) images from `dataset-1`.
- It takes 4^3 average across the 2D, and average across the depth
- This results in a (192, 192, 16) volume.
- I have a image that is of shape (1, 192, 192, 16). Divide the image and the labels into 32 patches: (1, 8*24, 8*24, 16) → (24, 24, 1024) → (576, 1024). 
- Pass the image into a transformer architecture with 6 layers, which will return (576, 1) tensor, which corresponds to probabilities of pixel occuring. You should also pass the “positional embeddings”, which should be of dimension 1024.

# 4^3 3D average 


In [1]:
import numpy as np
import torch
import torch.nn as nn

def avg_3d(volume):
    # Convert the numpy array to a PyTorch tensor
    volume_tensor = torch.tensor(volume, dtype=torch.float32)

    # Add batch and channel dimensions to the tensor
    volume_tensor = volume_tensor.permute(0, 3, 1, 2)  # Reorder dimensions to (batch, channels, height, width)

    # Create the 3D average pooling layer with the appropriate kernel size and stride values
    avg_pool = nn.AvgPool3d(kernel_size=4, stride=4, padding=0)

    # Apply the average pooling layer to the input tensor
    with torch.no_grad():
        filtered_volume_tensor = avg_pool(volume_tensor)

    # Convert the output tensor back to a numpy array
    filtered_volume = filtered_volume_tensor.permute(0, 2, 3, 1)  # Reorder dimensions back to (batch, height, width, channels)
    
    return filtered_volume


# Example volume
volume = np.random.rand(1, 768, 768, 64)
volume_avgd = avg_3d(volume)
# %time avg_3d(volume) # 85 ms!

In [2]:
volume_avgd.shape

torch.Size([1, 192, 192, 16])

# Take random 3 channels

In [3]:
volume_avgd_new = volume_avgd.permute(3, 1, 2, 0)
# take 3 random numbers between 0 and 16, without replacement
random_numbers = np.random.choice(16, 3, replace=False)
volume_avgd_new = volume_avgd_new[random_numbers, :, :, :]
# Repermute
volume_avgd_new = volume_avgd_new.permute(3, 1, 2, 0)
volume_avgd_new.shape


torch.Size([1, 192, 192, 3])

# Reshape

In [4]:
import numpy as np


def reshape_img(image):
    # Calculate the size of each patch
    patch_size = 24

    B, H, W, C = image.shape
    image = image.reshape(B, H // patch_size, patch_size, W // patch_size, patch_size, C) # (B, 8, 24, 8, 24, 16)
    image = image.permute(0, 2, 4, 1, 3, 5) # (B, 24, 24, 8, 8, 16)
    image = image.reshape(B, patch_size, patch_size, -1) # (B, 24, 24, 1024)
    image = image.reshape(B, -1, 1024) # (B, 576, 1024)
    
    return image

# Example image
image = np.random.rand(2, 192, 192, 16)
image = torch.tensor(image, dtype=torch.float32)


image = reshape_img(image)

print(image.shape)


torch.Size([2, 576, 1024])


# Metric competition


In [58]:
def f0point5_score(output, target):
    # Flatten the output and target tensors
    output = output.view(-1)
    target = target.view(-1)
    
    # Convert the output to binary values 0 and 1
    output = (output > 0.5).float()
    
    # Calculate the precision and recall
    tp = torch.sum(output * target)
    fp = torch.sum(output * (1 - target))
    fn = torch.sum((1 - output) * target)
    precision = tp / (tp + fp + 1e-7)
    recall = tp / (tp + fn + 1e-7)
    
    # Calculate the F0.5 score
    beta = 0.5
    f0point5 = (1 + beta**2) * precision * recall / (beta**2 * precision + recall + 1e-7)
    return f0point5

# Model: ResNet50Seg

In [62]:
import torch.nn as nn
import torchvision.models as models

class ResNet50Seg(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, class_one_weight=1):
        super(ResNet50Seg, self).__init__()
        
        self.class_one_weight = class_one_weight
        
        # Load the pre-trained ResNet50 model
        self.resnet50 = models.resnet50()
        
        # Replace the final layer to output 1024 channel instead of 1000
        self.resnet50.fc = nn.Linear(2048, 1024)
        
        
    def forward(self, x, targets=None):
        x = self.resnet50(x)
        # Apply sigmoid
        x = torch.sigmoid(x)
        
        out = {}
        
        if targets is None:
            loss = None
            precision = None
            accuracy = None
            f0point5 = None
        else:
            # Calculate the loss
            
            # Calculate class weights based on the imbalance
            class_weights = torch.tensor([1.0, self.class_one_weight]) # weight 1 for class 0, weight 5 for class 1

            # Instantiate the loss function
            loss_function = nn.BCEWithLogitsLoss(pos_weight=class_weights[1], reduction='mean')
            
            # Flatten the targets tensor
            targets = targets.reshape(-1, 1024)
            
            # Calculate the loss
            loss = loss_function(x, targets)
            
            # Calculate the accuracy: number of correctly predicted pixel / total number of pixels
            # Convert the predictions to binary values 0 and 1
            predictions = (x > 0.5).float()
            # Calculate the accuracy
            accuracy = (predictions == targets).float().mean()
            
            # Calculate the precision
            tp = torch.sum(predictions * targets)   
            fp = torch.sum(predictions * (1 - targets))
            fn = torch.sum((1 - predictions) * targets)
            precision = tp / (tp + fp + 1e-7)
            
            # Calculate F0.5 score
            f0point5 = f0point5_score(predictions, targets) 
        
        out = {
            "loss": loss,
            "logits": x,
            "accuracy": accuracy,
            "precision": precision,
            "f0point5": f0point5
        }
        
        return out


In [139]:
model = ResNet50Seg(in_channels=3, out_channels=1)
x = torch.randn(1, 3, 192, 192)  # input image
output = model(x)
output['logits'].shape

torch.Size([1, 1024])

In [34]:
# number of 1e6 parameters
sum(p.numel() for p in model.parameters()) / 1e6

25.606208

# Loss function

In [35]:
# Calculate class weights based on the imbalance
class_weights = torch.tensor([1.0, 5.0]) # weight 1 for class 0, weight 5 for class 1

# Instantiate the loss function
loss_function = nn.BCEWithLogitsLoss(pos_weight=class_weights[1], reduction='none')

# Example output and targets
out = [[0.1], [0.2], [0.9], [0.9], [0]]
tar = [[1], [1], [0], [1], [0]]
output = torch.tensor([out], dtype=torch.float32)
targets = torch.tensor([tar], dtype=torch.float32)
loss_function(output, targets)

tensor([[[3.2220],
         [2.9907],
         [1.2412],
         [1.7058],
         [0.6931]]])

In [36]:
# Calculate class weights based on the imbalance
class_weights = torch.tensor([1.0, 5.0]) # weight 1 for class 0, weight 5 for class 1

# Instantiate the loss function
loss_function = nn.BCEWithLogitsLoss(pos_weight=class_weights[1], reduction='mean')

# Example output and targets
out = [[0.1], [0.2], [0.9], [0.9], [0]]
tar = [[1], [1], [0], [1], [0]]
output = torch.tensor([out], dtype=torch.float32)
targets = torch.tensor([tar], dtype=torch.float32)
loss_function(output, targets)

tensor(1.9705)

# Model with loss test

In [37]:
model = ResNet50Seg(in_channels=3, out_channels=1)
x = torch.randn(16, 3, 192, 192)  # input image
targets = torch.randn(16, 1024)  # input image

# Normalize x to be between 0 and 1
x = torch.sigmoid(x)

# Clip targets to 0 and 1
targets = torch.clamp(targets, 0, 1)


output = model(x, targets)
output['logits'].shape, output['loss'], output['accuracy']

(torch.Size([16, 1024]),
 tensor(0.8191, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 tensor(0.3298))

In [38]:
# Expected value for the BCE loss if there is 1024 elements

# Instantiate the loss function
loss_function = nn.BCEWithLogitsLoss(pos_weight=class_weights[1], reduction='mean')

y_true = torch.randn(1024, 1024) 
y_pred = torch.randn(1024, 1024) 

loss = loss_function(y_pred, y_true)
loss

tensor(0.8055)

In [39]:
# Expected value for the BCE loss if there is 1024 elements that were all correctly classified

# Instantiate the loss function
loss_function = nn.BCEWithLogitsLoss(pos_weight=class_weights[1], reduction='mean')

y_pred = torch.randn(1024*16, 1024) 

loss = loss_function(y_pred, y_pred)
loss

tensor(-2.1928)

# Overfit on 1 batch of random numbers

In [40]:
from tqdm import tqdm

In [66]:


model = ResNet50Seg(in_channels=3, out_channels=1, class_one_weight=5)
x = torch.randn(4, 3, 192, 192)  # input image
# targets = random integers of shape (4, 1024) between 0 and 1
targets = torch.randint(0, 2, (4, 1024)).float()
# Make 80% of the targets to be 0
targets[:, 0:819] = 0


# Normalize x to be between 0 and 1
x = torch.sigmoid(x)

# Clip targets to 0 and 1
targets = torch.clamp(targets, 0, 1)

# optimizer 
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

n_iters = 450
for i in tqdm(range(n_iters)):
    
    # Forward pass
    output = model(x, targets)
    loss = output['loss']
    precision = output['precision']
    accuracy = output['accuracy']
    f0point5 = output['f0point5']
    
    # Backward pass
    loss.backward()
    
    # Update weights
    optimizer.step()
    
    if i % 10 == 0:
        print(f"Loss: {loss.item():.4f}, Accuracy: {accuracy.item():.4f}, Precision: {precision.item():.4f}, F0.5: {f0point5.item():.4f}")


  0%|          | 1/450 [00:00<02:41,  2.78it/s]

Loss: 1.1170, Accuracy: 0.5149, Precision: 0.1084, F0.5: 0.1288


  2%|▏         | 11/450 [00:03<02:03,  3.57it/s]

Loss: 1.0715, Accuracy: 0.6597, Precision: 0.1951, F0.5: 0.2289


  5%|▍         | 21/450 [00:06<01:58,  3.62it/s]

Loss: 1.0357, Accuracy: 0.7729, Precision: 0.2938, F0.5: 0.3386


  7%|▋         | 31/450 [00:08<01:58,  3.53it/s]

Loss: 0.9999, Accuracy: 0.8662, Precision: 0.4292, F0.5: 0.4812


  9%|▉         | 41/450 [00:11<01:55,  3.53it/s]

Loss: 0.9654, Accuracy: 0.9170, Precision: 0.5525, F0.5: 0.6062


 11%|█▏        | 51/450 [00:14<01:52,  3.56it/s]

Loss: 0.9338, Accuracy: 0.9387, Precision: 0.6254, F0.5: 0.6760


 14%|█▎        | 61/450 [00:17<01:49,  3.56it/s]

Loss: 0.9062, Accuracy: 0.9487, Precision: 0.6661, F0.5: 0.7138


 16%|█▌        | 71/450 [00:20<02:03,  3.07it/s]

Loss: 0.8832, Accuracy: 0.9500, Precision: 0.6715, F0.5: 0.7187


 18%|█▊        | 81/450 [00:23<01:49,  3.36it/s]

Loss: 0.8647, Accuracy: 0.9502, Precision: 0.6726, F0.5: 0.7197


 20%|██        | 91/450 [00:26<01:50,  3.26it/s]

Loss: 0.8500, Accuracy: 0.9500, Precision: 0.6715, F0.5: 0.7187


 22%|██▏       | 101/450 [00:29<01:45,  3.32it/s]

Loss: 0.8385, Accuracy: 0.9502, Precision: 0.6726, F0.5: 0.7197


 25%|██▍       | 111/450 [00:32<01:41,  3.36it/s]

Loss: 0.8299, Accuracy: 0.9502, Precision: 0.6726, F0.5: 0.7197


 27%|██▋       | 121/450 [00:35<01:35,  3.44it/s]

Loss: 0.8234, Accuracy: 0.9502, Precision: 0.6726, F0.5: 0.7197


 29%|██▉       | 131/450 [00:38<01:34,  3.39it/s]

Loss: 0.8183, Accuracy: 0.9521, Precision: 0.6813, F0.5: 0.7277


 31%|███▏      | 141/450 [00:41<01:30,  3.40it/s]

Loss: 0.8143, Accuracy: 0.9529, Precision: 0.6846, F0.5: 0.7307


 34%|███▎      | 151/450 [00:44<01:26,  3.47it/s]

Loss: 0.8110, Accuracy: 0.9565, Precision: 0.7018, F0.5: 0.7463


 36%|███▌      | 161/450 [00:47<01:28,  3.26it/s]

Loss: 0.8081, Accuracy: 0.9607, Precision: 0.7224, F0.5: 0.7649


 38%|███▊      | 171/450 [00:50<01:18,  3.54it/s]

Loss: 0.8056, Accuracy: 0.9648, Precision: 0.7442, F0.5: 0.7844


 40%|████      | 181/450 [00:53<01:15,  3.55it/s]

Loss: 0.8034, Accuracy: 0.9675, Precision: 0.7591, F0.5: 0.7975


 42%|████▏     | 191/450 [00:55<01:14,  3.48it/s]

Loss: 0.8014, Accuracy: 0.9705, Precision: 0.7759, F0.5: 0.8123


 45%|████▍     | 201/450 [00:58<01:16,  3.26it/s]

Loss: 0.7997, Accuracy: 0.9736, Precision: 0.7951, F0.5: 0.8290


 47%|████▋     | 211/450 [01:01<01:08,  3.47it/s]

Loss: 0.7982, Accuracy: 0.9756, Precision: 0.8073, F0.5: 0.8397


 49%|████▉     | 221/450 [01:05<01:21,  2.80it/s]

Loss: 0.7968, Accuracy: 0.9780, Precision: 0.8232, F0.5: 0.8534


 51%|█████▏    | 231/450 [01:08<01:07,  3.26it/s]

Loss: 0.7956, Accuracy: 0.9797, Precision: 0.8347, F0.5: 0.8632


 54%|█████▎    | 241/450 [01:11<01:00,  3.44it/s]

Loss: 0.7946, Accuracy: 0.9819, Precision: 0.8499, F0.5: 0.8762


 56%|█████▌    | 251/450 [01:14<00:59,  3.33it/s]

Loss: 0.7938, Accuracy: 0.9827, Precision: 0.8551, F0.5: 0.8806


 58%|█████▊    | 261/450 [01:17<00:54,  3.47it/s]

Loss: 0.7931, Accuracy: 0.9836, Precision: 0.8621, F0.5: 0.8866


 60%|██████    | 271/450 [01:20<00:51,  3.48it/s]

Loss: 0.7925, Accuracy: 0.9841, Precision: 0.8657, F0.5: 0.8896


 62%|██████▏   | 281/450 [01:22<00:48,  3.51it/s]

Loss: 0.7920, Accuracy: 0.9851, Precision: 0.8729, F0.5: 0.8957


 65%|██████▍   | 291/450 [01:25<00:45,  3.50it/s]

Loss: 0.7915, Accuracy: 0.9858, Precision: 0.8784, F0.5: 0.9003


 67%|██████▋   | 301/450 [01:29<00:52,  2.82it/s]

Loss: 0.7911, Accuracy: 0.9866, Precision: 0.8840, F0.5: 0.9050


 69%|██████▉   | 311/450 [01:32<00:41,  3.31it/s]

Loss: 0.7908, Accuracy: 0.9871, Precision: 0.8877, F0.5: 0.9081


 71%|███████▏  | 321/450 [01:35<00:38,  3.33it/s]

Loss: 0.7905, Accuracy: 0.9873, Precision: 0.8896, F0.5: 0.9097


 74%|███████▎  | 331/450 [01:38<00:35,  3.33it/s]

Loss: 0.7902, Accuracy: 0.9880, Precision: 0.8953, F0.5: 0.9144


 76%|███████▌  | 341/450 [01:41<00:32,  3.31it/s]

Loss: 0.7900, Accuracy: 0.9885, Precision: 0.8991, F0.5: 0.9177


 78%|███████▊  | 351/450 [01:44<00:31,  3.13it/s]

Loss: 0.7898, Accuracy: 0.9885, Precision: 0.8991, F0.5: 0.9177


 80%|████████  | 361/450 [01:47<00:25,  3.47it/s]

Loss: 0.7896, Accuracy: 0.9888, Precision: 0.9011, F0.5: 0.9193


 82%|████████▏ | 371/450 [01:50<00:22,  3.44it/s]

Loss: 0.7894, Accuracy: 0.9893, Precision: 0.9050, F0.5: 0.9225


 85%|████████▍ | 381/450 [01:53<00:19,  3.50it/s]

Loss: 0.7892, Accuracy: 0.9895, Precision: 0.9069, F0.5: 0.9241


 87%|████████▋ | 391/450 [01:56<00:17,  3.46it/s]

Loss: 0.7891, Accuracy: 0.9895, Precision: 0.9069, F0.5: 0.9241


 89%|████████▉ | 401/450 [01:59<00:14,  3.41it/s]

Loss: 0.7889, Accuracy: 0.9895, Precision: 0.9069, F0.5: 0.9241


 91%|█████████▏| 411/450 [02:02<00:11,  3.45it/s]

Loss: 0.7888, Accuracy: 0.9897, Precision: 0.9089, F0.5: 0.9258


 94%|█████████▎| 421/450 [02:05<00:08,  3.49it/s]

Loss: 0.7887, Accuracy: 0.9905, Precision: 0.9148, F0.5: 0.9307


 96%|█████████▌| 431/450 [02:08<00:05,  3.52it/s]

Loss: 0.7886, Accuracy: 0.9905, Precision: 0.9148, F0.5: 0.9307


 98%|█████████▊| 441/450 [02:10<00:02,  3.51it/s]

Loss: 0.7885, Accuracy: 0.9905, Precision: 0.9148, F0.5: 0.9307


100%|██████████| 450/450 [02:13<00:00,  3.37it/s]


# Data Loader

In [196]:
from torch.nn.functional import avg_pool2d
from torch.utils.data import Dataset, DataLoader
import os


class ImageSegmentationDataset(Dataset):
    def __init__(self, root, mode='train', device='cpu', cache_refresh_interval=None, cache_n_images=64):
        
        assert mode in ['train', 'test'], "mode must be either 'train' or 'test'"
        
        self.root = root
        self.mode = mode
        self.cache_refresh_interval = cache_refresh_interval
        self.cache_n_images = cache_n_images
        self._load_data()
        
    def __len__(self):
        return len(self.volume_images)
    
    # def __getitem__(self, item):
        
    #     volume = np.load(self.volume_images[item])
    #     label = np.load(self.label_images[item])
        
    #     # convert to torch tensors
    #     volume = torch.tensor(volume, dtype=torch.float32)
    #     label = torch.tensor(label, dtype=torch.float32)
        
    #     # send to device
    #     volume = volume.to(self.device)
    #     label = label.to(self.device)
        
    #     # unsqueeze the channel dimension
    #     label = label.unsqueeze(-1)
        
        
    #     # take one random number from 0 to 20, 20 to 35, 35 to 50, 50 to 64
    #     idx_1 = np.random.randint(0, 20)
    #     idx_2 = np.random.randint(20, 35)
    #     idx_3 = np.random.randint(35, 50)
    #     indices = [idx_1, idx_2, idx_3]
    #     volume = volume[:, :, indices]
        
    #     # Apply 2D average pooling to the volume and label tensors
    #     # Average each channel separately
    #     volume = volume.permute(2, 0, 1)  # Permute dimensions for avg_pool2d
    #     volume = avg_pool2d(volume, kernel_size=4, stride=4)
    #     volume = volume.permute(1, 2, 0)  # Permute dimensions back
        
    #     label = label.permute(2, 0, 1)  # Permute dimensions for avg_pool2d
    #     label = avg_pool2d(label, kernel_size=4, stride=4)
    #     label = label.permute(1, 2, 0)  # Permute dimensions back
        
    #     return {
    #         "image": volume,
    #         "targets": label
    #     }
       
    def _load_data(self):
        # Get the volume paths
        self.volume_images = os.listdir(os.path.join(self.root, self.mode, 'volume'))
        self.volume_images = [os.path.join(self.root, self.mode, 'volume', image) for image in self.volume_images]
        
        # Get label paths
        self.label_images = os.listdir(os.path.join(self.root, self.mode, 'label'))
        self.label_images = [os.path.join(self.root, self.mode, 'label', image) for image in self.label_images]
        
        # take self.cache_n_images random images from the dataset. They cannot be repeated
        if self.cache_n_images is not None and self.cache_n_images < len(self.volume_images):
            self.volume_images = np.random.choice(self.volume_images, size=self.cache_n_images, replace=False)
            self.label_images = np.random.choice(self.label_images, size=self.cache_n_images, replace=False)
        
        # Load the data into memory
        self.cached_data = []
        for volume_image, label_image in zip(self.volume_images, self.label_images):
            volume = np.load(volume_image)
            label = np.load(label_image)
            
            # convert to torch tensors
            volume = torch.tensor(volume, dtype=torch.float32)
            label = torch.tensor(label, dtype=torch.float32)
            
            # unsqueeze the channel dimension
            label = label.unsqueeze(-1)
            
            self.cached_data.append({
                "image": volume,
                "targets": label
            }) 
    
    def __len__(self):
        return len(self.cached_data)
    
    def __getitem__(self, item):
        return self.cached_data[item]
    
    def refresh_cache(self):
        self._load_data()

In [205]:

# Define the dataset and dataloader
root = "../../datasets/dataset-1/"

# Mac m1 device
device = torch.device("mps")

batch_size = 16

dataset = ImageSegmentationDataset(root=root, mode='train', device=device, cache_refresh_interval=160, cache_n_images=batch_size)

In [206]:
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [207]:
# one batch
batch = next(iter(dataloader))
batch['image'].shape, batch['targets'].shape

(torch.Size([5, 768, 768, 64]), torch.Size([5, 768, 768, 1]))

In [208]:
# Let's check how fast the dataloader is

import time

start = time.time()
for i in range(100):
    batch = next(iter(dataloader))
    # Refresh the cache 
    if i % dataset.cache_refresh_interval == 0:
        dataset.refresh_cache()
    
end = time.time()

print(f"Time taken for 100 iterations: {end - start:.4f} seconds")

Time taken for 100 iterations: 11.6429 seconds


# 1