# Bacteria in shell

In [None]:
import numpy as np
import pandas as pd
from scipy import io, ndimage as ndi, signal
from skimage import io as skio, color, filters, segmentation, feature, measure, morphology, transform, draw
from skimage.filters import median as skimage_filters_median, threshold_otsu, threshold_yen, threshold_multiotsu, gaussian
from skimage.io import imread
from skimage.measure import regionprops, regionprops_table, label
from skimage.morphology import erosion, dilation, square, disk
from skimage.segmentation import clear_border, watershed
from skimage.feature import peak_local_max
from skimage.draw import circle_perimeter
from skimage.color import label2rgb
import random
from datetime import datetime, timedelta, timezone
import pytz
import os
import math
import time
import glob
from PIL import Image
from tqdm import tqdm
import gc
import re
import platform
from skimage.io import imread as skio_imread

# --- User configuration (edit these two paths) ---
DATA_DIR = r"path/to/data_directory"       # folder containing your image folders
OUTPUT_DIR = r"path/to/output_directory"   # where CSV/Summary and exported crops will be saved

# -------------------- Metadata --------------------
CodeName = "Bacteria_in_shell"
Version = "1.4.20250405_shellThicknessErode"
Python_version = platform.python_version()
exp_id = "XB803"

# -------------------- IO paths --------------------
base_dir = DATA_DIR
folder_name_startwith = "folder_prefix"
save_path_fig = os.path.join(OUTPUT_DIR, "exports")
save_path_data_sheet = os.path.join(OUTPUT_DIR, "results") + os.sep
os.makedirs(save_path_fig, exist_ok=True)
os.makedirs(save_path_data_sheet, exist_ok=True)

# Cropped images go here (single folder)
cropped_out_dir = os.path.join(save_path_fig, f"{folder_name_startwith}_cropped")
os.makedirs(cropped_out_dir, exist_ok=True)

# -------------------- Presets --------------------
image_export = 0  # legacy matplotlib export; keep 0 because we now use fast crop export
only_one_image = 0
only_one_image_id = list(range(0, 1))
export_in_svg = 0

background_correction = 1  # 1 = apply background subtraction for DAPI; 0 = skip

channel_num = 4
num_steps = 4
channel_name = ["DAPI", "FITC-agarose", "AF647-alginate", "PC"]
label_channel = 1          # labels come from FITC-agarose
intensity_channel = 0      # measure intensity from DAPI
analysis_list = [1, 1, 0, 0]
binarization_list = [0, 1, 0, 0]
threshold_list = [0, threshold_otsu, 0, 0]
signal_threshold_list = [0, 0, 0, 0]
gaussian_sigma_list = [1, 1, 1, 1]
ero = 5
dil = 5
watershed_list = [0, 1, 0, 0]
watershed_footprint = (3, 3)
circularity_threshold_list = [0, 0.9, 0, 0]
area_threshold_list = [0, 7000, 0, 0]
upper_area_threshold_list = [100000, 10000, 100000, 100000]
regionprops_tuple = (
    "label", "centroid", "area", "intensity_mean",
    "major_axis_length", "minor_axis_length", "solidity"
)

timestamp_short = datetime.utcnow().replace(tzinfo=pytz.timezone("UTC"))\
    .astimezone(pytz.timezone("Asia/Tokyo")).strftime("%Y%m%d_%H%M")

# ---- Fast conditional CROP export settings ----
INTENSITY_CUTOFF = 84.53        # threshold for inside-shell DAPI mean
CROP_EXPORT = True             # turn on fast crop exports
CROP_PAD = 10                  # pixels of padding around each shell crop
MAX_CROPS_PER_FIELD = 50       # cap per f-id to avoid too many files
MAX_CROP_WIDTH = 800           # downscale wide crops; set None to keep original
JPEG_QUALITY = 85              # 70–90 is a good balance

# ---- NEW: shell thickness erosion settings ----
# 445 pixels correspond to 275 µm
UM_PER_PIXEL = 275.0 / 445.0   # ≈ 0.618 µm / pixel
# >>> Set this to the thickness (µm) you want to erode away from the shell mask:
SHELL_EROSION_UM = 6.1         # e.g. 5.0; set 0.0 to disable additional erosion

# We need raw images for these channels even if analysis_list says 0
# 0:DAPI, 1:FITC-agarose, 3:PC
REQUIRED_RAW_FOR_EXPORT = {0, 1, 3}

# -------------------- Helpers --------------------
def filter_labels_by_circularity(label_img, circularity_threshold=0.0):
    selected_labels = []
    for prop in regionprops(label_img):
        area = prop.area
        perimeter = prop.perimeter if prop.perimeter > 0 else 1
        circularity = (4 * np.pi * area) / (perimeter ** 2)
        if circularity >= circularity_threshold:
            selected_labels.append(prop.label)
    return np.isin(label_img, selected_labels)

def filter_labels_by_area(label_img, area_min=0, area_max=np.inf):
    label_img = label_img.astype(np.int32)
    selected_labels = [
        prop.label for prop in regionprops(label_img)
        if area_min <= prop.area <= area_max
    ]
    return np.isin(label_img, selected_labels)

def basic_binarization(img, gaussian_sigma, threshold_method, binarization_apply, signal_threshold):
    if binarization_apply == 0:
        return (img >= signal_threshold).astype(bool)
    img = skimage_filters_median(img)
    img = gaussian(img, sigma=gaussian_sigma)
    img_thresh = threshold_method(img)
    img = (img >= img_thresh)
    img = erosion(img, square(ero))
    img = dilation(img, square(dil))
    img = ndi.binary_fill_holes(img)
    img = clear_border(img)
    return img.astype(bool)

def apply_watershed(img):
    distance = ndi.distance_transform_edt(img)
    local_maxi = peak_local_max(distance,
                                footprint=np.ones(watershed_footprint),
                                labels=img.astype(bool))
    local_maxi_mask = np.zeros_like(distance, dtype=bool)
    if local_maxi.size > 0:
        local_maxi_mask[tuple(local_maxi.T)] = True
    markers, _ = ndi.label(local_maxi_mask)
    label_img = watershed(-distance, markers, mask=img.astype(bool))
    return label_img.astype(np.int32)

def subtract_median_background(dapi_img, fitc_binarized):
    # background = DAPI outside FITC-agarose mask
    background_mask = (fitc_binarized == False)
    background_values = dapi_img[background_mask]
    if len(background_values) == 0:
        return dapi_img, 0.0
    median_bkg = np.median(background_values)
    corrected = dapi_img.astype(float) - median_bkg
    corrected[corrected < 0] = 0
    return corrected, median_bkg

def rename_df_columns(
    df, label_channel, intensity_channel, channel_name, background_correction
):
    label_prefix = f"{channel_name[label_channel]}-mask"
    if background_correction == 1:
        intensity_suffix = f"{channel_name[intensity_channel]}_mean_intensity_corrected"
    else:
        intensity_suffix = f"{channel_name[intensity_channel]}_mean_intensity_raw"
    intensity_prefix = f"{channel_name[label_channel]}-mask_{intensity_suffix}"
    rename_dict = {
        "label":             f"{label_prefix}_label",
        "area":              f"{label_prefix}_area",
        "centroid-0":        f"{label_prefix}_centroid-0",
        "centroid-1":        f"{label_prefix}_centroid-1",
        "major_axis_length": f"{label_prefix}_major_axis_length",
        "minor_axis_length": f"{label_prefix}_minor_axis_length",
        "solidity":          f"{label_prefix}_solidity",
        "intensity_mean":    intensity_prefix
    }
    df.rename(columns=rename_dict, inplace=True)
    return df

def _normalize_to_uint8(img):
    # robust percentile stretch to uint8 for display
    vmin, vmax = np.percentile(img, [1, 99])
    if vmax <= vmin:
        vmax = vmin + 1e-6
    img_n = np.clip((img - vmin) / (vmax - vmin), 0, 1)
    return (img_n * 255).astype(np.uint8)

def _square_bbox_with_pad(bbox, pad, h, w):
    # bbox: (minr, minc, maxr, maxc) -> square bbox with padding and clipping
    minr, minc, maxr, maxc = bbox
    hr = maxr - minr
    wr = maxc - minc
    side = max(hr, wr)
    cr = (minr + maxr) // 2
    cc = (minc + maxc) // 2
    half = side // 2
    r0 = cr - half
    r1 = r0 + side
    c0 = cc - half
    c1 = c0 + side
    r0 -= pad; r1 += pad; c0 -= pad; c1 += pad
    r0 = max(0, r0); c0 = max(0, c0)
    r1 = min(h, r1); c1 = min(w, c1)
    return int(r0), int(c0), int(r1), int(c1)

# -------------------- Discover folders --------------------
folder_names = sorted([
    d for d in os.listdir(base_dir)
    if os.path.isdir(os.path.join(base_dir, d)) and d.startswith(folder_name_startwith)
])

# -------------------- Main loop --------------------
for folder_name in folder_names:
    parts = folder_name.rsplit("_Day", 1)
    mouse_info = parts[0]
    day_info = "Day" + parts[1]
    sample_id = f"{mouse_info}_{day_info}"
    main_path = os.path.join(base_dir, folder_name)

    tif_files = sorted([f for f in os.listdir(main_path) if f.lower().endswith('.tif')])

    summary_lines = [
        "",
        f"Code: {CodeName}_Ver. {Version}",
        f"Python_version: {Python_version}",
        f"Date: {datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y-%m-%d %H:%M:%S')}",
        f"Experiment ID: {exp_id}",
        f"Sample ID: {sample_id}",
        f"File Path: {main_path}",
        f"Number of TIFF Files: {len(tif_files)}",
        f"Number of Channels per Field: {channel_num}",
        f"Total Steps (Raw & Processed Images): {num_steps}",
        f"Channel Names: {', '.join(channel_name)}",
        f"Gaussian Sigma Values: {gaussian_sigma_list}",
        f"Erosion Kernel Size: {ero}",
        f"Dilation Kernel Size: {dil}",
        f"Watershed Application per Channel: {watershed_list}",
        f"Lower Area Thresholds per Channel: {area_threshold_list}",
        f"Upper Area Thresholds per Channel: {upper_area_threshold_list}",
        f"Circularity Threshold: {circularity_threshold_list}",
        f"Binarization Settings per Channel: {binarization_list}",
        f"Binarization Methods: {threshold_list}",
        f"Signal Thresholds per Channel: {signal_threshold_list}",
        f"Labels for Region Properties: {channel_name[label_channel]}",
        f"Extended Properties for Region Analysis: {', '.join(regionprops_tuple)}",
        f"INTENSITY_CUTOFF (inside shells): {INTENSITY_CUTOFF}",
        f"SHELL_EROSION_UM: {SHELL_EROSION_UM}",
        f"UM_PER_PIXEL: {UM_PER_PIXEL}",
        "",
        "List of TIFF Files:"
    ]
    summary_lines.extend(tif_files)
    summary_lines.append("")

    if len(tif_files) == 0:
        print(f"No TIFF files found in {main_path}. Skipping...")
        summary_filename = f"{save_path_data_sheet}{sample_id}_{timestamp_short}_Summary.txt"
        with open(summary_filename, 'a') as f:
            for line in summary_lines:
                f.write(line + '\n')
        continue

    # --- Parse and group by f-id; map d -> channel ---
    # filenames like: "..._A01fXXdYY.tif"  (XX = 0..199; YY in {0,1,3,4})
    f_d_regex = re.compile(r"_A\d{2}f(?P<f>\d{1,3})d(?P<d>[0134])(?:\.tif|_)", re.IGNORECASE)

    # Map each d to your channel index order:
    # d=0 -> channel 0 (DAPI)
    # d=1 -> channel 1 (FITC-agarose)
    # d=3 -> channel 2 (AF647-alginate)
    # d=4 -> channel 3 (PC)
    d_to_channel = {0: 0, 1: 1, 3: 2, 4: 3}
    channel_to_d  = {v: k for k, v in d_to_channel.items()}

    files_by_f = {}
    for fname in tif_files:
        m = f_d_regex.search(fname)
        if not m:
            continue
        f_id = int(m.group("f"))
        d_id = int(m.group("d"))
        files_by_f.setdefault(f_id, {})[d_id] = fname

    sorted_f_ids = sorted(files_by_f.keys())
    if not sorted_f_ids:
        print(f"[{sample_id}] No files matched the f/d pattern. Check filenames.")
        summary_filename = f"{save_path_data_sheet}{sample_id}_{timestamp_short}_Summary.txt"
        with open(summary_filename, 'a') as f:
            for line in summary_lines:
                f.write(line + '\n')
        continue

    # Which f-ids to process
    if only_one_image:
        target_f_ids = [i for i in only_one_image_id if i in files_by_f]
        if not target_f_ids:
            print(f"[{sample_id}] Requested only_one_image_id not found. Skipping sample.")
            summary_filename = f"{save_path_data_sheet}{sample_id}_{timestamp_short}_Summary.txt"
            with open(summary_filename, 'a') as f:
                for line in summary_lines:
                    f.write(line + '\n')
            continue
    else:
        target_f_ids = sorted_f_ids

    # -------- Analysis over f-ids --------
    df_all = pd.DataFrame()

    for image_pos_id in target_f_ids:  # image_id equals the "f" number
        file_dict = files_by_f[image_pos_id]

        required_ds = {0, 1, 3, 4}
        missing = required_ds - set(file_dict.keys())
        if missing:
            print(f"[{sample_id}] f={image_pos_id}: missing d={sorted(list(missing))}, skipping this set.")
            continue

        # Read one file to get shape (prefer DAPI)
        first_file_for_shape = file_dict.get(0, list(file_dict.values())[0])
        sample_img_shape = imread(os.path.join(main_path, first_file_for_shape), plugin='pil').shape
        imgs = np.zeros((channel_num, num_steps, *sample_img_shape), dtype=np.float32)

        # 1) Read all channels for this f-id (by channel index 0..3)
        read_error = False
        for j in range(channel_num):
            need_this = analysis_list[j] or (CROP_EXPORT and j in REQUIRED_RAW_FOR_EXPORT)
            if not need_this:
                continue
            d_needed = channel_to_d[j]
            file_to_read = file_dict[d_needed]
            try:
                imgs[j, 0] = imread(os.path.join(main_path, file_to_read), plugin='pil')
            except Exception as e:
                print(f"ERROR reading file {file_to_read} for f={image_pos_id}: {e}")
                read_error = True
                break

        if read_error:
            continue

        # 2) Binarize the FITC channel (label_channel == 1)
        if analysis_list[1] and binarization_list[1]:
            img_bin_fitc = basic_binarization(
                imgs[1, 0],
                gaussian_sigma_list[1],
                threshold_list[1],
                binarization_list[1],
                signal_threshold_list[1]
            )
            imgs[1, 1] = img_bin_fitc
        else:
            imgs[1, 1] = imgs[1, 0] > 0

        # 3) Background subtraction on DAPI (using FITC mask)
        median_bkg = 0.0
        if background_correction == 1 and analysis_list[0]:
            corrected, median_bkg = subtract_median_background(imgs[0, 0], imgs[1, 1])
            imgs[0, 0] = corrected

        # 4) Label on FITC + filtering
        if analysis_list[1] and binarization_list[1]:
            if watershed_list[1]:
                imgs[1, 2] = apply_watershed(imgs[1, 1])
            else:
                imgs[1, 2], _ = ndi.label(imgs[1, 1])
            filtered_by_area = filter_labels_by_area(
                imgs[1, 2],
                area_min=area_threshold_list[1],
                area_max=upper_area_threshold_list[1]
            )
            label_img_area_filtered, _ = ndi.label(filtered_by_area.astype(bool))
            imgs[1, 3] = filter_labels_by_circularity(
                label_img_area_filtered,
                circularity_threshold=circularity_threshold_list[1]
            ).astype(bool)
            imgs[1, 3], _ = ndi.label(imgs[1, 3])
        else:
            imgs[1, 3] = np.zeros_like(imgs[1, 1])

        # ==== NEW: erode shell labels by physical thickness (µm) ====
        if SHELL_EROSION_UM > 0:
            erode_pixels = int(math.ceil(SHELL_EROSION_UM / UM_PER_PIXEL))
            if erode_pixels > 0:
                shell_mask = imgs[1, 3] > 0
                shell_mask_eroded = erosion(shell_mask, disk(erode_pixels))
                # Re-label after erosion
                imgs[1, 3], _ = ndi.label(shell_mask_eroded.astype(bool))

        # 5) Regionprops (on FITC-selected labels; intensity from DAPI)
        label_img = imgs[1, 3].astype(int)
        df = pd.DataFrame(regionprops_table(
            label_img,
            intensity_image=imgs[0, 0],  # background-corrected if enabled
            properties=regionprops_tuple
        ))

        df = rename_df_columns(
            df,
            label_channel,
            intensity_channel,
            channel_name=channel_name,
            background_correction=background_correction
        )

        # add DAPI background median
        bkg_col = f"{channel_name[label_channel]}-mask_{channel_name[intensity_channel]}_background_median"
        df[bkg_col] = median_bkg

        # Stamp IDs
        df.insert(0, "image_id", image_pos_id)  # f number
        df.insert(0, "sample_id", sample_id)
        df.insert(0, "exp_id", exp_id)

        df_all = pd.concat([df_all, df], ignore_index=True).fillna(0)

        # ---- FAST CROPPED EXPORTS (no matplotlib) ----
        intensity_col = f"{channel_name[label_channel]}-mask_{channel_name[intensity_channel]}_mean_intensity_corrected"
        lbl_col = f"{channel_name[label_channel]}-mask_label"
        hits = df[df[intensity_col] > INTENSITY_CUTOFF]
        export_this_field = CROP_EXPORT and (len(hits) > 0)

        if export_this_field:
            H, W = label_img.shape

            # Build bbox map
            label_props = regionprops(label_img)
            bbox_map = {p.label: p.bbox for p in label_props}

            # Select hot labels and cap number
            hot_labels = [int(x) for x in hits[lbl_col].tolist() if int(x) in bbox_map]
            if len(hot_labels) > MAX_CROPS_PER_FIELD:
                hot_labels = hot_labels[:MAX_CROPS_PER_FIELD]

            for lbl in hot_labels:
                lbl_int = int(lbl)

                r0, c0, r1, c1 = _square_bbox_with_pad(bbox_map[lbl_int], CROP_PAD, H, W)

                dapi_crop = imgs[0, 0][r0:r1, c0:c1]
                fitc_crop = imgs[1, 0][r0:r1, c0:c1]
                pc_crop   = imgs[3, 0][r0:r1, c0:c1] if imgs[3, 0].any() else None

                # Normalize to 8-bit
                dapi_u8 = _normalize_to_uint8(dapi_crop)
                fitc_u8 = _normalize_to_uint8(fitc_crop)
                if pc_crop is not None:
                    pc_u8 = _normalize_to_uint8(pc_crop)

                # Optional downscale
                if MAX_CROP_WIDTH is not None:
                    def _down(x):
                        if x.shape[1] <= MAX_CROP_WIDTH:
                            return x
                        scale = MAX_CROP_WIDTH / x.shape[1]
                        new_h = int(x.shape[0] * scale)
                        return np.array(Image.fromarray(x).resize((MAX_CROP_WIDTH, new_h), Image.BILINEAR))
                    dapi_u8 = _down(dapi_u8)
                    fitc_u8 = _down(fitc_u8)
                    if pc_crop is not None:
                        pc_u8 = _down(pc_u8)

                base = os.path.join(
                    cropped_out_dir,
                    f"{sample_id}_f{image_pos_id:03d}_lbl{lbl_int}_{timestamp_short}"
                )
                Image.fromarray(dapi_u8).save(base + "_DAPI.jpg", format="JPEG", quality=JPEG_QUALITY, optimize=True)
                Image.fromarray(fitc_u8).save(base + "_FITC.jpg", format="JPEG", quality=JPEG_QUALITY, optimize=True)
                if pc_crop is not None:
                    Image.fromarray(pc_u8).save(base + "_PC.jpg",   format="JPEG", quality=JPEG_QUALITY, optimize=True)

        # (Optional) legacy matplotlib export — keep disabled for speed
    # Save DataFrame for this folder
    if not df_all.empty:
        csv_filename = f"{save_path_data_sheet}{sample_id}_{timestamp_short}.csv"
        df_all.to_csv(csv_filename, index=False)
        print(f"[{sample_id}] Saved: {csv_filename}")
    else:
        print(f"[{sample_id}] No data rows to save.")

    # Save summary
    summary_filename = f"{save_path_data_sheet}{sample_id}_{timestamp_short}_Summary.txt"
    with open(summary_filename, 'a') as f:
        for line in summary_lines:
            f.write(line + '\n')
    print(f"[{sample_id}] Summary written: {summary_filename}")