<a href="https://colab.research.google.com/github/xdu006/CellSAM_XL/blob/main/CellSAM_XL_Gradio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
#Setup

!pip install gradio
!pip install -q git+https://github.com/xdu006/cellSAM_test.git

import cellSAM
from cellSAM import segment_cellular_image, get_model
import numpy as np
import torch
from scipy.ndimage import binary_dilation
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import tifffile as tf
import numpy as np
import pandas as pd

import gradio as gr

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# FUNCTIONS

## Visualization

In [9]:
#visualizes all channels by converting each channel into a color in RGB format
def visualize_channels(channels):

  rgb: None | np.ndarray = None

  #choose color pallet
  colors = [(255, 0, 0), (0, 255, 0), (160,32,240), (0, 0, 255), (255, 0, 255), (0, 255, 255)]
  #colors = [(255, 0, 0), (0, 255, 0), (255, 255, 0), (0, 0, 255), (255, 0, 255), (0, 255, 255)]
  #colors = [(255, 255, 0), (255, 255, 255), (160,32,240), (0, 255, 255), (255, 0, 255), (0, 255, 0)]

  #for every channel
  for i,channel in enumerate(channels):

    rgb_factor = colors[i] #set scaling factor/color for the channel

    # Create base RGB-TIFF arrays
    if rgb is None: rgb = np.zeros((channel.shape[0], channel.shape[1], 3), dtype="float64")

    rgb_multiplier = (np.array(rgb_factor,)/ 255) #use factor to create a multiplier

    # Iterate over each channel index and multiplier in rgb_multipliers
    for i, multiplier in enumerate(rgb_multiplier):
      # Multiply the channel with the current multiplier
      # Add the resulting channel_rgb to the corresponding channel in the rgb array
      rgb[:, :, i] += (channel * multiplier).astype("float64")

    # Now the rgb array holds the accumulated values for each channel
    # Scale and convert to float64 with intensity between 0 and 1
    assert rgb is not None
    rgb = np.clip(rgb, 0, 1).astype("float64")

  return rgb # return the RGB image for display

## Load and Process Image

In [7]:
# reads tiff files from given file name
def read_and_enhance_file (file_name):

  raw_tiff_file = tf.imread(file_name).astype(np.float64) #read file

  #Check input file shape (ensure that channel is last)
  match raw_tiff_file.ndim:
    case 0: message = "Loading image... No data detected, please check tiff file loaded"            ### consider adding a case none for the ndim check...
    case 1: message = "Loading image... Image Loaded: File only had one dimention. Are you sure you loaded a 2D image...?"
    case 2: message = "Loading image... Image Loaded: 2 dimension image detected."
    case 3:
      message = "Image Loaded: 3 dimension image detected. Converted to HWC format."
      if raw_tiff_file.shape[0] < raw_tiff_file.shape[1]: #always convert to # H, W, C
        print(f"Channel shape is {raw_tiff_file.shape}. Performing convertion to (H,W,C) format. ")
        raw_tiff_file = np.transpose(raw_tiff_file, (1, 2, 0))
        print(f"Channel shape is now {raw_tiff_file.shape}.")
    case 4: message = "Image Loaded: 4 dimension image detected... Currently we do not support 3D images."
    case _: message = "Image Loaded: Unknown image format."

  #performe image enhancements
  original_image = np.copy(raw_tiff_file) #copy to ensure that original data is unaltered
  channel_scaled = np.copy(raw_tiff_file)
  for c in range(raw_tiff_file.shape[2]): channel_scaled[:,:,c] = raw_tiff_file[:,:,c]/max(raw_tiff_file[:,:,c].flatten()) #channel dependent scaling based on maxiumin in each channel
  scaled_gamma15 = np.copy(channel_scaled)
  scaled_gamma15 = np.power(channel_scaled, 1/1.5) #gamma 1.5 scaling
  raw_tiff_file =  scaled_gamma15 #

  message = f"Image loaded. Image shape {raw_tiff_file.shape}, in HWC format" # success message
  global processed_img
  processed_img = raw_tiff_file # update global variable

  return message, visualize_channels(np.transpose(raw_tiff_file, (2,0,1)))

## Helper Functions

In [8]:
### Tiling helper functions

#calculate approriate dimentions for cropping based on cell size
def get_crop_dimentions(cell_size):
  ICR = 17.63  #Target Image to Cell Size Ratio: KNOWN VALUE based on our experiments
  return int(round(cell_size*ICR)) #ratio to crop

#uses inputted crop dimentions to crop and return an list of tiles
def get_tiles(im, yDim: int, xDim: int, saveFolder=None, id=None, ijmeta=None):
  if im.shape[0] <= yDim or im.shape[1] <= xDim: return [im]
  tiles = [im[y:y+yDim, x:x+xDim, :] for y in range(0,im.shape[0],yDim) for x in range(0,im.shape[1],xDim)]
  for i in range(len(tiles)): #add padding to tiles outside of region
    if tiles[i].shape[0] != yDim or tiles[i].shape[1] != xDim:
      padding = np.zeros( shape=(yDim,xDim,im.shape[2]), dtype=np.float64)
      padding[0:tiles[i].shape[0], 0:tiles[i].shape[1], :] = tiles[i]
      tiles[i] = padding
  return tiles


### Data extraction helper functions

#returns the number of cells represented by the size (dim1) of the bounding boxes list
def get_num_cells(bb):
  if bb is None: return 0
  return bb.shape[0]

#calculate and return centroids for each bounding box
def get_centroids(bounding_boxes, img_scale_dim, return_scaled_BB=False):
  centroids = []
  if bounding_boxes is None: return centroids                                                                             # NOTE: consider using a set for better searching efficiency?

  bb = bounding_boxes.cpu().detach().numpy()/1024*img_scale_dim   #x1, y1, x2, y2
  for x1, y1, x2, y2 in bb: centroids.append( ((x1+x2)/2, (y1+y2)/2) )
  if return_scaled_BB: return centroids, bb
  return centroids

## Run CellSAM

In [10]:
# runs cellSAM on specified channels for all tiles
# compiles and returns and pandas dataframe with detailed extracted quantification data
# as well as a summary dataframe with only summed cell counts per channel

def Run_CellSAM (channels, cell_size):

  #Crop file to appropriate dimentions
  cd = get_crop_dimentions(cell_size)
  tiles = get_tiles(im=processed_img, yDim=cd, xDim=cd)

  #Ensure lists are expected format
  print(f"Dimentions for tiles: {cd}")
  if isinstance(tiles, list): print(f"Tiles is a list of size {len(tiles)}!")
  else: tiles = [tiles]
  if isinstance(channels, list): print(f"Channels is a list of size {len(channels)}!")
  else: channels = [channels]

  #initialize dataframe
  SUMMARY_ARRAY = pd.DataFrame(columns=['Channel', 'Tile', 'Count', 'Centroid_Coords'])
  cells_per_channel = []

  for i in channels: #for each chosen channel

    #run cellSAM on each tile
    results = [segment_cellular_image(img[:,:,(i-1)], device=str(device)) for img in tiles]

    #add results to tables
    channelEntries = [pd.DataFrame( {'Channel':i, 'Tile':numtile+1, 'Count':get_num_cells(result[2]), 'Centroid_Coords': [get_centroids(result[2],cd)] }) for numtile, result in enumerate(results)]
    SUMMARY_ARRAY = SUMMARY_ARRAY._append(channelEntries, ignore_index=True)

    #calculate total number of cells for this channel
    cells_per_channel.append( SUMMARY_ARRAY.loc[SUMMARY_ARRAY['Channel']==i, ['Count'] ].sum()[0] )
    print(f"Total number of cells in channel {i}: {cells_per_channel[-1]}")

  #add data to summary dataframe
  SUMMARY_SMALL = pd.DataFrame({'Channel': channels, 'Total_Cell_Count': cells_per_channel})

  #visualization with dots
  fig = plt.figure(figsize=(40,40))

  #for every tile
  for i in range(len(tiles)):

    #show tile
    fig.add_subplot(3, 3, i+1)
    plt.imshow(visualize_channels(np.transpose(tiles[i], (2,0,1)))*2, zorder=1)

    #specify markers and colors for each channel
    #currently supports 5 concurrent channels for cell counting
    markers = ['o', '^', 's', 'x', 'p']
    colors = ['red', 'orange', 'white', 'yellow', 'pink']

    #draw dots
    for j in range(len(channels)):
      print(f"Visualize{i}, {channels[j]}, {colors[j]}, {markers[j]}, xorder {j+2}")
      centroids = [v for val in SUMMARY_ARRAY.loc[ (SUMMARY_ARRAY['Tile']==i+1) & (SUMMARY_ARRAY['Channel']==channels[j]), ['Centroid_Coords']].to_numpy().flatten().tolist() for v in val]
      for point in centroids: plt.scatter(point[0], point[1], c=colors[j], marker=markers[j], zorder=j+2)

  return SUMMARY_ARRAY, SUMMARY_SMALL, fig


# MAIN

In [None]:
processed_img = None #global variable workaround to help with working with TIFF files without

with gr.Blocks() as demo:

  # Begin with a nice title and description for the page
  gr.Markdown(
      """
      # Welcome to CellSAM_XL
      CellSAM XL is a tool to automate with quantitative analysis of large flourescent images.
      Input: TIFF images.
      Output: Visualization of identified cells. CSV of summary statistics (cell count and cell centroids) will be generated.
      """)

  # The first tab is for file uploads, processing, and channel seperation
  with gr.Tab("Upload File"):

    # Message
    gr.Markdown("""
        First, upload a tif image. Once uploaded, the image will be processed and displayed on the right hand side seperated by channel and as a composite image. Please use refer to these images to help determine which channels contain your cells of interest.
        Please note: while multi-channel input is accepted, we do not currently support z-stacks images or videos.
        """)

    with gr.Row(): # create row to contain two columns

      # file input and system message fields
      with gr.Column(): #first column (left
        input_img = gr.File()
        channels_msg = gr.Textbox(label='System Message')

      # output field for visualizations (display channels and composite image)
      with gr.Column(): #second column

        #individual channels
        with gr.Row():
          @gr.render(inputs=channels_msg) #whenever the channel message changes (aka new upload)
          def show_channels(channels_msg):
            if processed_img is not None:
              for c in range(processed_img.shape[2]): #dynamically generate fields for each channel in rows of 2
                gr.Image(processed_img[:,:,c], label=f"Channel {c+1}")  #display and label each channel appropriately

        #also display the composite image underneath the channels
        visualization = gr.Image(label='Composite Image')

        #event listener
        input_img.upload(fn=read_and_enhance_file, inputs=input_img, outputs=[channels_msg, visualization]) #event listener for image upload


  #actual image processing tab
  with gr.Tab("Process Image"):

    # Message
    gr.Markdown("""
        In this part of the CellSAM XL pipeline, we enter some parameters to help the model perform well. Please select the channels that contains the cells that you wish to count and provide a rough diameter of your cells in pixles.
        """)


    with gr.Row():

      #input parameters for CellSAM_XL Pipeline
      with gr.Column():

        #set up channel selection field
        channels_to_count = gr.CheckboxGroup(label="Select Channels to Count", info="To show channel options, please first input an image in the Upload Files tab.")

        #Update channel selection field dynamically every time channel_msg is changed (new upload)
        def update_channelslist(channels_to_count): return gr.CheckboxGroup(label="Select Channels to Count", choices=range(1,processed_img.shape[2]+1))
        channels_msg.change(update_channelslist, inputs=None, outputs=channels_to_count)

        #set up cell size input field
        cell_size = gr.Number(label="Cell Diameter (px)", value=14.75)
        start_button = gr.Button("Process Selected Channel(s)")

      #output field for results
      with gr.Column():
        visual = gr.Plot() #visualization field
        small_df = gr.Dataframe(label="Summary") #Short Summary dataframe
        large_df = gr.Dataframe(label="Full Dataframe") #Full Summary Dataframe

      #event listener
      start_button.click(fn=Run_CellSAM, inputs=[channels_to_count, cell_size], outputs=[large_df, small_df, visual])

demo.launch(share=True)