# Install dependencies

In [None]:
# Install Dependencies.
!echo "deb https://packages.cloud.google.com/apt gcsfuse-`lsb_release -c -s` main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
!apt -qq update && apt -qq install gcsfuse

!pip install pydicom matplotlib transformers nibabel

# Authenticate.
from google.colab import auth
auth.authenticate_user()

deb https://packages.cloud.google.com/apt gcsfuse-jammy main
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2659  100  2659    0     0  21151      0 --:--:-- --:--:-- --:--:-- 21272
OK
19 packages can be upgraded. Run 'apt list --upgradable' to see them.
[1;33mW: [0mhttps://packages.cloud.google.com/apt/dists/gcsfuse-jammy/InRelease: Key is stored in legacy trusted.gpg keyring (/etc/apt/trusted.gpg), see the DEPRECATION section in apt-key(8) for details.[0m
The following NEW packages will be installed:
  gcsfuse
0 upgraded, 1 newly installed, 0 to remove and 19 not upgraded.
Need to get 5,561 kB of archives.
After this operation, 0 B of additional disk space will be used.
Selecting previously unselected package gcsfuse.
(Reading database ... 120882 files and directories currently installed.)
Preparing to unpack .../gcsfuse_1.2.1_amd64.deb ...
Unpacking gcsfuse (1.2.1) 

In [None]:
##!mkdir "/content/brats"
#!mkdir "/content/upenn"

#!gcsfuse "brats-image-files-eu" "/content/brats"
#!gcsfuse "upenn-gbm-nifti" "/content/upenn"

#!fusermount -u /content/brats

# Define functions

## Preprocessing

In [None]:
import subprocess

def find_scan_and_segm_files_gcs(bucket_path,dataset):
    '''
    Finds the scan and segmentation files for a given scan type in a GCS bucket

    Args:
        bucket_path (string): Path to the GCS bucket (e.g., 'gs://my-bucket/')
        scan_type (string): Type of scan. Can be 'flair', 't1', 't1gd', or 't2'
    Returns:
        scan_path (string): Path to the scan file in GCS
        seg_path (string): Path to the segmentation file in GCS
    '''

    patient_id = "".join((bucket_path).split("/")[-2:-1])

    # Define scan path keys
    scan_paths = {
        f't1': None,   # Corresponds to 't1' in UPenn and 't1n' in Brats
        f't1c': None,  # Corresponds to 't1gd' in UPenn and 't1c' in Brats
        f't2': None,   # Corresponds to 't2' in UPenn and 't2w' in Brats
        f'flair': None,  # Corresponds to 'flair' in UPenn and 't2f' in Brats
        f'seg': None   # Segmentation file
    }

    # Use gsutil to list files in the bucket
    cmd = f'gsutil ls -r "{bucket_path}**"'
    process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = process.communicate()

    if process.returncode != 0:
        print("Error:", err.decode('utf-8'))
        return None

    # Decode output and split into lines
    files = out.decode('utf-8').splitlines()

    # Search for the files
    for file in files:
        if dataset == 'brats':
            # File matching for 'brats'
            if 't1n.nii.gz' in file:
                scan_paths[f't1'] = file
            elif 't1c.nii.gz' in file:
                scan_paths[f't1c'] = file
            elif 't2w.nii.gz' in file:
                scan_paths[f't2'] = file
            elif 't2f.nii.gz' in file:
                scan_paths[f'flair'] = file
            elif 'seg.nii.gz' in file:
                scan_paths[f'seg'] = file
        elif dataset == 'upenn':
            # File matching for 'upenn'
            if '_11_segm.nii' in file:
                scan_paths[f'seg'] = file
            else:
                for upenn_type, brats_type in [('flair', 'flair'), ('t1', 't1'), ('t1gd', 't1c'), ('t2', 't2')]:
                    scan_file_pattern = f'_11_{upenn_type.upper()}.nii.gz'
                    if file.endswith(scan_file_pattern):
                        scan_paths[f'{brats_type}'] = file

    return patient_id, scan_paths



In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np

def read_nii(file_name):
    '''
    Reads a NIfTI file and returns the data as a numpy array
    '''

    # reads the NIfTI file
    nii_file = nib.load(file_name)

    # Access the data
    data = nii_file.get_fdata()
    return data

# Load the original brain scan and the segmentation mask
def mask_slice(original_data, mask_data):
        '''
        Function to display the slice with the largest affected
        area of the original MRI scan together with slice of segmentation mask

        Args:
            original_data (numpy array): original MRI scan (already loaded from NIfTI file)
            mask_data (numpy array): segmentation mask (already loaded from NIfTI file)
        Returns:
            None
        '''

        # Initialize variables to track the largest slice
        max_non_black_count = 0
        max_slice_index = 0

        # Iterate through each slice in the mask
        for i in range(mask_data.shape[2]):
            # Count non-black (non-zero) pixels in the slice
            non_black_count = np.count_nonzero(mask_data[:, :, i])

            # Update max count and slice index if current slice has more non-black pixels
            if non_black_count > max_non_black_count:
                max_non_black_count = non_black_count
                max_slice_index = i

        ## Get bounding-box from mask
        mask_array = mask_data[:,:,max_slice_index]
        # Identifying the indices of non-zero elements
        non_zero_indices = np.argwhere(mask_array != 0)

        # Finding the min and max indices along x and y axis
        min_y, min_x = non_zero_indices.min(axis=0)
        max_y, max_x = non_zero_indices.max(axis=0)

        # Creating the bounding box mask
        mask = np.zeros_like(mask_array)
        mask[min_y:max_y + 1, min_x:max_x + 1] = 1

        return max_slice_index,mask

In [None]:
from PIL import Image
import numpy as np
import pydicom
import pandas as pd
from torchvision import transforms

# Adjust image range between 0 and max value for 16 bits
def preprocess_image(best_slice):

  # Normalize to 0-255
  min_val = best_slice.min()
  max_val = best_slice.max()
  image_array = (best_slice - min_val) / (max_val - min_val) * 255
  image_array = np.uint8(image_array)

  # Make Greyscale to RGB
  image_array = np.stack((image_array,) * 3, axis=-1)
  image = Image.fromarray(image_array)

  # Turn into tensor and add one dimension for "batch-size"
  transform = transforms.ToTensor()
  image_tensor = transform(image).unsqueeze(0) # Shape: 1,x,x,x

  return image, image_tensor



## Models

In [None]:
import requests
import zipfile
import torch
import torch.nn as nn
import torchvision.models as models
import io
import transformers
from transformers import AutoModel, AutoConfig


def load_models():

  ### Load weights

  url = "https://storage.googleapis.com/pytrial/medclip-pretrained.zip"
  r = requests.get(url)
  z = zipfile.ZipFile(io.BytesIO(r.content))
  z.extractall("RESNET50_wb")

  # See in https://github.com/mk-statistics/MedCLIP/blob/main/medclip/modeling_medclip.py line 66-70
  state_dict_resnet = torch.load("RESNET50_wb/pytorch_model.bin",map_location=torch.device('cpu'))
  new_state_dict_resnet = {}
  for key in state_dict_resnet.keys():
    if 'vision_model' in key:
      new_state_dict_resnet[key.replace('vision_model.model.','')] = state_dict_resnet[key]

  url = "https://storage.googleapis.com/pytrial/medclip-vit-pretrained.zip"
  r = requests.get(url)
  z = zipfile.ZipFile(io.BytesIO(r.content))
  z.extractall("ViT_wb")

  # Do transforms on weights
  # See in https://github.com/mk-statistics/MedCLIP/blob/main/medclip/modeling_medclip.py line 109-113
  state_dict_vit = torch.load("ViT_wb/pytorch_model.bin",map_location=torch.device('cpu'))
  new_state_dict_vit = {}
  for key in state_dict_vit.keys():
    if 'vision_model' in key:
      new_state_dict_vit[key.replace('vision_model.model.','')] = state_dict_vit[key]

  new_state_dict_vit["projection_head.weight"] = new_state_dict_vit.pop("vision_model.projection_head.weight")


  ### Initilize models

  model_resnet = models.resnet50(weights=None)

  # Add one layer as the projection head as seen in https://github.com/mk-statistics/MedCLIP/blob/main/medclip/modeling_medclip.py line 51,52,53
  projection_head = nn.Linear(model_resnet.fc.in_features, 512, bias=False) # create an embedding
  model_resnet.fc = projection_head

  model_resnet.load_state_dict(new_state_dict_resnet)



  model_vit = AutoModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

  projection_head = nn.Linear(768, 512, bias=False) # Add projection_head in https://github.com/mk-statistics/MedCLIP/blob/main/medclip/modeling_medclip.py line 95
  model_vit.projection_head = projection_head

  model_vit.load_state_dict(new_state_dict_vit)

  return model_resnet, model_vit


def do_inference(model_name,model,image_tensor):

  model.eval()

  with torch.no_grad():

    if model_name == "resnet50":

      vis_output = model(image_tensor)
      embedding = vis_output / vis_output.norm(dim=-1, keepdim=True) # See in https://github.com/mk-statistics/MedCLIP/blob/main/medclip/modeling_medclip.py line 199

    elif model_name == "vit":

      vis_output = model.projection_head(model(image_tensor)["pooler_output"]) # Add forward pass to the added projection_layer... actually could not add it to the forward() method
      embedding = vis_output / vis_output.norm(dim=-1, keepdim=True)


  return embedding.numpy().reshape((512,))


In [None]:
# https://github.com/christiansafka/img2vec
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np

class Img2Vec():
    RESNET_OUTPUT_SIZES = {
        'resnet18': 512,
        'resnet34': 512,
        'resnet50': 2048,
        'resnet101': 2048,
        'resnet152': 2048,
    }

    def __init__(self, cuda=False, model='resnet-34', layer='default',
                 layer_output_size=512):
        """ Img2Vec
        :param cuda: If set to True, will run forward pass on GPU
        :param model: String name of requested model
        :param layer: String or Int depending on model.  See more docs: https://github.com/christiansafka/img2vec.git
        :param layer_output_size: Int depicting the output size of the requested layer
        """
        self.device = torch.device("cuda" if cuda else "cpu")
        self.layer_output_size = layer_output_size
        self.model_name = model

        self.model, self.extraction_layer = self._get_model_and_layer(model, layer)

        self.model = self.model.to(self.device)

        self.model.eval()

        self.scaler = transforms.Resize((224, 224))
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])
        self.to_tensor = transforms.ToTensor()

    def get_vec(self, img, tensor=False):
        """ Get vector embedding from PIL image
        :param img: PIL Image or list of PIL Images
        :param tensor: If True, get_vec will return a FloatTensor instead of Numpy array
        :returns: Numpy ndarray
        """
        if type(img) == list:
            a = [self.normalize(self.to_tensor(self.scaler(im))) for im in img]
            images = torch.stack(a).to(self.device)
            if self.model_name == 'alexnet':
                my_embedding = torch.zeros(len(img), self.layer_output_size)
            else:
                my_embedding = torch.zeros(len(img), self.layer_output_size, 1, 1)

            def copy_data(m, i, o):
                my_embedding.copy_(o.data)

            h = self.extraction_layer.register_forward_hook(copy_data)
            h_x = self.model(images)
            h.remove()

            if tensor:
                return my_embedding
            else:
                if self.model_name == 'alexnet':
                    return my_embedding.numpy()[:, :]
                else:
                    return my_embedding.numpy()[:, :, 0, 0]
        else:
            image = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0).to(self.device)

            if self.model_name == 'alexnet':
                my_embedding = torch.zeros(1, self.layer_output_size)
            else:
                my_embedding = torch.zeros(1, self.layer_output_size, 1, 1)

            def copy_data(m, i, o):
                my_embedding.copy_(o.data)

            h = self.extraction_layer.register_forward_hook(copy_data)
            h_x = self.model(image)
            h.remove()

            if tensor:
                return my_embedding
            else:
                if self.model_name == 'alexnet':
                    return my_embedding.numpy()[0, :]
                else:
                    return my_embedding.numpy()[0, :, 0, 0]

    def _get_model_and_layer(self, model_name, layer):
        """ Internal method for getting layer from model
        :param model_name: model name such as 'resnet-18'
        :param layer: layer as a string for resnet-18 or int for alexnet
        :returns: pytorch model, selected layer
        """

        if model_name.startswith('resnet') and not model_name.startswith('resnet-'):
            model = getattr(models, model_name)(pretrained=True)
            if layer == 'default':
                layer = model._modules.get('avgpool')
                self.layer_output_size = self.RESNET_OUTPUT_SIZES[model_name]
            else:
                layer = model._modules.get(layer)
            return model, layer
        elif model_name == 'resnet-34':
            model = models.resnet34(pretrained=True)
            if layer == 'default':
                layer = model._modules.get('avgpool')
                self.layer_output_size = 512
            else:
                layer = model._modules.get(layer)

            return model, layer

        elif model_name == 'alexnet':
            model = models.alexnet(pretrained=True)
            if layer == 'default':
                layer = model.classifier[-2]
                self.layer_output_size = 4096
            else:
                layer = model.classifier[-layer]

            return model, layer

        else:
            raise KeyError('Model %s was not found' % model_name)

# Main-Function

## Initialize Models

In [None]:
# Initialize
model_resnet, model_vit = load_models()
img2vec50 = Img2Vec(model='resnet50')
img2vec152 = Img2Vec(model='resnet152')

config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 37.3MB/s]
Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:02<00:00, 100MB/s]


## Get data to iterate trough

In [None]:
import subprocess

#!gsutil ls -d gs://brats-image-files-eu/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData/*-000/ #BRATS_GLI # /ASNR-MICCAI-BraTS2023-MET-Challenge-TrainingData/ #ASNR-MICCAI-BraTS2023-PED-Challenge-TrainingData/ BraTS-MEN-Train/
#!gsutil ls -d gs://brats-image-files-eu/ASNR-MICCAI-BraTS2023-MET-Challenge-TrainingData/*-000/ ##!gsutil ls -d gs://upenn-gbm-nifti/PKG-UPENN-GBM-NIfTI/UPENN-GBM/NIfTI-files/images_structural/*_11/ # UPENN GLIO

#cmd = f'gsutil ls gs://upenn-gbm-nifti/PKG-UPENN-GBM-NIfTI/UPENN-GBM/NIfTI-files/images_structural/'
cmd = f'gsutil ls gs://brats-image-files-eu/ASNR-MICCAI-BraTS2023-MET-Challenge-TrainingData/'#*-000/'
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out = process.communicate()[0].decode('utf-8').splitlines()
dataset = "brats"
diagnose = "met"


## Do the actucal inference, store reference images onto GCS & store embeddings into dict

In [None]:
from PIL import Image
import os
import hashlib
import uuid
import torch
import time
import pickle

embedding_dict = {}

start = time.time()
### MAIN
for counter,pat_path in enumerate(out):

  # Get files
  if pat_path == "gs://brats-image-files-eu/ASNR-MICCAI-BraTS2023-MET-Challenge-TrainingData/":
    continue

  uui = str(uuid.uuid4())
  embedding_dict[uui] = {}

  patid, files = find_scan_and_segm_files_gcs(pat_path,dataset)

  # Store in dict
  embedding_dict[uui]["pat_id"] = patid
  embedding_dict[uui]["bucket_url"] = pat_path
  embedding_dict[uui]["origin_ds"] = dataset
  embedding_dict[uui]["diagnose"] = diagnose

  # Download files loacally
  !gcloud storage cp -r {pat_path} .

  ## Now do the embedding for every image

  seg = read_nii("/".join(files["seg"].split("/")[-2:]))
  slice_index,mask = mask_slice(read_nii("/".join(files["t1"].split("/")[-2:])), seg)

  # Store in dict
  embedding_dict[uui]["embeddings"] = {}
  embedding_dict[uui]["embeddings"]["slice_id"] = slice_index
  embedding_dict[uui]["embeddings"]["RESNET50_MEDCLIP"] = {}
  embedding_dict[uui]["embeddings"]["VIT_MEDCLIP"] = {}
  embedding_dict[uui]["embeddings"]["RESNET50_IMAGENET"] = {}
  embedding_dict[uui]["embeddings"]["RESNET152_IMAGENET"] = {}

  for i in ["t1","t1c","t2","flair"]:

    pic = read_nii("/".join(files[i].split("/")[-2:]))[:,:,slice_index]
    pic_seg = pic * mask

    # Precprocess image
    image, image_tensor = preprocess_image(pic)
    seg_image, seg_image_tensor = preprocess_image(pic_seg)

    # Save image an upload to google
    image_path = patid + "/png/"
    !mkdir -p {image_path}
    image.save((image_path + patid + "_" + i + "_sliced.png"))
    seg_image.save((image_path + patid + "_seg_" + i + "_sliced.png"))

    # Create embedding w medclip
    resnet_embedding = do_inference("resnet50",model_resnet,image_tensor)
    seg_resnet_embedding = do_inference("resnet50",model_resnet,seg_image_tensor)
    embedding_dict[uui]["embeddings"]["RESNET50_MEDCLIP"][i] = resnet_embedding
    embedding_dict[uui]["embeddings"]["RESNET50_MEDCLIP"]["seg_" + i] = seg_resnet_embedding

    vit_embedding = do_inference("vit",model_vit,image_tensor)
    seg_vit_embedding = do_inference("vit",model_vit,seg_image_tensor)
    embedding_dict[uui]["embeddings"]["VIT_MEDCLIP"][i] = vit_embedding
    embedding_dict[uui]["embeddings"]["VIT_MEDCLIP"]["seg_" + i] = seg_vit_embedding

    # Create embedding w resnet
    resnet50_imagenet_emedding = img2vec50.get_vec(image)
    seg_resnet50_imagenet_emedding = img2vec50.get_vec(seg_image)
    embedding_dict[uui]["embeddings"]["RESNET50_IMAGENET"][i] = resnet50_imagenet_emedding
    embedding_dict[uui]["embeddings"]["RESNET50_IMAGENET"]["seg_" + i] = seg_resnet50_imagenet_emedding

    resnet152_imagenet_emedding = img2vec152.get_vec(image)
    seg_resnet152_imagenet_emedding = img2vec152.get_vec(seg_image)
    embedding_dict[uui]["embeddings"]["RESNET152_IMAGENET"][i] = resnet152_imagenet_emedding
    embedding_dict[uui]["embeddings"]["RESNET152_IMAGENET"]["seg_" + i] = seg_resnet152_imagenet_emedding

  !gsutil -m cp -R {image_path} {pat_path}


  # After one cycle: remove data again
  !rm -rf {patid}

  if counter >= 100:

    with open('100_ASNR-MICCAI-BraTS2023-MET-Challenge-TrainingData_embedding.pickle', 'wb') as handle:
      pickle.dump(embedding_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    !gsutil -m cp  /content/100_ASNR-MICCAI-BraTS2023-MET-Challenge-TrainingData_embedding.pickle gs://picture_embeddings/combined_embeddings

    break

  print("Counter: " + str(counter))


end = time.time()
print(end - start)