In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sys
import os
from pathlib import Path

scripts_dir = Path().resolve()
parent_dir = scripts_dir.parent
sys.path.insert(0, str(parent_dir))
import Functions.HMM as HMM
import Functions.kinematics as kinematics
import Functions.patch as patch

In [2]:
aeon_dir = scripts_dir.parent.parent/'aeon_mecha'/'aeon'
sys.path.insert(0, str(aeon_dir))
import aeon
import aeon.io.api as api
from aeon.io import reader, video
from aeon.schema.dataset import exp02, exp01
from aeon.analysis.utils import visits, distancetravelled
from aeon.schema.schemas import social02

In [3]:
roots = [Path("/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraTop/predictions_social02/AEON3/analyses"),Path("/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraTop/predictions_social02/AEON4/analyses")]
if not np.all([path.exists() for path in roots]):
    print("Cannot find root paths. Check path names or connection.")

# Define the root path
root = Path("/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraTop/predictions_social02/AEON3/analyses")

# List all files in the directory
files = list(root.glob("*"))

# Print the list of files
'''for file in files:
    print(file)'''
    
from datetime import datetime

# Extract datetimes and create list of tuples
file_dates = []
for file in files:
    datetime_str = str(file).split('/')[-1].split('_')[1] 
    datetime_obj = datetime.strptime(datetime_str, '%Y-%m-%dT%H-%M-%S')
    file_dates.append((datetime_obj, file))

# Sort the list of tuples by the datetime
file_dates.sort()

# Extract the sorted paths
sorted_paths = [path for _, path in file_dates] 

import h5py
with h5py.File("/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraTop/predictions_social02/AEON3/analyses/CameraTop_2024-01-31T11-00-00_full_pose.analysis.h5", 'r') as f:
    occupancy_matrix = f['track_occupancy'][:]
    tracks_matrix = f['tracks'][:]
    point_scores = f['point_scores'][:]
    nodes_name = f['node_names'][:]

print(occupancy_matrix.shape)
print(tracks_matrix.shape)

(93451, 1)
(1, 2, 8, 93451)


In [42]:
root = '/ceph/aeon/aeon/data/raw/AEON3/social0.2/'
start_time = pd.Timestamp('2024-02-09 16:07:32')
end_time = pd.Timestamp('2024-02-09 19:00:00')
metadata = aeon.load(
    root, social02.Metadata, start=start_time, end=end_time
)["metadata"].iloc[0]
patch_loc = [(int(point.X), int(point.Y)) for point in metadata.ActiveRegion.Patch1Region.ArrayOfPoint]
print(np.mean(patch_loc, axis = 0))
patch_loc = [(int(point.X), int(point.Y)) for point in metadata.ActiveRegion.Patch2Region.ArrayOfPoint]
print(np.mean(patch_loc, axis = 0))
patch_loc = [(int(point.X), int(point.Y)) for point in metadata.ActiveRegion.Patch3Region.ArrayOfPoint]
print(np.mean(patch_loc, axis = 0))

[910.25 544.  ]
[613.75 724.  ]
[604.5  375.75]


In [52]:
root = [Path("/ceph/aeon/aeon/data/raw/AEON2/experiment0.2")]

pellets_patch1 = api.load(root, exp02.Patch1.DeliverPellet, start=start_time, end=end_time)

In [2]:
title = 'ShortSession'+str(0)

mouse_pos = pd.read_parquet('../Data/MousePos/' + title + 'mousepos.parquet', engine='pyarrow')
states = np.load('../Data/HMMStates/' + title+'States_Unit.npy', allow_pickle=True)

Visits_Patch1 = patch.Visits(mouse_pos, patch = 'Patch1', pre_period_seconds = 30)
Visits_Patch2 = patch.Visits(mouse_pos, patch = 'Patch2', pre_period_seconds = 30)



In [4]:
Visits_Patch1

Unnamed: 0,start,end,distance,duration,speed,acceleration,entry,patch,pellet
0,2022-03-15 12:54:50.460000038,2022-03-15 12:55:01.069983959,16.534254,10.609983,180.077883,643.121477,6.384096,Patch1,0
1,2022-03-15 13:17:50.716000080,2022-03-15 13:18:10.699999809,42.480055,19.983999,107.955604,449.413439,11.430816,Patch1,1
2,2022-03-15 13:35:00.572000027,2022-03-15 13:35:30.280000210,173.970176,29.708,60.104508,345.706909,9.981376,Patch1,2
3,2022-03-15 13:38:36.024000168,2022-03-15 13:39:03.149983883,184.388075,27.125983,57.84411,418.23796,14.53232,Patch1,2
4,2022-03-15 13:48:06.889984131,2022-03-15 13:48:18.369984150,34.47372,11.48,42.321009,238.915853,6.494624,Patch1,0
5,2022-03-15 14:10:52.337984085,2022-03-15 14:11:28.116000175,247.9908,35.778016,258.482398,780.889663,4.674848,Patch1,3
6,2022-03-15 14:38:06.653984070,2022-03-15 14:38:54.136000156,295.912217,47.482016,102.104289,369.693911,10.181375,Patch1,3
7,2022-03-15 14:39:11.845983982,2022-03-15 14:39:39.913983822,193.992915,28.067999,10.839531,142.793829,75.373375,Patch1,2


In [22]:
Patches = [Visits_Patch1, Visits_Patch2]
Visits = pd.concat(Patches, ignore_index=True)
Visits = Visits.sort_values(by='start',ignore_index=True)  

Visits['last_pellets_self'] = 0
Visits['last_pellets_other'] = 0
Visits['interval'] = 0
    
for i in range(1,len(Visits)):
    start, end = Visits.start[i], Visits.end[i]
    last_end = Visits.end[i-1]
    Visits.loc[i, 'interval'] = (start - last_end).total_seconds()
    
    self_patch, other_patch = False, False
    self_pellet, other_pellet = 0, 0
    for j in range(i-1, -1, -1):
        if self_patch and other_patch: break
        if Visits.patch[j] != Visits.patch[i] and other_patch == False: 
            other_pellet = Visits.pellet[j]
            other_patch = True
        if Visits.patch[j] == Visits.patch[i] and self_patch == False:
            self_pellet = Visits.pellet[j]
            self_patch = True
    Visits.loc[i, 'last_pellets_self'] = self_pellet
    Visits.loc[i, 'last_pellets_other'] = other_pellet

In [24]:
VISIT = pd.read_parquet('../Data/RegressionPatchVisits/VISIT.parquet', engine='pyarrow')
VISIT

Unnamed: 0,start,end,distance,duration,speed,acceleration,entry,patch,pellet,last_pellets_self,last_pellets_other,interval,interc
1,2022-03-15 12:46:04.092000008,2022-03-15 12:46:16.117983818,41.599496,12.025983,143.290260,427.553219,14.519648,Patch2,0.0,0,0,327.708000,1
2,2022-03-15 12:51:27.592000008,2022-03-15 12:51:57.584000111,131.307566,29.992000,129.750754,522.784048,11.817440,Patch2,2.0,0,0,311.474016,1
3,2022-03-15 12:52:21.641983986,2022-03-15 12:53:02.728000164,205.733186,41.086016,63.896369,270.957639,2.967072,Patch2,2.0,2,0,24.057983,1
4,2022-03-15 12:54:50.460000038,2022-03-15 12:55:01.069983959,16.534254,10.609983,180.077883,643.121477,6.384096,Patch1,0.0,0,2,107.731999,1
5,2022-03-15 12:55:18.141983986,2022-03-15 12:56:16.579999924,297.303623,58.438015,96.581978,387.980349,6.865920,Patch2,3.0,2,0,17.072000,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
331,2022-08-16 15:12:09.101984024,2022-08-16 15:13:47.645984173,584.899622,98.544000,95.306665,360.186225,5.875264,Patch1,6.0,1,0,169.263999,1
332,2022-08-16 15:19:23.139999866,2022-08-16 15:20:16.693984032,289.355583,53.553984,95.489309,305.837305,7.410463,Patch1,3.0,6,0,335.494015,1
333,2022-08-16 15:29:03.940000057,2022-08-16 15:29:42.808000088,196.835555,38.868000,21.115280,219.046458,39.807104,Patch1,2.0,3,0,527.246016,1
334,2022-08-16 15:29:53.515999793,2022-08-16 15:31:01.113984108,389.368031,67.597984,4.250241,72.267359,89.383103,Patch1,4.0,2,0,10.707999,1


In [83]:
start = VISIT.loc[1, 'start']
print(start)

1   2022-03-15 12:46:04.092000008
1   2022-03-15 13:17:50.716000080
Name: start, dtype: datetime64[ns]
