In [None]:
import argparse
import os, sys
import json
import numpy as np
import time

import WhiskiWrap as ww
from WhiskiWrap import load_whisker_data as lwd
# add directory above notebook to path
sys.path.append(os.path.dirname(os.path.abspath('')))
import whiskerpad as wp
# Check that whisk binaries are executables and update permissions if necessary
from wwutils.whisk_permissions import update_permissions
update_permissions()

In [None]:
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
input_dir=Path('/data')
# list all the files in the data directory
# data_files = list(input_dir.glob('*'))
# print(data_files)
video_name='sc016_0630_001_30sWhisking.mp4'
base_name='sc016_0630_001'
input_file=input_dir/video_name

# load the video's first frame 
frame_num = 0
cap = cv2.VideoCapture(str(input_file))
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = cap.read()
cap.release()
# display the image,
# plt.imshow(frame)
# plt.axis('off')
# plt.show()

# # Save the image as base_name_first_frame.png
# output_file = input_dir / f'{base_name}_first_frame.png'
# cv2.imwrite(str(output_file), frame)

#### Run whisker tracking on the first frame

In [None]:
splitUp=True
output_dir = input_dir / 'frame_0_whiskers'
nproc=40

In [None]:
# whiskerpad=wp.Params(str(input_file), splitUp, base_name)
# whiskerpadParams, splitUp = wp.WhiskerPad.get_whiskerpad_params(whiskerpad)

image_halves, image_side, face_side, fp = wp.get_side_image(str(input_file), splitUp)
# print each image halvse side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image_halves[0], cmap='gray')
ax[0].axis('off')
ax[0].set_title(f"Face side: {face_side[0]}")
ax[1].imshow(image_halves[1], cmap='gray')
ax[1].axis('off')
ax[1].set_title(f"Face side: {face_side[1]}")
plt.show()

# Save each side as tif
output_dir = input_dir / 'frame_0_whiskers'
output_dir.mkdir(exist_ok=True)
for i, side in enumerate(face_side):
    output_file = output_dir / f'{base_name}_first_frame_{side}.tif'
    cv2.imwrite(str(output_file), image_halves[i])

In [None]:
# Load whiskerpad json file
whiskerpad_file = os.path.join(input_dir, f'whiskerpad_{base_name}.json')
# whiskerpad_file = os.path.join(input_dir, f'whiskerpad_{os.path.basename(input_file).split(".")[0]}.json')

if not os.path.exists(whiskerpad_file):
# If whiskerpad file does not exist, create it
    print('Creating whiskerpad parameters file.')
    whiskerpad=wp.Params(input_file, splitUp, base_name)
    # Get whiskerpad parameters
    whiskerpadParams, splitUp = wp.WhiskerPad.get_whiskerpad_params(whiskerpad)
    # Save whisking parameters to json file
    wp.WhiskerPad.save_whiskerpad_params(whiskerpad, whiskerpadParams)

with open(whiskerpad_file, 'r') as f:
    whiskerpad_params = json.load(f)

# Check that left and right whiskerpad parameters are defined
if np.size(whiskerpad_params['whiskerpads'])<2:
    raise Exception('Missing whiskerpad parameters in whiskerpad json file.')

# Get side types (left / right or top / bottom)
side_types = [whiskerpad['FaceSide'].lower() for whiskerpad in whiskerpad_params['whiskerpads']]

In [None]:
#######################
### Run whisker tracking
########################

for side in side_types:
    print(f'Running whisker tracking for {side} face side video')

    # Time the tracking
    start_time_track = time.time()

    image_filename = output_dir / f'{base_name}_first_frame_{side}.tif'
    result_dict = ww.trace_and_measure_chunk(image_filename,
                                            delete_when_done=False,
                                            face=side,
                                            classify={'px2mm': '0.04', 'n_whiskers': '3'})
    

    time_track = time.time() - start_time_track
    print(f'Tracking took {time_track} seconds.')

In [None]:
# Load the whisker data and plot whiskers on each image
# whiskers is a list of dictionaries, each dictionary is a frame, each frame has a dictionary of whiskers
# Initialize the dictionary to store the whisker data for each side
whisker_data = {}

for side in side_types:
    print(f'Loading whiskers for {side} face side video')
    whisk_filename = output_dir / f'{base_name}_first_frame_{side}.whiskers'
    # Load whiskers
    whiskers = ww.wfile_io.Load_Whiskers(str(whisk_filename))
    whisker_data[side] = whiskers

print(f"Whisker data for {len(whisker_data)} sides loaded.")

In [None]:
# Initialize the dictionaries to store the whisker pixel values for each side
xpixels, ypixels = {}, {}
# whisker_ids = {}

for side, whiskers in whisker_data.items():
    # Initialize lists for this side if they don't exist yet
    if side not in xpixels:
        xpixels[side] = []
    if side not in ypixels:
        ypixels[side] = []
    # if side not in whisker_ids:
    #     whisker_ids[side] = []
    
    for frame, frame_whiskers in list(whiskers.items()):
        for whisker_id, wseg in list(frame_whiskers.items()):
            # Write whisker contour x and y pixel values
            xpixels[side].append(wseg.x)
            ypixels[side].append(wseg.y)
            # whisker_ids[side].append(wseg.id)

In [None]:
# Check how many whiskers were detected for each side
n_whiskers = {side: len(xpixels[side]) for side in side_types}
print(f"Number of whiskers detected: {n_whiskers}")
# print(f"Whisker IDs for each sides: {whisker_ids}")

In [None]:
# Get unique whisker IDs
whisker_ids = {side: np.unique([wseg.id for frame_whiskers in whisker_data[side].values() for wseg in frame_whiskers.values()]) for side in side_types}
print(f"Unique whisker IDs: {whisker_ids}")
# Create set of colors for each whisker ID
colors = {side: plt.cm.viridis(np.linspace(0, 1, len(whisker_ids[side]))) for side in side_types}

In [None]:
# Plot whiskers on each image
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# Remove space between subplots
plt.subplots_adjust(wspace=-0.51, hspace=0)

for i, side in enumerate(side_types):
    ax[-(i+1)].imshow(image_halves[i], cmap='gray')
    for whisker_id, color in zip(whisker_ids[side], colors[side]):
        # Get the whisker pixel values for this whisker ID from xpixels and ypixels
        whisker_x = xpixels[side][whisker_id]
        whisker_y = ypixels[side][whisker_id]
        ax[-(i+1)].plot(whisker_x, whisker_y, color=color)
        
    ax[-(i+1)].axis('off')
    ax[-(i+1)].set_title(f"Face side: {side}")
plt.show()

In [None]:
#  Check if measurment file exists
measurement_file = output_dir / f'{base_name}_first_frame_{side}.measurements'
if measurement_file.exists():
    wmeas = {}

    for side in side_types:
        print(f'Loading whisker measurements for {side} face side video')
        whisk_filename = output_dir / f'{base_name}_first_frame_{side}.whiskers'
        # Load whiskers
        wmeas[side] =ww.read_whisker_data(str(whisk_filename))

print(f"Whisker measurements for {len(wmeas)} sides loaded.")

In [None]:
wmeas['left']

In [None]:
# as sanity check, for each whisker, (on each side), compute whisker length and compare to the length in the measurements file
# Initialize the dictionaries to store the whisker lengths for each side
whisker_lengths, whisker_lengths_meas = {}, {}
for side in side_types:
    # Initialize lists for this side if they don't exist yet
    if side not in whisker_lengths:
        whisker_lengths[side] = []
    
    for frame, frame_whiskers in list(whisker_data[side].items()):
        for whisker_id, wseg in list(frame_whiskers.items()):
            # Compute whisker length
            whisker_length = np.sqrt((wseg.x[-1] - wseg.x[0])**2 + (wseg.y[-1] - wseg.y[0])**2)
            whisker_lengths[side].append(whisker_length)

# Compare whisker lengths to lengths in measurements file
for side in side_types:
    if side not in whisker_lengths_meas:
        whisker_lengths_meas[side] = []
    # Get whisker lengths from measurements file 
     # Sort the whisker lengths according to the sorted indices
    sorted_indices = np.argsort(wmeas[side]['label'])
    whisker_lengths_meas[side].append(np.array(wmeas[side]['length'])[sorted_indices])
    
    # Compare whisker lengths
    whisker_lengths_diff = np.array(whisker_lengths[side]) - np.array(whisker_lengths_meas[side])
    
    print(f"Mean difference in whisker lengths for {side} face side: {np.mean(whisker_lengths_diff)}")
    print(f"Max difference in whisker lengths for {side} face side: {np.max(np.abs(whisker_lengths_diff))}")

In [None]:
# Print arrays
print(np.array(whisker_lengths[side]))
print( np.array(whisker_lengths_meas[side]))

In [None]:
#  print 'score' for each whisker from measurements data (wmeas), for each side. again, resort by labels
# Initialize the dictionaries to store the whisker scores for each side
whisker_scores = {}
for side in side_types:
    # Initialize lists for this side if they don't exist yet
    if side not in whisker_scores:
        whisker_scores[side] = []
    
    sorted_indices = np.argsort(wmeas[side]['label'])
    whisker_scores[side].append(np.array(wmeas[side]['score'])[sorted_indices])

# Print whisker scores
for side in side_types:
    print(f"Whisker lengths and scores for {side} face side:")
    
    for length, score in zip(whisker_lengths[side], whisker_scores[side][0]):
        print(f"{length}, {score}")

In [None]:
# Normalize the scores to be between 0 and 1
for side in side_types:
    whisker_scores[side] = np.array(whisker_scores[side][0])
    whisker_scores[side] = (whisker_scores[side] - np.min(whisker_scores[side])) / (np.max(whisker_scores[side]) - np.min(whisker_scores[side]))

In [None]:
# Get follicles

follicle_x, follicle_y = {}, {}
for side in side_types:
    # Initialize lists for this side if they don't exist yet
    if side not in follicle_x:
        follicle_x[side] = []
        follicle_y[side] = []
    
    sorted_indices = np.argsort(wmeas[side]['label'])
    follicle_x[side].append(np.array(wmeas[side]['follicle_x'])[sorted_indices])
    follicle_y[side].append(np.array(wmeas[side]['follicle_y'])[sorted_indices])

# Flatten the lists
for side in side_types:
    follicle_x[side] = np.concatenate(follicle_x[side])
    follicle_y[side] = np.concatenate(follicle_y[side])

In [None]:
follicle_x

In [None]:
# Plot whiskers on each image
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# Remove space between subplots
plt.subplots_adjust(wspace=-0.51, hspace=0)

for i, side in enumerate(side_types):
    ax[-(i+1)].imshow(image_halves[i], cmap='gray')

    # Plot the follicles (follicle_x, follicle_y) as circles of the same color as 
    # the corresponding whisker, with intensity defined by whisker scores
    for fx, fy, score, color in zip(follicle_x[side],
                                    follicle_y[side],
                                    whisker_scores[side],
                                    colors[side]):
        # If the score is below 0.5, set the color to red
        if score < 0.5:
            color = 'red'
            alpha_level = 0.1
        else:
            alpha_level = score
        ax[-(i+1)].scatter(fx, fy, s=100, c=[color], alpha=alpha_level)

    # Plot the whiskers
    for whisker_id, color in zip(whisker_ids[side], colors[side]):
        # Get the whisker pixel values for this whisker ID from xpixels and ypixels
        whisker_x = xpixels[side][whisker_id]
        whisker_y = ypixels[side][whisker_id]
        # If the score is below 0.5, set the color to red
        if whisker_scores[side][whisker_id] < 0.5:
            color = 'red'
            alpha_level = 0.1
        else:
            alpha_level = score
        ax[-(i+1)].plot(whisker_x, whisker_y, color=color, alpha=alpha_level)
        
    ax[-(i+1)].axis('off')
    ax[-(i+1)].set_title(f"Face side: {side}")
plt.show()

In [None]:
whisker_scores