In [14]:
import io
import os
import sys
src_path = os.path.split(os.getcwd())[0]
sys.path.insert(0, src_path)

import re
import glob
from tqdm import tqdm
import zipfile
import itertools
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from training.datasets import CellPainting
import clip.helpers



def group_samples(indir):
    dirlist = glob.glob(os.path.join(indir, "*"))           # Loads all directories from the indir

    basenames = [os.path.basename(d) for d in dirlist]      # Extracts basenames for directories
 #   print(f"group_samples() -> basenames of directory : {basenames}")

    # ME : They will be grouped and sorted based on the plates, and in our case first 7 characters are the PLATE_ID
    plate_groups = [list(g) for _, g in itertools.groupby(sorted(basenames), lambda x: x[0:8])]
 #   print(f"plate_groups : {plate_groups}")

    fullpath_groups = []
    basenames_groups = []

    order = [2,3,5,1,4]                      # Defining the order in which channels will be grouped

    for g in plate_groups:                   # Preparing of full paths
     #   print(f"\tg={g}\n")
        fullpath_group = []
        basenames_group = []
        for f in g:
      #      print(f"\t\tf={f}\n")
            fullpath_group.append(os.path.join(indir, f))
            basenames_group.append(f)
        fullpath_groups.append(fullpath_group)
        basenames_groups.append(basenames_group)

 #   print(f"\nfullpath_group={fullpath_group}\nbasenames_group={basenames_group}\n")
    
    sample_list = []

    for i, plate in enumerate(fullpath_groups):
        plate_id = basenames_groups[i][0][0:7]
        print(f"\nplate_id = {plate_id}")

        plate_files = []
        for channel in plate:
            print(f"\tchannel = {channel}")
            z = zipfile.ZipFile(channel)     # Open .zip files and extracts the .png files
            file_list = z.namelist()
            for f in file_list:
                if f.endswith(".tiff"):
                    plate_files.append(f)
              #   print(f"\t\tfile = {f}")

        
        #plate_files = [os.path.join(dirname, f) for f in plate_files]
        
        # Groups .png files by substring of their names and sorts the files based on the predefined order.
  #      sample_groups = [list(g) for _, g in itertools.groupby(sorted(plate_files, key=lambda x: x[-49:-43]), lambda x: x[-49:-43])]

        sample_groups = [list(g) for _, g in itertools.groupby(
            sorted(plate_files, key=lambda x: x.split('_')[-3]),     # Extract the last character before .png
            lambda x: x.split('_')[-3]                               # Use the last character before .png as the group key
        )]

   #     print(f"BEFORE FOR\n{sample_groups}\n\n\n")

        for g in sample_groups:
       #     ordered_group = [x for _, x in sorted(zip(order, g))]
            ordered_group = sorted(g, key=lambda x: x.split('_')[1])
            sample_list.append(ordered_group)

  #  print(f"Sample list : \n\t{sample_list}")
    
    return sample_list                        # Returns a list of images grouped and ordered based on PLATE_ID and channels


def process_sample(imglst, indir, outdir="."):      #imglst : group of images
  #  print(f"process_sample()\n{imglst}\n\n")

    sample = np.zeros((2500, 2500, 5), dtype=np.uint8) #520x696

    refimg = imglst[0]

    refimg = os.path.basename(imglst[0])
    
   # pattern = re.compile(".*(?P<plate>\d{5})\-(?P<channel>\w*).*\/.*\_(?P<well>\w\d{2})\_\w(?P<sample>\d).*")
#    pattern = re.compile(r"img_(?P<plate>\w+)_acqid_\d+_(?P<well>\w\d{2})_site_(?P<site>\d+)_merged_channel_\d+\.png")
 #   pattern = re.compile(r"img_(?P<plate>\w+)_acqid_\d+_(?P<well>\w\d{2})_site_(?P<site>\d+)_merged_channel_(?P<channel>\d+)\.png")      # Worked (before changing to .tiff)
    
    pattern = re.compile(r"(?P<plate>\w+)_(?P<well>\w\d{2})_s(?P<site>\d+)_.*_(?P<channel>\d+)_nm_Ex.tiff")

    print(f"\n{refimg}\n")
    ref_matches = pattern.match(refimg)
    plate, well, site = ref_matches["plate"], ref_matches["well"], ref_matches["site"]
    well = well.upper()
  
    sampleID = "-".join([plate, well, site])
    
    print(f"Regex\n\tplate={plate}\twell={well}\tsite={site}\t\tsample_ID={sampleID}\n")


    filenames, channels = {}, {}

    for i, imgfile in enumerate(imglst):
    #    print(f"imgfile={imglst}")
        
        dirname = os.path.dirname(imgfile)
        basename = os.path.basename(imgfile)
        base, ext = os.path.splitext(basename)

    #    print(f"dirname={dirname}\tbasename={basename}\tbase={base}\text={ext}\n")

        wavelength = base.split('_')[-3]
        wavelength_to_channel = {
            '405': 1,
            '488': 2,
            '561': 3,
            '638': 4,
            '730': 5
        }
        channel_id = wavelength_to_channel.get(wavelength, "Unknown")
        
       # dirname='P102785_channel_' + str(base.split('_')[-3][0])
        dirname='P102785_channel_' + str(channel_id)
        zipname = os.path.join(indir, dirname+".zip")

        z = zipfile.ZipFile(zipname)
        data = z.read(imgfile)
        dataenc = io.BytesIO(data)

        arr = img_to_numpy(dataenc)
    #    print(f"\n\nimgfile={imgfile}\n\timage_to_numpy shape of array = {arr.shape}")
        scaled_arr = process_image(arr)
        
 #       print(f"imgfile = {basename}")
        sample[:,:,i] = scaled_arr

        matches = pattern.match(basename)
        channel = matches["channel"]

        channels[i] = channel
        filenames[channel] = base

  #  outfile = str(plate)+"-"+str(well)+"-"+str(sampleID)
  #  outpath = os.path.join(outdir, outfile)
    outpath = os.path.join(outdir, sampleID)
    
    print(f"outpath = {outpath}")
    np.savez(outpath, sample=sample, channels=channels, filenames=filenames)
     
    return

def img_to_numpy(file):
    img = Image.open(file)
 #   img = img.convert('L')
    arr = np.array(img)

 #   print("\n\nImg_to_numpy")
 #   plt.imshow(img, cmap='gray')
 #   plt.show()
    
    return arr
    
def process_image(arr):
    threshold = illumination_threshold(arr)
    scaled_img = sixteen_to_eight_bit(arr, threshold)
    return scaled_img

# Calculates a threshold value to remove the highest percentage of pixels from an image.
def illumination_threshold(arr, perc=0.0028):
    """ Return threshold value to not display a percentage of highest pixels"""

    perc = perc/100

    h = arr.shape[0]             # Extracting number of pixels for height and width of an image.
    w = arr.shape[1]

    # Calculating number (n) of pixels to delete
    total_pixels = h * w         
    n_pixels = total_pixels * perc
    n_pixels = int(np.around(n_pixels))
    
    # Finding the value of the pixel with the highest value (as stated in the paper it will be used to ...)
    flat_inds = np.argpartition(arr, -n_pixels, axis=None)[-n_pixels:]
    inds = np.array(np.unravel_index(flat_inds, arr.shape)).T

#    print(f"array shape = {arr.shape}")
 #   print(f"\nillumination_threshold() : \n\tinds={inds}\n")
    
    max_values = [arr[i, j] for i, j in inds]
 #   max_values = [arr[tuple(ind)] for ind in inds]

    threshold = min(max_values)

    return threshold

def sixteen_to_eight_bit(arr, display_max, display_min=0):
    threshold_image = ((arr.astype(float) - display_min) * (arr > display_min))

    scaled_image = (threshold_image * (256. / (display_max - display_min)))
    scaled_image[scaled_image > 255] = 255

    scaled_image = scaled_image.astype(np.uint8)

  #  print(scaled_image.min(), scaled_image.max())
    
 #   print("\n\n\n\nHEeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
  #  plt.imshow(threshold_image, cmap='gray')
   # plt.show()
    
    return scaled_image



if __name__ == '__main__':
    indir = "/share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_DMSO_zip"         # Changed from bigger to all
    outdir = "/share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_DMSO_npz"
    n_cpus = 1  #60

 #   index_file = "/share/data/analyses/silvija/RT/data_cloome/our_images/metadata_P102785.csv"                  # Used for bigger
    index_file = "/share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/metadata_P102785_DMSO.csv"
 #   input_imgs = "/share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing"
    input_mols = "/share/data/analyses/silvija/RT/data_cloome/our_images/morgan_chiral_fps.hdf5"
    batchsize = 32

    sample_groups = group_samples(indir)
  #  print(f"\nsample_groups\n{np.array(sample_groups)}")
    
    sample_groups_trans = [list(i) for i in zip(*sample_groups)]
  #  sample_groups_prefix = [['/' + elem for elem in row] for row in sample_groups_trans]
 #   print(f"\nsample_groups_transposed\n{np.array(sample_groups_trans)}")
  
    result = clip.helpers.parallelize(process_sample, sample_groups_trans, n_cpus, indir=indir, outdir=outdir)
 #   print(result)

#     dataloader = get_dataloader(index_file, input_imgs, batchsize)
#     mean, std = get_mean_std(dataloader, stats_file)


plate_id = P102785
	channel = /share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_DMSO_zip/P102785_channel_1.zip
	channel = /share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_DMSO_zip/P102785_channel_2.zip
	channel = /share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_DMSO_zip/P102785_channel_3.zip
	channel = /share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_DMSO_zip/P102785_channel_4.zip
	channel = /share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_DMSO_zip/P102785_channel_5.zip

P102785_A02_s1_x0_y0_Fluorescence_405_nm_Ex.tiff

Regex
	plate=P102785	well=A02	site=1		sample_ID=P102785-A02-1

outpath = /share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_DMSO_npz/P102785-A02-1

P102785_A02_s2_x1_y0_Fluorescence_405_nm_Ex.tiff

Regex
	plate=P102785	well=A02	site=2		sample_ID=P1

In [23]:
import numpy as np

# Load the .npz file
data = np.load('/share/data/analyses/silvija/RT/data_cloome/our_images/preprocessing_all/channels_tiff_npz/P102785-P04-9.npz', allow_pickle=True)

# Check what arrays are stored in the .npz file
print(data.files)

# Accessing each array by its name
for key in data.files:
    print(f"{key}: {data[key]}")


['sample', 'channels', 'filenames']
sample: [[[ 2  4 14 19  3]
  [ 2  4 15 18  2]
  [ 2  4 14 19  3]
  ...
  [88 40 35 35 17]
  [87 43 33 33 15]
  [85 42 35 33 16]]

 [[ 2  4 14 17  3]
  [ 2  4 15 18  2]
  [ 2  4 14 18  2]
  ...
  [86 42 36 33 15]
  [86 40 32 31 15]
  [86 39 35 30 15]]

 [[ 2  4 13 19  3]
  [ 2  4 15 20  3]
  [ 2  4 16 18  3]
  ...
  [86 40 35 32 14]
  [86 39 33 31 14]
  [85 39 35 33 13]]

 ...

 [[ 7  7 30 22  9]
  [ 8  7 30 22 10]
  [ 9  7 28 21 10]
  ...
  [ 3  6 18 24  5]
  [ 3  7 16 21  3]
  [ 3  4 18 25  4]]

 [[ 9  7 30 21  9]
  [ 8  6 31 21  9]
  [ 8  7 29 20  8]
  ...
  [ 3  4 19 24  4]
  [ 3  4 18 24  4]
  [ 3  4 18 24  4]]

 [[ 9  7 26 19  8]
  [ 8  6 26 20  7]
  [ 9  7 28 20  9]
  ...
  [ 2  4 17 24  4]
  [ 3  4 17 23  3]
  [ 3  3 18 24  3]]]
channels: {0: '405', 1: '488', 2: '561', 3: '638', 4: '730'}
filenames: {'405': 'P102785_P04_s9_x2_y2_Fluorescence_405_nm_Ex', '488': 'P102785_P04_s9_x2_y2_Fluorescence_488_nm_Ex', '561': 'P102785_P04_s9_x2_y2_Fluoresc

In [16]:
sample = data['sample']
channels = data['channels']
filenames = data['filenames']

# Print the loaded arrays
print("Sample shape:", sample.shape)
print("Channels:", channels)
print("Filenames:", filenames)


Sample shape: (2500, 2500, 5)
Channels: {0: '405', 1: '488', 2: '561', 3: '638', 4: '730'}
Filenames: {'405': 'P102785_P04_s9_x2_y2_Fluorescence_405_nm_Ex', '488': 'P102785_P04_s9_x2_y2_Fluorescence_488_nm_Ex', '561': 'P102785_P04_s9_x2_y2_Fluorescence_561_nm_Ex', '638': 'P102785_P04_s9_x2_y2_Fluorescence_638_nm_Ex', '730': 'P102785_P04_s9_x2_y2_Fluorescence_730_nm_Ex'}
