# Selecting Videos for Single-Model Focus: Analyzing the First n Hours Across All t Time Thresholds

## Imports & Configuration

In [None]:
#!/usr/bin/env python3
"""
STEP 1: Analysis & Selection
STEP 2: Processing (Copy, Fancy Optical Flow, Grad-CAM)
"""

import os
import sys

# --- 0. SUPPRESS TENSORFLOW NOISE ---
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import glob
import shutil
import traceback
import importlib.util
import numpy as np
import pandas as pd
import cv2
import re
from typing import Optional, List, Dict, Set
from itertools import combinations
import matplotlib.pyplot as plt # Optional, for inline viewing if desired

# =============================================================================
#                                CONFIGURATION
# =============================================================================

# --- Paths ---
# Where original dataset is (frames inside dish_well subfolders)
ORIGINAL_DATA_ROOT = "/home/phd2/Scrivania/CorsoData/blastocisti"

# Where the per-day summaries were written by the previous pipeline (Input for Step 1)
SUMMARY_BASE = "/home/phd2/Scrivania/CorsoRepo/cellPIV/_04_test/GRADCAM_batch_outputs_stratified"

# Path to your single-video gradcam runner
GRADCAM_SCRIPT = "/home/phd2/Scrivania/CorsoRepo/cellPIV/_04_test/test_single_gradcam.py"

# Project Root (for python path)
REPO_PARENT = "/home/phd2/Scrivania/CorsoRepo/cellPIV"

# --- Outputs ---
# Step 1 Output: Where to save the selection CSV lists
OUTPUT_DIR_ANALYSIS = "/home/phd2/Scrivania/CorsoRepo/cellPIV/_05_GRADCAM/output_analysis"

# Step 2 Output: Where to save the images, overlays, and gradcam maps
DEST_ROOT = "/home/phd2/Documenti/embryo/embryo_to_send_gradcam"

# Destination Subfolders
DST_CORRECT_BLASTO = os.path.join(DEST_ROOT, "correct_blasto")
DST_CORRECT_NOBLASTO = os.path.join(DEST_ROOT, "correct_no_blasto")
DST_INCORRECT_BLASTO = os.path.join(DEST_ROOT, "incorrect_blasto")
DST_INCORRECT_NOBLASTO = os.path.join(DEST_ROOT, "incorrect_no_blasto")

# --- Parameters ---
MODELS = ["ConvTran", "LSTMFCN"]
DAYS_TO_RUN = [1, 3, 5]             
PEAK_COLS = ["peak1_h", "peak2_h", "peak3_h", "peak4_h", "peak5_h"]
MAX_HOURS = 5.0                     # Selection criteria: "first n hours"
LIGHT_RUN = False                    # If True, cap the number of videos processed in Step 2
MAX_PER_CLASS = 3                   # Max videos per category (Correct/Wrong) per model if LIGHT_RUN

# --- Confidence Thresholds ---
# Correct Cases
CONF_LOW_DAY1 = 0.3;    CONF_HIGH_DAY1 = 0.7
CONF_LOW_DAY3 = 0.3;    CONF_HIGH_DAY3 = 0.7
CONF_LOW_DAY5 = 0.1;    CONF_HIGH_DAY5 = 0.9
DEFAULT_CONF_LOW = 0.5; DEFAULT_CONF_HIGH = 0.5

# Wrong Cases (High Confidence Errors)
CONF_LOW_DA1_WRONG = 0.3;    CONF_HIGH_DA1_WRONG = 0.7
CONF_LOW_DA3_WRONG = 0.3;    CONF_HIGH_DA3_WRONG = 0.3
CONF_LOW_DA5_WRONG = 0.2;    CONF_HIGH_DA5_WRONG = 0.8
DEFAULT_CONF_LOW_WRONG = 0.5; DEFAULT_CONF_HIGH_WRONG = 0.5

# --- Farneback Optical Flow Params ---
FARNEBACK_PARAMS = dict(
    pyr_scale=0.5, levels=4, winsize=25, iterations=3,
    poly_n=5, poly_sigma=1.2, flags=0
)
FRAME_EXTS = [".jpg", ".jpeg", ".png", ".tif", ".tiff"]

# --- Sys Path Setup ---
if REPO_PARENT not in sys.path:
    sys.path.insert(0, REPO_PARENT)

print("Configuration loaded.")

Configuration loaded.


## Step 1 Helpers (Analysis Logic)

In [19]:
# =============================================================================
#                        STEP 1 HELPERS: ANALYSIS & SELECTION
# =============================================================================

def find_summary_csv(day: int, model: str, base_dir: str) -> Optional[str]:
    expected = os.path.join(base_dir, f"day_{day}", model, f"summary_prediction_day{day}_{model}.csv")
    if os.path.exists(expected):
        return expected
    # Fallback search
    pattern = os.path.join(base_dir, f"day_{day}", model, f"*summary*day*{model}*.csv")
    matches = glob.glob(pattern)
    if matches: return matches[0]
    return None

def has_peak_within_first_hours(row: pd.Series, max_hours: float) -> bool:
    for c in PEAK_COLS:
        if c in row:
            v = row[c]
            try:
                if pd.isna(v) or v == "": continue
                if float(v) <= max_hours: return True
            except: continue
    return False

def check_confidence_logic(day, prob, pred_label, true_label):
    """
    Returns flags: (is_correct_and_confident, is_wrong_and_confident)
    """
    if pd.isna(prob) or pd.isna(pred_label) or pd.isna(true_label):
        return False, False
        
    pred_label = int(pred_label)
    true_label = int(true_label)
    
    # 1. Determine Thresholds based on Day
    if day == 1:
        c_low, c_high = CONF_LOW_DAY1, CONF_HIGH_DAY1
        w_low, w_high = CONF_LOW_DA1_WRONG, CONF_HIGH_DA1_WRONG
    elif day == 3:
        c_low, c_high = CONF_LOW_DAY3, CONF_HIGH_DAY3
        w_low, w_high = CONF_LOW_DA3_WRONG, CONF_HIGH_DA3_WRONG
    elif day == 5:
        c_low, c_high = CONF_LOW_DAY5, CONF_HIGH_DAY5
        w_low, w_high = CONF_LOW_DA5_WRONG, CONF_HIGH_DA5_WRONG
    else:
        c_low, c_high = DEFAULT_CONF_LOW, DEFAULT_CONF_HIGH
        w_low, w_high = DEFAULT_CONF_LOW_WRONG, DEFAULT_CONF_HIGH_WRONG

    is_correct = (pred_label == true_label)
    
    # Logic: Class 0 (No Blasto) checks against LOW threshold. Class 1 (Blasto) checks against HIGH.
    
    # A. Correct & High Confidence
    meet_correct = False
    if is_correct:
        if pred_label == 0:   meet_correct = (prob < c_low)
        elif pred_label == 1: meet_correct = (prob > c_high)
            
    # B. Wrong & High Confidence
    # If True=1 but Pred=0 (False Negative) -> we want prob to be VERY LOW (confidently 0)
    # If True=0 but Pred=1 (False Positive) -> we want prob to be VERY HIGH (confidently 1)
    meet_wrong = False
    if not is_correct:
        if pred_label == 0:   meet_wrong = (prob < w_low)
        elif pred_label == 1: meet_wrong = (prob > w_high)
            
    return meet_correct, meet_wrong

def read_summary_and_flag(file_path: str, max_hours: float, day=None) -> pd.DataFrame:
    df = pd.read_csv(file_path, dtype=str)
    
    # 1. Peak
    df[f"peak_within_{max_hours}h"] = df.apply(lambda r: has_peak_within_first_hours(r, max_hours), axis=1)
    
    # 2. Confidence Evaluation
    def eval_row(r):
        try:
            p = float(r.get("prob_class1"))
            pl = float(r.get("pred_label"))
            tl = float(r.get("true_label"))
            return check_confidence_logic(day, p, pl, tl)
        except:
            return False, False

    flags = df.apply(eval_row, axis=1, result_type='expand')
    df["flag_correct_conf"] = flags[0]
    df["flag_wrong_conf"] = flags[1]
    
    # 3. Combined Criteria
    # Must have Peak AND (CorrectConf OR WrongConf)
    df["meets_correct"] = (df[f"peak_within_{max_hours}h"] & df["flag_correct_conf"])
    df["meets_wrong"] = (df[f"peak_within_{max_hours}h"] & df["flag_wrong_conf"])
    
    return df

def intersect_dishes_across_days(model_day_sets, n_required=2):
    """
    Returns list of dishes that appear in at least n_required days.
    """
    available_days = sorted(model_day_sets.keys())
    if len(available_days) < n_required:
        return []
        
    intersect_sets = []
    for day_comb in combinations(available_days, n_required):
        # Intersection of the sets for this specific combination of days
        s = set.intersection(*[model_day_sets[d] for d in day_comb])
        intersect_sets.append(s)
        
    # Union of all valid intersections
    final_set = set()
    for s in intersect_sets:
        final_set |= s
    return sorted(list(final_set))

print("Step 1 Helpers defined.")

Step 1 Helpers defined.


## Run Step 1 (Analysis Execution)

In [21]:
def run_step1_selection():
    """
    Produces a dict: { model: { 'correct': [dish1, ...], 'wrong': [dish2, ...] } }
    """
    print("\n" + "="*50)
    print("STEP 1: ANALYSIS & SELECTION")
    print("="*50)
    
    os.makedirs(OUTPUT_DIR_ANALYSIS, exist_ok=True)
    
    # Structure: model -> day -> set_of_dishes
    map_correct = {m: {} for m in MODELS}
    map_wrong = {m: {} for m in MODELS}
    
    for model in MODELS:
        for d in DAYS_TO_RUN:
            csvp = find_summary_csv(d, model, SUMMARY_BASE)
            if csvp is None:
                print(f"  [Warn] No summary for {model} day {d}")
                continue
                
            df = read_summary_and_flag(csvp, MAX_HOURS, day=d)
            
            # Store valid sets
            valid_c = set(df.loc[df["meets_correct"], "dish_well"].dropna().unique())
            valid_w = set(df.loc[df["meets_wrong"], "dish_well"].dropna().unique())
            
            map_correct[model][d] = valid_c
            map_wrong[model][d] = valid_w
            
            print(f"  [Info] {model} Day {d}: Found {len(valid_c)} Correct-Conf, {len(valid_w)} Wrong-Conf")

    # Intersect and Save
    final_selection = {}
    
    for model in MODELS:
        list_c = intersect_dishes_across_days(map_correct[model], n_required=2)
        list_w = intersect_dishes_across_days(map_wrong[model], n_required=2)
        
        final_selection[model] = {'correct': list_c, 'wrong': list_w}
        
        # Save lists
        pd.DataFrame(list_c, columns=["dish_well"]).to_csv(
            os.path.join(OUTPUT_DIR_ANALYSIS, f"{model}_selected_CORRECT.csv"), index=False
        )
        pd.DataFrame(list_w, columns=["dish_well"]).to_csv(
            os.path.join(OUTPUT_DIR_ANALYSIS, f"{model}_selected_WRONG.csv"), index=False
        )
        print(f"  > {model} Final: {len(list_c)} Correct, {len(list_w)} Wrong.")
        
    return final_selection

# EXECUTE STEP 1
final_map = run_step1_selection()


STEP 1: ANALYSIS & SELECTION
  [Info] ConvTran Day 1: Found 36 Correct-Conf, 14 Wrong-Conf
  [Info] ConvTran Day 3: Found 175 Correct-Conf, 117 Wrong-Conf
  [Info] ConvTran Day 5: Found 118 Correct-Conf, 40 Wrong-Conf
  [Info] LSTMFCN Day 1: Found 6 Correct-Conf, 3 Wrong-Conf
  [Info] LSTMFCN Day 3: Found 182 Correct-Conf, 158 Wrong-Conf
  [Info] LSTMFCN Day 5: Found 143 Correct-Conf, 33 Wrong-Conf
  > ConvTran Final: 59 Correct, 30 Wrong.
  > LSTMFCN Final: 49 Correct, 8 Wrong.


## Step 2 Helpers (Processing)

In [22]:
# =============================================================================
#                        STEP 2 HELPERS: PROCESSING
# =============================================================================

# --- Import GradCAM Wrapper ---
def import_run_single_gradcam(script_path: str):
    script_dir = os.path.dirname(script_path)
    if script_dir not in sys.path:
        sys.path.insert(0, script_dir)
    try:
        spec = importlib.util.spec_from_file_location("test_single_gradcam", script_path)
        mod = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod) # type: ignore
        if hasattr(mod, "run_single_video"):
            return mod.run_single_video
    except Exception:
        pass
    return None

run_single_video_func = import_run_single_gradcam(GRADCAM_SCRIPT)
if not run_single_video_func:
    print("WARNING: Could not import Grad-CAM script. Processing will skip Grad-CAM step.")

# --- File Utils ---
def natural_keys(text):
    def atoi(t): return int(t) if t.isdigit() else t
    return [atoi(c) for c in re.split(r'(\d+)', text)]

def find_frames(dish_dir):
    files = []
    for ext in FRAME_EXTS:
        files.extend(glob.glob(os.path.join(dish_dir, f"*{ext}")))
    files.sort(key=natural_keys)
    return files

def deduce_true_label(dish_name):
    if os.path.isdir(os.path.join(ORIGINAL_DATA_ROOT, "blasto", dish_name)): return 1
    if os.path.isdir(os.path.join(ORIGINAL_DATA_ROOT, "no_blasto", dish_name)): return 0
    return None

def copy_frames(src_dir, dst_dir):
    os.makedirs(dst_dir, exist_ok=True)
    frames = find_frames(src_dir)
    if not frames: return False
    for f in frames:
        shutil.copy2(f, os.path.join(dst_dir, os.path.basename(f)))
    return True

# --- Visualization Helpers ---

def draw_fancy_arrow(img, pt1, pt2, color, thickness=1, tip_ratio=0.2):
    """
    Draws a high-quality arrow with a solid triangular tip using Anti-Aliasing.
    """
    pt1 = np.array(pt1, dtype=np.float64)
    pt2 = np.array(pt2, dtype=np.float64)
    
    vec = pt2 - pt1
    length = np.linalg.norm(vec)

    angle = np.arctan2(vec[1], vec[0])
    
    # Dynamic tip length (proportional to arrow length)
    tip_len = length * tip_ratio
    
    # Calculate triangle vertices for the tip
    # We want a solid triangle. 
    # Angle spread of the tip: pi/6 (30 degrees) looks sharp, pi/5 is fuller.
    arrow_angle = np.pi / 6 
    
    x_tip = pt2[0]
    y_tip = pt2[1]
    
    # Determine back points of the triangle
    x_left = x_tip - tip_len * np.cos(angle - arrow_angle)
    y_left = y_tip - tip_len * np.sin(angle - arrow_angle)
    
    x_right = x_tip - tip_len * np.cos(angle + arrow_angle)
    y_right = y_tip - tip_len * np.sin(angle + arrow_angle)
    
    triangle_cnt = np.array([
        [x_tip, y_tip],
        [x_left, y_left],
        [x_right, y_right]
    ], dtype=np.int32)
    
    # 1. Draw Shaft (Anti-Aliased)
    # We draw from start to tip. The tip polygon will cover the end of the line.
    cv2.line(img, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])), 
             color, thickness, lineType=cv2.LINE_AA)
    
    # 2. Draw Solid Tip (Anti-Aliased)
    cv2.fillPoly(img, [triangle_cnt], color, lineType=cv2.LINE_AA)


def compute_optical_flow_fancy(frames_dir, out_dir):
    """
    Computes Dense Optical Flow.
    Visuals: 
      1. Dense Heatmap (Inferno) for magnitude blended with image.
      2. Uniform grid of custom 'Fancy' arrows.
    """
    os.makedirs(out_dir, exist_ok=True)
    frames = find_frames(frames_dir)
    if len(frames) < 2: return

    prev = cv2.imread(frames[0])
    if prev is None: return
    prev_gray = cv2.cvtColor(prev, cv2.COLOR_BGR2GRAY)
    
    # --- Visualization Config ---
    arrow_step = 20            # Grid spacing
    arrow_scale = 3.0          # How much to magnify the movement length
    alpha_heatmap = 0.5        # Transparency of flow map
    
    print(f"   Computing Fancy Optical Flow for {len(frames)} frames...")
    
    for i in range(1, len(frames)):
        cur_path = frames[i]
        cur = cv2.imread(cur_path)
        if cur is None: continue
        cur_gray = cv2.cvtColor(cur, cv2.COLOR_BGR2GRAY)
        
        # 1. Calc Flow
        flow = cv2.calcOpticalFlowFarneback(prev_gray, cur_gray, None,
                                            pyr_scale=FARNEBACK_PARAMS['pyr_scale'],
                                            levels=int(FARNEBACK_PARAMS['levels']),
                                            winsize=int(FARNEBACK_PARAMS['winsize']),
                                            iterations=int(FARNEBACK_PARAMS['iterations']),
                                            poly_n=int(FARNEBACK_PARAMS['poly_n']),
                                            poly_sigma=float(FARNEBACK_PARAMS['poly_sigma']),
                                            flags=int(FARNEBACK_PARAMS['flags']))
        
        # 2. Dense Heatmap
        mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
        norm_mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
        norm_mag = norm_mag.astype(np.uint8)
        heatmap = cv2.applyColorMap(norm_mag, cv2.COLORMAP_INFERNO)
        
        cur_gray_bgr = cv2.cvtColor(cur_gray, cv2.COLOR_GRAY2BGR)
        overlay = cv2.addWeighted(cur_gray_bgr, 1.0, heatmap, 0.3, 0)
        # overlay = cv2.addWeighted(cur_gray_bgr, 1.0 - alpha_heatmap, heatmap, alpha_heatmap, 0)
        
        # 3. Draw Fancy Arrows
        h, w = cur_gray.shape
        y, x = np.mgrid[arrow_step//2:h:arrow_step, arrow_step//2:w:arrow_step].reshape(2,-1).astype(int)
        
        fx, fy = flow[y, x].T
        
        # Calculate Endpoints
        # x, y are starting points.
        # target_x, target_y are endpoints based on flow.
        target_x = x + fx * arrow_scale
        target_y = y + fy * arrow_scale
        
        # Flatten for iteration
        p1s = np.stack([x, y], axis=-1).reshape(-1, 2)
        p2s = np.stack([target_x, target_y], axis=-1).reshape(-1, 2)
        
        # Calculate brightness for each arrow based on magnitude
        arrow_mags = np.sqrt(fx**2 + fy**2).reshape(-1)
        max_mag = arrow_mags.max()
        if max_mag > 0:
            arrow_brightness = (arrow_mags / max_mag) * 255
        else:
            arrow_brightness = arrow_mags

        for p1, p2, val in zip(p1s, p2s, arrow_brightness):
            # Brighter = Faster. Clip lowest to 50 so slow arrows are still slightly visible (dark grey)
            c_val = int(np.clip(val + 50, 50, 255))
            color = (c_val, c_val, c_val) 
            
            # Use custom drawer
            # tip_ratio=0.25 gives a nice substantial triangle
            draw_fancy_arrow(overlay, p1, p2, color, thickness=1, tip_ratio=0.25)

        # Save
        out_fname = os.path.join(out_dir, f"flow_{os.path.basename(cur_path)}")
        cv2.imwrite(out_fname, overlay)
        
        prev_gray = cur_gray

print("Step 2 Helpers defined.")

Step 2 Helpers defined.


## Run Step 2 (Processing Execution)

In [23]:
def run_step2_processing(selection_map: Dict[str, Dict[str, List[str]]]):
    """
    Iterates through the selection map from Step 1.
    Copies frames, runs GradCAM, runs Optical Flow.
    """
    print("\n" + "="*50)
    print("STEP 2: PROCESSING (Copy, GradCAM, Fancy Optical Flow)")
    print("="*50)
    
    for model_name, cat_dict in selection_map.items():
        
        for category, dishes in cat_dict.items(): # category is 'correct' or 'wrong'
            if not dishes: continue
            
            print(f"\n--- Model: {model_name} | Type: {category.upper()} ({len(dishes)} candidates) ---")
            
            # 1. Cap for Light Run
            dishes_to_proc = []
            if LIGHT_RUN:
                c0, c1 = 0, 0
                for d in dishes:
                    lbl = deduce_true_label(d)
                    if lbl == 1 and c1 < MAX_PER_CLASS:
                        dishes_to_proc.append(d); c1+=1
                    elif lbl == 0 and c0 < MAX_PER_CLASS:
                        dishes_to_proc.append(d); c0+=1
                print(f"   [Light Run] Reduced to {len(dishes_to_proc)} videos.")
            else:
                dishes_to_proc = dishes
                
            # 2. Process Loop
            for dish in dishes_to_proc:
                tlabel = deduce_true_label(dish)
                if tlabel is None:
                    print(f"   [Skip] Could not find source for {dish}")
                    continue
                
                # DETERMINE DESTINATION
                if category == 'correct':
                    dest_cat_folder = DST_CORRECT_BLASTO if tlabel == 1 else DST_CORRECT_NOBLASTO
                else:
                    dest_cat_folder = DST_INCORRECT_BLASTO if tlabel == 1 else DST_INCORRECT_NOBLASTO
                
                # Standard Logic
                dish_root = os.path.join(dest_cat_folder, model_name, dish)
                plain_video_dir = os.path.join(dish_root, "plain_video")
                overlay_dir = os.path.join(dish_root, "overlay_farneback")
                gradcam_base = os.path.join(dish_root, "gradcam_maps")
                
                # A. Copy Frames
                src_path = os.path.join(ORIGINAL_DATA_ROOT, "blasto" if tlabel==1 else "no_blasto", dish)
                if not os.path.exists(plain_video_dir):
                    print(f"   [{dish}] Copying frames...")
                    if not copy_frames(src_path, plain_video_dir):
                        print(f"     Failed to copy frames for {dish}")
                        continue
                
                # B. Optical Flow (Fancy)
                if not os.path.exists(overlay_dir):
                    print(f"   [{dish}] Computing Optical Flow...")
                    try:
                        compute_optical_flow_fancy(plain_video_dir, overlay_dir)
                    except Exception as e:
                        print(f"     Flow Error: {e}")
                        
                # C. Grad-CAM
                if run_single_video_func:
                    print(f"   [{dish}] Running Grad-CAM for days {DAYS_TO_RUN}...")
                    for day in DAYS_TO_RUN:
                        # Check if already done (heuristic check)
                        day_dir = os.path.join(gradcam_base, f"day_{day}", model_name)
                        if os.path.exists(day_dir) and len(glob.glob(os.path.join(day_dir, "*.png"))) > 0:
                            # print(f"     Day {day} exists, skipping.") 
                            continue
                        
                        try:
                            run_single_video_func(
                                day=day, 
                                specific_video=dish, 
                                output_base=gradcam_base, 
                                models_to_run=[model_name],
                                use_model_in_result_path=False,
                                create_dir_with_data_name=False
                            )
                        except Exception as e:
                            print(f"     GradCAM Error (Day {day}): {e}")
                else:
                    print(f"   [{dish}] Skipping Grad-CAM (function import failed).")

# EXECUTE STEP 2
run_step2_processing(final_map)
print("\nDONE.")


STEP 2: PROCESSING (Copy, GradCAM, Fancy Optical Flow)

--- Model: ConvTran | Type: CORRECT (59 candidates) ---
   [D2013.06.22_S0792_I141_3] Copying frames...
   [D2013.06.22_S0792_I141_3] Computing Optical Flow...
   Computing Fancy Optical Flow for 571 frames...
   [D2013.06.22_S0792_I141_3] Running Grad-CAM for days [1, 3, 5]...

=== Single video Grad-CAM: day 1 — D2013.06.22_S0792_I141_3 ===
Loading test data from: /home/phd2/Scrivania/CorsoRepo/cellPIV/datasets/Farneback/subsets/Normalized_sum_mean_mag_1Days_test.csv

--- Model: ConvTran ---
Wrote named modules to: /home/phd2/Documenti/embryo/embryo_to_send_gradcam/correct_no_blasto/ConvTran/D2013.06.22_S0792_I141_3/gradcam_maps/day_1/named_modules.txt
Detected Conv2d layer 'embed_layer.0' with in_channels==1; using 2D sample (C,T).
use_conv2d=False for model ConvTran
Using target layer: embed_layer.3

                                      Executing SignalGrad-CAM                                      

AVAILABLE EXPLAINERS:
 - E