In [1]:
import wr_score_util as wu 
import datetime
import numpy as np
import cv2
from tqdm import tqdm
from google.genai import types
from google import genai
import os
import pandas as pd
import re
from time import sleep
import imageio.v3 as iio
import typing 
from pydantic import BaseModel, Field
from typing import List, Literal
import json
import time

In [2]:
def adjust_brightness_contrast(frames: np.ndarray, alpha: float = 1.2, beta: int = 30) -> np.ndarray:
    """
    Adjust brightness and contrast of video frames.

    Args:
        frames (np.ndarray): Input array of shape (frames, height, width, channels).
        alpha (float): Contrast control (1.0 = no change, >1 increases contrast).
        beta (int): Brightness control (0 = no change, positive increases brightness).

    Returns:
        np.ndarray: Brightness/contrast adjusted frames.
    """
    # Convert to float for safe scaling
    adjusted = frames.astype(np.float32) * alpha + beta
    
    # Clip to valid range and convert back
    adjusted = np.clip(adjusted, 0, 255).astype(np.uint8)
    return adjusted


In [3]:
class TimelineEvent(BaseModel):
    """Represents a single, classified event in the trial timeline."""
    time_range: str = Field(description="The start and end time of the event in 'start_time - end_time' format, e.g., '0.5s - 1.2s'.")
    
    event_classification: str = Field(
        description="A brief classification of the event (maximum 4 words), e.g., 'Paw reaches for spout' or 'Tongue licks paw'."
    )
    
    event_description: str = Field(
        description="A concise, objective, and detailed description of what was visually observed during the time range."
    )

# Define the main schema for the entire trial analysis
class TrialAnalysis(BaseModel):
    """
    Provides a structured analysis of a mouse's behavior in a water-reaching trial.
    All fields must be based on clear, indisputable visual evidence only.
    """
    tongue_contact: bool = Field(description="Did the mouse's tongue make physical contact with its paw after the paw touched the water drop?")
    water_drop_stable: bool = Field(description="Did the water drop remain fully attached to the spout before any interaction?")
    water_spilled: bool = Field(description="Did the mouse drop or splash any part of the water drop during retrieval?")
    percentage_consumed: int = Field(description="What percentage of the water drop did the mouse drink? Must be 0 if no contact was made.", ge=0, le=100)
    outcome_classification: Literal["Successful", "Partial Success", "Failed"] = Field(description="Classification of the trial's outcome.")
    justification: str = Field(description="Brief, evidence-based justification for the outcome classification.")
    timeline: List[TimelineEvent] = Field(description="A chronological list of key, observable events.")

In [4]:
def resize_frames(frames, scale=0.5):
    """
    Resize frames (N, H, W, C) by a given scale factor.
    """
    N, H, W, C = frames.shape
    new_H, new_W = int(H * scale), int(W * scale)
    resized = np.empty((N, new_H, new_W, C), dtype=frames.dtype)
    for i in range(N):
        resized[i] = cv2.resize(frames[i], (new_W, new_H), interpolation=cv2.INTER_AREA)
    return resized

def read_event(path):
    # Load the initial event log
    event_log = pd.read_csv(os.path.join(path, 'eventlog.txt'))

    # Initialize all new columns with the 'object' dtype.
    # This is the most flexible option and will prevent any dtype warnings
    # when you later insert strings, booleans, or JSON.
    event_log['raw_answers'] = pd.Series(dtype='object')
    event_log['tongue_contact'] = pd.Series(dtype='object')
    event_log['water_drop_stable'] = pd.Series(dtype='object')
    event_log['water_spilled'] = pd.Series(dtype='object')
    event_log['outcome_classification'] = pd.Series(dtype='object')
    event_log['justification'] = pd.Series(dtype='object')

    return event_log


In [None]:
temperature = 0
top_p = 0.9

ycrop = (0,740)
xcrop = (210,930)
fps = 54
start = 1.25
end = 4
temp_path = 'temp.mp4'
def_args= dict(min_dist=2, max_dist=4, resize_factor=1, num_workers=12)
#sys_instruc = ("You are a research assistant specializing in behavioral neuroscience. "
#               "Your job is to carefully and objectively score the outcome of a mouse doing water reaching trials. "
#               "Pay absolute attention to the water drop. Your entire analysis must be based on whether the drop is successfully retrieved and drank by the mouse."
#               "Focus exclusively on following: spout, water drop, mouse paws, and mouse mouth. Ignore all other background movement"
#               "In each video, there is only one drop given."
#               "Provide yes or no only as answers if requested."
#               "Be very brief and concise with you response")
sys_instruc =("You are a research assistant tasked with objectively scoring a mouse's behavior in a water-reaching trial." 
              "Your analysis must be based only on clear, indisputable visual evidence. Pay extremely close attention to the following elements:"
              "1. The Water Drop: Note its formation, its position on the spout, and any change to its shape or location."
              "2. The Mouse's Tongue (very obvious and similar to that of a dog's): Look for the moment the tongue is extended or if the tongue physically touches the water."
              "3. The Mouse's Paws: Observe if the paws lift off the surface and move towards the spout. Mainly the mouse's right paw (the mouse's perspective)."
              "Most importantly, the mouse is head fixed meaning it cannot move its head and its tongue cannot lick the spout directly.(the mouse's perspective and the one closest to the metal spout when the paws are at rest)"
              "The water drop can only be consumed by its right paw reaching the water drop on the spout and bringing it close to its mouth for it to lick."
              "If the mouse's right paw did not come near to the waterdrop or moved at all, it couldn't have drank any of the water drop."
              "The water drop is always delivered within the first 30-40 seconds of the video. In some cases, the water drop might nos stick to the spout for the mouse to touch"
              "Do not infer success or failure; only report the physical interactions you see. In each video, only one drop is delivered."
              "If no change or movement (right paw) is detected for a specific feature (e.g., paw, tongue), you must state 'No change was observed."
              "Do not repeat the question in your responses"
              "Do not describe an event unless there is undeniable visual evidence")


a1 = "Did the mouse's tongue make physical contact with the water drop? (Answer: Yes or No)"
a2 = "Did the water drop remain fully attached to the spout before the mouse touched it? (Answer: Yes or No)"
a3 = 'Did the mouse drop or splashed some of the water when retrieving the water to drink ? (Yes or No)'
a4 = "What percentage of the water drop did the mouse drink? (If no physical contact was made between the tongue and the drop, the answer must be 0%)."
a5 = """
Based on the provided video of a head-fixed mouse performing a water-reaching trial, classify the outcome as one of the following:
Successful: The mouse retrieved most of the water (more than 70%).
Partial Success: The mouse retrieved some of the water (between 50% and 70%).
Failed: The mouse retrieved little or none of the water (less than 40%).
Clearly state your classification and briefly justify your choice using only observable evidence from the video.
"""
a5 = a5.replace('\n', ' ')
a6 = "Provide a timeline of key events. Use the following specific format and describe only what you see ( Time start - Time end: Event)."



#api_key = 'AIzaSyDGq8tQcG7PfrRE3voGVH4aJVuCPDQplBo' # out of money
api_key = 'AIzaSyBWqcqGfSXIfkZQ0MXaJ6lc4ZEkHuei73Q' #yw
#api_key = 'AIzaSyD6LW-iymR4DdGIHmqb5u3rULxZqfulpDQ' #haokey
MODEL_ID="gemini-2.5-pro" 

config=types.GenerateContentConfig(
    temperature=temperature,
    system_instruction = sys_instruc,
    top_p=top_p)
config = types.GenerateContentConfig(
    temperature=temperature,
    top_p=top_p,
    # The structured output settings belong here:
    response_mime_type="application/json",
    response_schema=TrialAnalysis,
    system_instruction = sys_instruc
)



'''
config=types.GenerateContentConfig(
    system_instruction = sys_instruc)
'''
save_path = '/mnt/team/TM_Lab/Tony/wr_new/data_used/tta_gcamp8s/gemini_predictions_full_sys'


In [6]:
def paths_to_dict_unique(file_path):
    """
    Reads a file containing file paths (one per line) and returns a dictionary
    with the basename as key and the full path as value. If duplicate basenames
    exist, the last one is kept.
    """
    result = {}
    with open(file_path, 'r') as f:
        for line in f:
            full_path = line.strip()
            if full_path:
                base_name = os.path.basename(full_path)
                result[base_name] = full_path
    return result
file_paths = 'vid_paths.txt'
path_dict = paths_to_dict_unique(file_paths)

In [None]:
def paths_to_dict_unique(file_path):
    """
    Reads a file containing file paths (one per line) and returns a dictionary
    with the basename as key and the full path as value. If duplicate basenames
    exist, the last one is kept.
    """
    result = {}
    with open(file_path, 'r') as f:
        for line in f:
            full_path = line.strip()
            if full_path:
                base_name = os.path.basename(full_path)
                result[base_name] = full_path
    return result
file_paths = 'vid_paths.txt'
path_dict = paths_to_dict_unique(file_paths)

In [7]:
client = genai.Client(api_key=api_key)

In [8]:
sys_config1 = ['FU','FS','FJ']

#folders = ['/mnt/team/TM_Lab/Tony/wr_new/data_used/thy1_gcamp/FU/R2/FU_R2_2024-06-29_1','/mnt/team/TM_Lab/Tony/wr_new/data_used/thy1_gcamp/FU/R2/FU_R2_2024-07-18_1',]
#folders = ['/mnt/team/TM_Lab/Tony/wr_new/data_used/thy1_gcamp/FJ/R3/FJ_R3_2024-06-29_1','/mnt/team/TM_Lab/Tony/wr_new/data_used/thy1_gcamp/FU/R2/FU_R2_2024-06-29_1',
#           '/mnt/team/TM_Lab/Tony/wr_new/data_used/thy1_gcamp/FS/L2/FS_L2_2024-07-31_1','/mnt/team/TM_Lab/Tony/wr_new/data_used/thy1_gcamp/FJ/L2/FJ_L2_2024-07-29_1',
#           '/mnt/team/TM_Lab/Tony/wr_new/data_used/thy1_gcamp/FJ/L3/FJ_L3_2024-06-29_1','/mnt/team/TM_Lab/Tony/wr_new/data_used/thy1_gcamp/FS/L2/FS_L2_2024-06-28_1']
folders = [v for f,v in path_dict.items() if any(cfg in f for cfg in sys_config1)]

In [9]:
for path in folders:
    event_log = read_event(path)
    total  = event_log['outcome'].isna().sum()
    basename = os.path.basename(path)
    save_name = os.path.join(save_path,f"{basename}.csv")
    pbar=tqdm(total = total, desc=f"Processing trials in {basename}", leave=True,position=0)
    try:
        for idx, row in event_log.iterrows():
            if pd.isna(row['outcome']):
                try:
                    trial_path = os.path.join(path, f"Trial_{int(row['trial'])}")
                    #change to else thing
                    frames = wu.extract_vid2score(trial_path, start=start, end=end, xcrop=xcrop,ycrop=ycrop,fps=54, cam_use='brain')

                    #frames = resize_frames(frames, scale=0.7)
                    wu.write_video_opencv(frames, temp_path,fps=1)
                    start_time  = time.time()
                    video_file = client.files.upload(file=temp_path)
                    
                    while video_file.state.name == "PROCESSING":
                    #print('.', end='')
                        sleep(1)
                        video_file = client.files.get(name=video_file.name)
                    upload_time = time.time()
                    response = client.models.generate_content(
                        model=f"models/{MODEL_ID}",
                        contents=[
                            video_file,
                            a1,
                            a2,
                            a3,
                            a4,
                            a5],
                        config=config,
                        )
                    api_call_time = time.time()
                    client.files.delete(name=video_file.name)
                    event_log.loc[idx, 'raw_answers'] = str(response.text)
                    event_log.loc[idx, 'upload_time'] = upload_time - start_time
                    event_log.loc[idx, 'api_time'] = api_call_time-upload_time

                    full_dic = response.to_json_dict()

                    event_log.loc[idx, 'full_response'] = json.dumps(full_dic) # a dictionary
                    analysis: TrialAnalysis = response.parsed
                    answers = analysis.model_dump_json()
                    answers= json.loads(answers)
                    event_log.loc[idx, 'tongue_contact']=answers['tongue_contact']
                    event_log.loc[idx, 'water_drop_stable'] = answers["water_drop_stable"]
                    event_log.loc[idx, 'water_spilled'] = answers["water_spilled"]
                    event_log.loc[idx, 'outcome_classification'] = answers['outcome_classification']
                    event_log.loc[idx, 'justification'] = answers["justification"]                    
                    os.remove(temp_path)
                except Exception as e:
                    print('wrong trial check')
                    print(idx)
                    print(e)

                pbar.update(1)
                #import sys; sys.exit()
        event_log.to_csv(save_name, index=False)
        
    except Exception as e:
        print(f"Error occured at {basename}")
        print(e)
        event_log.to_csv(save_name, index=False)
    

Processing trials in FJ_R3_2024-07-15_1:   0%|          | 0/6 [00:00<?, ?it/s]

Processing trials in FJ_R3_2024-07-15_1: 100%|██████████| 6/6 [03:53<00:00, 38.96s/it]
Processing trials in FJ_L3_2024-07-15_1:  55%|█████▌    | 27/49 [18:04<13:59, 38.14s/it]

In [None]:
sys_config1 = ['K','AZ']
folders = [v for f,v in path_dict.items() if any(cfg in f for cfg in sys_config1)]
#folders = ['/mnt/team/TM_Lab/Tony/wr_new/data_used/tta_gcamp8s/cage_az_rig4/l3/AZ_L3_2024-11-22_1','/mnt/team/TM_Lab/Tony/wr_new/data_used/tta_gcamp8s/cage_k_rig2/r2/K_R2_2025-01-14_1',
#           '/mnt/team/TM_Lab/Tony/wr_new/data_used/tta_gcamp8s/cage_az_rig4/l3/AZ_L3_2024-12-16_1']
#folders = [f for f in folders if f"{os.path.basename(f)}.csv" not in os.listdir(save_path)]

In [None]:
for path in folders:
    event_log = read_event(path)
    total  = event_log['outcome'].isna().sum()
    basename = os.path.basename(path)
    save_name = os.path.join(save_path,f"{basename}.csv")
    pbar=tqdm(total = total, desc=f"Processing trials in {basename}", leave=True,position=0)
    try:
        for idx, row in event_log.iterrows():
            if pd.isna(row['outcome']):
                try:
                    trial_path = os.path.join(path, f"Trial_{int(row['trial'])}")
                    #change to else thing
                    frames = wu.extract_vid2score(trial_path, start=start, end=end, xcrop=xcrop,ycrop=ycrop,fps=54, cam_use='B')
                    #import sys;sys.exit()
                    frames = resize_frames(frames, scale=0.7)
                    wu.write_video_opencv(frames, temp_path,fps=1)
                    start_time  = time.time()                    
                    video_file = client.files.upload(file=temp_path)
                    while video_file.state.name == "PROCESSING":
                    #print('.', end='')
                        sleep(1)
                        video_file = client.files.get(name=video_file.name)
                    upload_time = time.time()
                    response = client.models.generate_content(
                        model=f"models/{MODEL_ID}",
                        contents=[
                            video_file,
                            a1,
                            a2,
                            a3,
                            a4,
                            a5],
                        config=config,
                        )
                    api_call_time = time.time()
                    client.files.delete(name=video_file.name)
                    event_log.loc[idx, 'raw_answers'] = str(response.text)
                    event_log.loc[idx, 'upload_time'] = upload_time - start_time
                    event_log.loc[idx, 'api_time'] = api_call_time-upload_time
                    full_dic = response.to_json_dict()
                    event_log.loc[idx, 'full_response'] = json.dumps(full_dic) # a dictionary
                    analysis: TrialAnalysis = response.parsed
                    answers = analysis.model_dump_json()
                    answers= json.loads(answers)
                    event_log.loc[idx, 'tongue_contact']=answers['tongue_contact']
                    event_log.loc[idx, 'water_drop_stable'] = answers["water_drop_stable"]
                    event_log.loc[idx, 'water_spilled'] = answers["water_spilled"]
                    event_log.loc[idx, 'outcome_classification'] = answers['outcome_classification']
                    event_log.loc[idx, 'justification'] = answers["justification"]                    
                    os.remove(temp_path)
                except Exception as e:
                    print('wrong trial check')
                    print(idx)
                    print(e)

                pbar.update(1)
                #import sys; sys.exit()
        event_log.to_csv(save_name, index=False)
        
    except Exception as e:
        print(f"Error occured at {basename}")
        print(e)
        event_log.to_csv(save_name, index=False)
    

Processing trials in FS_L2_2024-06-28_1: 100%|██████████| 86/86 [5:03:19<00:00, 211.63s/it]


Processing trials in AZ_L3_2024-11-22_1: 100%|██████████| 50/50 [18:49<00:00, 22.58s/it]
Processing trials in K_R2_2025-01-14_1: 100%|██████████| 74/74 [27:13<00:00, 22.07s/it]
Processing trials in AZ_L3_2024-12-16_1: 100%|██████████| 26/26 [10:15<00:00, 23.19s/it]