# Selecting Videos for Single-Model Focus: Analyzing the First n Hours Across All t Time Thresholds

## Import and config

In [None]:
#!/usr/bin/env python3
"""
combined_gradcam_pipeline.py

STEP 1: Analysis & Selection
    - Reads summary CSVs.
    - Filters for Correct Prediction + High Confidence + Peak within N hours.
    - Saves analysis CSVs to OUTPUT_DIR_ANALYSIS.

STEP 2: Processing
    - Takes the selected dishes from Step 1.
    - Copies frames to DEST_ROOT.
    - Computes Farneback Optical Flow overlays.
    - Runs Grad-CAM for the specific days.
"""

import os
import sys

# --- 0. SUPPRESS TENSORFLOW NOISE ---
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import glob
import shutil
import traceback
import importlib.util
import numpy as np
import pandas as pd
import cv2
from typing import Optional, List, Dict, Set

# =============================================================================
#                                CONFIGURATION
# =============================================================================

# --- Paths ---
# Where original dataset is (frames inside dish_well subfolders)
ORIGINAL_DATA_ROOT = "/home/phd2/Scrivania/CorsoData/blastocisti"
# Where the per-day summaries were written by the previous pipeline (Input for Step 1)
SUMMARY_BASE = "/home/phd2/Scrivania/CorsoRepo/cellPIV/_04_test/GRADCAM_batch_outputs_stratified"
# Path to your single-video gradcam runner
GRADCAM_SCRIPT = "/home/phd2/Scrivania/CorsoRepo/cellPIV/_04_test/test_single_gradcam.py"
# Project Root (for python path)
REPO_PARENT = "/home/phd2/Scrivania/CorsoRepo/cellPIV"

# --- Outputs ---
# Step 1 Output: Where to save the selection CSV lists
OUTPUT_DIR_ANALYSIS = "/home/phd2/Scrivania/CorsoRepo/cellPIV/_05_GRADCAM/output_analysis"
# Step 2 Output: Where to save the images, overlays, and gradcam maps
DEST_ROOT = "/home/phd2/Documenti/embryo/embryo_to_send_gradcam"
DST_CORRECT_BLASTO = os.path.join(DEST_ROOT, "correct_blasto")
DST_CORRECT_NOBLASTO = os.path.join(DEST_ROOT, "correct_no_blasto")
DST_INCORRECT_BLASTO = os.path.join(DEST_ROOT, "incorrect_blasto")
DST_INCORRECT_NOBLASTO = os.path.join(DEST_ROOT, "incorrect_no_blasto")

# --- Parameters ---
MODELS = ["ConvTran", "LSTMFCN"]
DAYS_TO_RUN = [1, 3, 5]             
PEAK_COLS = ["peak1_h", "peak2_h", "peak3_h", "peak4_h", "peak5_h"]
MAX_HOURS = 5.0                     # Selection criteria: "first n hours"

LIGHT_RUN = True                    # If True, cap the number of videos processed in Step 2
MAX_PER_CLASS = 3                   # Max videos per class per model (only if LIGHT_RUN=True)

# --- Confidence Thresholds (From Script 1 logic) ---
# Class 0 (No Blasto) must be < CONF_LOW
# Class 1 (Blasto)    must be > CONF_HIGH
CONF_LOW_DAY1 = 0.3;    CONF_HIGH_DAY1 = 0.7
CONF_LOW_DAY3 = 0.3;    CONF_HIGH_DAY3 = 0.7
CONF_LOW_DAY5 = 0.1;    CONF_HIGH_DAY5 = 0.9
DEFAULT_CONF_LOW = 0.5; DEFAULT_CONF_HIGH = 0.5

CONF_LOW_DA1_WRONG = 0.3;    CONF_HIGH_DA1_WRONG = 0.7
CONF_LOW_DA3_WRONG = 0.3;    CONF_HIGH_DA3_WRONG = 0.3
CONF_LOW_DA5_WRONG = 0.2;    CONF_HIGH_DA5_WRONG = 0.8
DEFAULT_CONF_LOW_WRONG = 0.5; DEFAULT_CONF_HIGH_WRONG = 0.5

# --- Farneback Optical Flow Params ---
FARNEBACK_PARAMS = dict(
    pyr_scale=0.5, levels=4, winsize=25, iterations=3,
    poly_n=5, poly_sigma=1.2, flags=0
)
FRAME_EXTS = [".jpg", ".jpeg", ".png", ".tif", ".tiff"]

# --- Sys Path Setup ---
if REPO_PARENT not in sys.path:
    sys.path.insert(0, REPO_PARENT)


## Helpers step 1

In [83]:
# =============================================================================
#                        STEP 1 HELPERS: ANALYSIS & SELECTION
# =============================================================================

def find_summary_csv(day: int, model: str, base_dir: str) -> Optional[str]:
    expected = os.path.join(base_dir, f"day_{day}", model, f"summary_prediction_day{day}_{model}.csv")
    if os.path.exists(expected):
        return expected
    # Fallback search
    pattern = os.path.join(base_dir, f"day_{day}", model, f"*summary*day*{model}*.csv")
    matches = glob.glob(pattern)
    if matches: return matches[0]
    return None

def has_peak_within_first_hours(row: pd.Series, max_hours: float) -> bool:
    for c in PEAK_COLS:
        if c in row:
            v = row[c]
            try:
                if pd.isna(v) or v == "": continue
                if float(v) <= max_hours: return True
            except: continue
    return False

def compute_pred_correct_and_confidence(m: pd.Series, day=None):
    """Return tuple (pred_correct, high_confident, prob_class1, pred_label, true_label)"""
    prob = np.nan
    try: prob = float(m.get("prob_class1")) if pd.notna(m.get("prob_class1")) else np.nan
    except: pass
    
    pred_label = None
    try: pred_label = int(float(m.get("pred_label"))) if pd.notna(m.get("pred_label")) else None
    except: pass
    
    true_label = None
    try: true_label = int(float(m.get("true_label"))) if pd.notna(m.get("true_label")) else None
    except: pass

    pred_correct = (pred_label is not None and true_label is not None and pred_label == true_label)
    
    high_confident = False
    if pred_label is not None and not np.isnan(prob):
        if day in [1, 3, 5]:
            c_low = CONF_LOW_DAY1 if day == 1 else (CONF_LOW_DAY3 if day == 3 else CONF_LOW_DAY5)
            c_high = CONF_HIGH_DAY1 if day == 1 else (CONF_HIGH_DAY3 if day == 3 else CONF_HIGH_DAY5)
        else:
            c_low, c_high = DEFAULT_CONF_LOW, DEFAULT_CONF_HIGH
            
        if pred_label == 0: high_confident = (prob < c_low)
        elif pred_label == 1: high_confident = (prob > c_high)
            
    return pred_correct, high_confident, prob, pred_label, true_label

def read_summary_and_flag(file_path: str, max_hours: float, day=None) -> pd.DataFrame:
    df = pd.read_csv(file_path, dtype=str)
    
    # Peak logic
    df[f"peak_within_{max_hours}h"] = df.apply(lambda r: has_peak_within_first_hours(r, max_hours), axis=1)
    
    # Confidence logic
    results = df.apply(lambda r: compute_pred_correct_and_confidence(r, day=day), axis=1, result_type='expand')
    df["pred_correct"] = results[0]
    df["high_confident"] = results[1]
    df["prob_class1_f"] = results[2]
    df["pred_label_i"] = results[3]
    df["true_label_i"] = results[4]

    # Combined Criteria
    df["meets_criteria"] = (df[f"peak_within_{max_hours}h"] & df["pred_correct"] & df["high_confident"])
    return df

def run_step1_selection():
    """
    Reads CSVs, determines which dishes meet criteria across ALL available days,
    saves summary CSVs to disk, and returns a dictionary {model: [list_of_dishes]}.
    """
    print("\n" + "="*50)
    print("STEP 1: ANALYSIS & SELECTION")
    print("="*50)
    
    os.makedirs(OUTPUT_DIR_ANALYSIS, exist_ok=True)
    
    model_day_sets = {m: {} for m in MODELS}
    model_day_dfs = {m: {} for m in MODELS}
    
    # 1. Read and Flag
    for model in MODELS:
        for d in DAYS_TO_RUN:
            csvp = find_summary_csv(d, model, SUMMARY_BASE)
            if csvp is None:
                print(f"  [Warn] No summary for {model} day {d}")
                continue
            df = read_summary_and_flag(csvp, MAX_HOURS, day=d)
            model_day_dfs[model][d] = df
            
            # Set of dishes meeting criteria for this specific day
            valid_dishes = set(df.loc[df["meets_criteria"], "dish_well"].dropna().unique().tolist())
            model_day_sets[model][d] = valid_dishes
            print(f"  [Info] {model} day {d}: {len(df)} rows, {len(valid_dishes)} dishes meet criteria")

    # 2. Intersect across days
    # n_required of total different days must have the dish meeting criteria
    model_selected_final = {} # {model: list_of_dishes}

    for model in MODELS:
        available_days = sorted(model_day_sets[model].keys())
        if not available_days:
            model_selected_final[model] = []
            continue
            
        # Intersection: Dish must be valid in at least n_required found days
        n_required = 2
        if len(available_days) < n_required:
            print(f"  [Warn] {model} has only {len(available_days)} available days, less than required {n_required}. Skipping.")
            model_selected_final[model] = []
            continue
        else:
            print(f"  [Info] {model} has {len(available_days)} available days, requiring intersection of {n_required} days.")
            # Perform intersection of all combinations of n_required days
            from itertools import combinations
            intersect_sets = []
            for day_comb in combinations(available_days, n_required):
                intersect_set = set.intersection(*[model_day_sets[model][d] for d in day_comb])
                intersect_sets.append(intersect_set)
            # Union of all intersected sets
            final_set = set()
            for s in intersect_sets:
                final_set |= s
            model_selected_final[model] = sorted(list(final_set))
        
        # 3. Generate Summary Table for this Model
        # (Union of all dishes seen in any day to show why they failed/passed)
        union_dishes = set()
        for d in available_days:
            union_dishes |= set(model_day_dfs[model][d]["dish_well"].dropna().unique())
            
        rows = []
        for dish in sorted(list(union_dishes)):
            row = {"dish_well": dish}
            all_criteria_met = True
            
            for d in available_days:
                df = model_day_dfs[model][d]
                matched = df[df["dish_well"] == dish]
                
                if matched.empty:
                    row[f"day{d}_meets"] = False
                    row[f"day{d}_details"] = "Missing"
                    all_criteria_met = False
                else:
                    m = matched.iloc[0]
                    meets = bool(m["meets_criteria"])
                    row[f"day{d}_meets"] = meets
                    if not meets: all_criteria_met = False
                    # Detailed debug string
                    row[f"day{d}_details"] = (f"Pk:{m[f'peak_within_{MAX_HOURS}h']}|"
                                              f"Corr:{m['pred_correct']}|"
                                              f"Conf:{m['high_confident']}")
                    row[f"day{d}_true"] = m['true_label_i']
            
            row["selected_final"] = all_criteria_met
            rows.append(row)
            
        table = pd.DataFrame(rows)
        out_name = f"{model}_analysis_summary.csv"
        table.to_csv(os.path.join(OUTPUT_DIR_ANALYSIS, out_name), index=False)
        
        # Save simple list
        pd.DataFrame(model_selected_final[model], columns=["dish_well"]).to_csv(
            os.path.join(OUTPUT_DIR_ANALYSIS, f"{model}_selected_list.csv"), index=False
        )
        print(f"  > {model}: Selected {len(model_selected_final[model])} dishes. Saved analysis to {out_name}")

    # Save both models intersection
    s1 = set(model_selected_final.get(MODELS[0], []))
    s2 = set(model_selected_final.get(MODELS[1], []))
    both = sorted(list(s1 & s2))
    pd.DataFrame(both, columns=["dish_well"]).to_csv(
        os.path.join(OUTPUT_DIR_ANALYSIS, "both_models_selected.csv"), index=False
    )
    print(f"  > Both Models Intersection: {len(both)} dishes.")
    
    return model_selected_final


## STEP 1 --> SELECTION

In [84]:
# 1. Run Analysis to get the list of dishes
selected_map = run_step1_selection()


STEP 1: ANALYSIS & SELECTION
  [Info] ConvTran day 1: 762 rows, 36 dishes meet criteria
  [Info] ConvTran day 3: 762 rows, 175 dishes meet criteria
  [Info] ConvTran day 5: 762 rows, 118 dishes meet criteria
  [Info] LSTMFCN day 1: 762 rows, 6 dishes meet criteria
  [Info] LSTMFCN day 3: 762 rows, 182 dishes meet criteria
  [Info] LSTMFCN day 5: 762 rows, 143 dishes meet criteria
  [Info] ConvTran has 3 available days, requiring intersection of 2 days.
  > ConvTran: Selected 59 dishes. Saved analysis to ConvTran_analysis_summary.csv
  [Info] LSTMFCN has 3 available days, requiring intersection of 2 days.
  > LSTMFCN: Selected 49 dishes. Saved analysis to LSTMFCN_analysis_summary.csv
  > Both Models Intersection: 9 dishes.


## Helpers step 2

In [85]:
# =============================================================================
#                        STEP 2 HELPERS: PROCESSING
# =============================================================================
import re

def import_run_single_gradcam(script_path: str):
    # Fix: Add script dir to sys.path so it can find _modelAdapter
    script_dir = os.path.dirname(script_path)
    if script_dir not in sys.path:
        sys.path.insert(0, script_dir)

    spec = importlib.util.spec_from_file_location("test_single_gradcam", script_path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod) # type: ignore
    if hasattr(mod, "run_single_video"):
        return mod.run_single_video
    raise ImportError("run_single_video not found in " + script_path)

# Try import immediately
try:
    run_single_video_func = import_run_single_gradcam(GRADCAM_SCRIPT)
except Exception:
    run_single_video_func = None
    print("WARNING: Could not import Grad-CAM script. Processing will skip Grad-CAM step.")
    traceback.print_exc()

def natural_keys(text):
    """
    alist.sort(key=natural_keys) sorts in human order
    (e.g. "Image_2.jpg" comes before "Image_10.jpg")
    """
    def atoi(t):
        return int(t) if t.isdigit() else t
    return [atoi(c) for c in re.split(r'(\d+)', text)]

def find_frames(dish_dir):
    files = []
    for ext in FRAME_EXTS:
        files.extend(glob.glob(os.path.join(dish_dir, f"*{ext}")))
    # Use natural keys instead of standard sort
    files.sort(key=natural_keys)
    return files

def deduce_true_label(dish_name):
    # Check physical existence in original source
    if os.path.isdir(os.path.join(ORIGINAL_DATA_ROOT, "blasto", dish_name)): return 1
    if os.path.isdir(os.path.join(ORIGINAL_DATA_ROOT, "no_blasto", dish_name)): return 0
    return None

def copy_frames(src_dir, dst_dir):
    os.makedirs(dst_dir, exist_ok=True)
    frames = find_frames(src_dir)
    if not frames: return False
    for f in frames:
        shutil.copy2(f, os.path.join(dst_dir, os.path.basename(f)))
    return True

def compute_optical_flow(frames_dir, out_dir):
    """
    Computes Dense Optical Flow (Farneback) between consecutive frames in frames_dir.
    Saves overlay images with flow vectors drawn to out_dir.
    Display single vectors as arrows but only if significant movement to reduce noise
    """
    os.makedirs(out_dir, exist_ok=True)
    frames = find_frames(frames_dir)
    
    if len(frames) < 2:
        return
    
    # Read the first frame
    prev = cv2.imread(frames[0], cv2.IMREAD_COLOR)
    if prev is None:
        return
    prev_gray = cv2.cvtColor(prev, cv2.COLOR_BGR2GRAY)
    
    # --- Visualization Config ---
    step = 16            # Grid spacing (pixels)
    scale = 3            # Visual length multiplier (does not affect filtering)
    color = (255, 255, 255)  # White arrows
    min_magnitude = 0.8  # FILTER: Ignore movements smaller than 0.8 pixels
    # ----------------------------
    
    print(f"   Computing Vector Optical Flow for {len(frames)} frames...")

    # Iterate starting from the second frame
    for i in range(1, len(frames)):
        cur_path = frames[i]
        cur = cv2.imread(cur_path, cv2.IMREAD_COLOR)
        if cur is None:
            continue
            
        cur_gray = cv2.cvtColor(cur, cv2.COLOR_BGR2GRAY)
        
        # Calculate Dense Flow (Farneback)
        flow = cv2.calcOpticalFlowFarneback(prev_gray, cur_gray, None,
                                            pyr_scale=FARNEBACK_PARAMS['pyr_scale'],
                                            levels=int(FARNEBACK_PARAMS['levels']),
                                            winsize=int(FARNEBACK_PARAMS['winsize']),
                                            iterations=int(FARNEBACK_PARAMS['iterations']),
                                            poly_n=int(FARNEBACK_PARAMS['poly_n']),
                                            poly_sigma=float(FARNEBACK_PARAMS['poly_sigma']),
                                            flags=int(FARNEBACK_PARAMS['flags']))
        
        # Create a canvas (copy of current frame)
        vis = cur.copy()
        
        # --- Vectorized Filtering & Drawing ---
        h, w = cur_gray.shape
        # 1. Get grid coordinates
        y, x = np.mgrid[step//2:h:step, step//2:w:step].reshape(2,-1).astype(int)
        
        # 2. Get flow values at those coordinates
        fx, fy = flow[y,x].T
        
        # 3. Calculate Magnitude (actual speed in pixels)
        magnitude = np.sqrt(fx**2 + fy**2)
        
        # 4. Create a mask: True where movement > threshold
        mask = magnitude > min_magnitude
        
        # 5. Apply mask (keep only significant vectors)
        x = x[mask]
        y = y[mask]
        fx = fx[mask]
        fy = fy[mask]
        
        # 6. Calculate arrow endpoints
        lines = np.vstack([x, y, x+fx*scale, y+fy*scale]).T.reshape(-1, 2, 2)
        lines = np.int32(lines + 0.5)
        
        # 7. Draw the valid lines
        for (x1, y1), (x2, y2) in lines:
            cv2.arrowedLine(vis, (x1, y1), (x2, y2), color, thickness=1, tipLength=0.5)
        
        # Output
        original_name = os.path.basename(cur_path)
        out_fname = os.path.join(out_dir, f"flow_{original_name}")
        cv2.imwrite(out_fname, vis)
        
        # Update previous frame
        prev_gray = cur_gray

def run_step2_processing(selection_map: Dict[str, List[str]]):
    """
    Iterates through the selection map from Step 1.
    Copies frames, runs GradCAM, runs Optical Flow.
    """
    print("\n" + "="*50)
    print("STEP 2: PROCESSING (Copy, GradCAM, Optical Flow)")
    print("="*50)

    for model_name, dishes in selection_map.items():
        print(f"\n--- Processing Model: {model_name} ({len(dishes)} candidates) ---")
        
        # 1. Cap for Light Run
        dishes_to_proc = []
        if LIGHT_RUN:
            c0, c1 = 0, 0
            for d in dishes:
                lbl = deduce_true_label(d)
                if lbl == 1 and c1 < MAX_PER_CLASS:
                    dishes_to_proc.append(d); c1+=1
                elif lbl == 0 and c0 < MAX_PER_CLASS:
                    dishes_to_proc.append(d); c0+=1
            print(f"   [Light Run] Reduced to {len(dishes_to_proc)} videos.")
        else:
            dishes_to_proc = dishes

        # 2. Process Loop
        for dish in dishes_to_proc:
            tlabel = deduce_true_label(dish)
            if tlabel is None:
                print(f"   [Skip] Could not find source for {dish}")
                continue
            
            label_str = "blasto" if tlabel == 1 else "no_blasto"
            dest_cat_folder = DST_CORRECT_BLASTO if tlabel == 1 else DST_CORRECT_NOBLASTO
            
            # Structure: .../correct_blasto/<dish_name>/
            dish_root = os.path.join(dest_cat_folder, model_name, dish)
            plain_video_dir = os.path.join(dish_root, "plain_video")
            overlay_dir = os.path.join(dish_root, "overlay_farneback")
            gradcam_base = os.path.join(dish_root, "gradcam_maps")
            
            # A. Copy Frames
            src_path = os.path.join(ORIGINAL_DATA_ROOT, label_str, dish)
            if not os.path.exists(plain_video_dir):
                print(f"   [{dish}] Copying frames...")
                if not copy_frames(src_path, plain_video_dir):
                    print(f"     Failed to copy frames for {dish}")
                    continue
            
            # B. Optical Flow
            if not os.path.exists(overlay_dir):
                print(f"   [{dish}] Computing Optical Flow...")
                try:
                    compute_optical_flow(plain_video_dir, overlay_dir)
                except Exception as e:
                    print(f"     Flow Error: {e}")

            # C. Grad-CAM
            if run_single_video_func:
                print(f"   [{dish}] Running Grad-CAM for days {DAYS_TO_RUN}...")
                for day in DAYS_TO_RUN:
                    # Check if already done (heuristic check)
                    day_dir = os.path.join(gradcam_base, f"day_{day}", model_name)
                    if os.path.exists(day_dir) and len(glob.glob(os.path.join(day_dir, "*.png"))) > 0:
                        # print(f"     Day {day} exists, skipping.") 
                        continue
                    
                    try:
                        run_single_video_func(
                            day=day, 
                            specific_video=dish, 
                            output_base=gradcam_base, 
                            models_to_run=[model_name],
                            use_model_in_result_path=False,
                            create_dir_with_data_name=False
                        )
                    except Exception as e:
                        print(f"     GradCAM Error (Day {day}): {e}")
            else:
                print(f"   [{dish}] Skipping Grad-CAM (function import failed).")


## STEP 2 --> PROCESSING

In [86]:
# 2. Run Processing on that list
run_step2_processing(selected_map)

print("\nDONE. Analysis saved to:", OUTPUT_DIR_ANALYSIS)
print("Processed assets saved to:", DEST_ROOT)


STEP 2: PROCESSING (Copy, GradCAM, Optical Flow)

--- Processing Model: ConvTran (59 candidates) ---
   [D2013.06.22_S0792_I141_3] Copying frames...
   [D2013.06.22_S0792_I141_3] Computing Optical Flow...
   Computing Vector Optical Flow for 571 frames...
   [D2013.06.22_S0792_I141_3] Running Grad-CAM for days [1, 3, 5]...

=== Single video Grad-CAM: day 1 — D2013.06.22_S0792_I141_3 ===
Loading test data from: /home/phd2/Scrivania/CorsoRepo/cellPIV/datasets/Farneback/subsets/Normalized_sum_mean_mag_1Days_test.csv

--- Model: ConvTran ---
Wrote named modules to: /home/phd2/Documenti/embryo/embryo_to_send_gradcam/correct_no_blasto/ConvTran/D2013.06.22_S0792_I141_3/gradcam_maps/day_1/named_modules.txt
Detected Conv2d layer 'embed_layer.0' with in_channels==1; using 2D sample (C,T).
use_conv2d=False for model ConvTran
Using target layer: embed_layer.3

                                      Executing SignalGrad-CAM                                      

AVAILABLE EXPLAINERS:
 - Explainer i

# Show some example