# Data Loading Functions Example

This notebook demonstrates all the data loading functions available in the ethopy_analysis package. Each function is explained with its purpose, parameters, and example usage.

In [1]:
# Import all necessary modules
from ethopy_analysis.data.loaders import (
    get_sessions,
    get_trials,
    get_trial_states,
    get_trial_experiment,
    get_trial_behavior,
    get_trial_stimulus,
    get_trial_licks,
    get_trial_proximities,
    get_session_classes,
    get_session_duration,
    get_session_task
)
from ethopy_analysis.data.analysis import (
    get_performance,
    trials_per_session,
    session_summary
)
from ethopy_analysis.data.utils import get_setup
from ethopy_analysis.db.schemas import get_schema, get_all_schemas
from ethopy_analysis.config.settings import load_config

# Apply styling
from ethopy_analysis.config.styles import Style
Style().apply()

  import pkg_resources


## 1. Configuration and Setup Functions

These functions handle configuration loading and database setup.

In [2]:
# load_config() - Load configuration from file or use defaults
# Parameters: config_path (str/Path, optional)
# Returns: Configuration dictionary
config = load_config(display_path=True)
print("Configuration loaded:")
print(f"Database host: {config.get('database', {}).get('host', 'Not configured')}")
print(f"Available schemas: {list(config.get('database', {}).get('schemas', {}).keys())}")

Configuration loaded from: /Users/alexandros/Documents/GitHub/ethopy_analysis/ethopy_config.json
Configuration loaded:
Database host: database.eflab.org:3306
Available schemas: ['experiment', 'stimulus', 'behavior']


In [3]:
# get_setup() - Retrieve animal_id and session for a given setup
# Parameters: setup (str)
# Returns: Tuple[int, int] - (animal_id, session)
animal_id, session = get_setup("ef-rp13")
print(f"Setup 'ef-rp13' current animal_id: {animal_id}, session: {session}")

[2025-07-16 11:13:09,662][INFO]: DataJoint 0.14.4 connected to eflab@database.eflab.org:3306


Setup 'ef-rp13' current animal_id: 290, session: 46


## 2. Database Schema Functions

These functions provide access to DataJoint database schemas.

In [4]:
# get_schema() - Get a specific schema by name
# Parameters: schema_name (str: 'experiment', 'behavior', 'stimulus'), config (Dict, optional)
# Returns: DataJoint virtual module for the schema
experiment = get_schema('experiment')
behavior = get_schema('behavior')
stimulus = get_schema('stimulus')

print("Individual schemas loaded:")
print(f"Experiment schema: {type(experiment)}")
print(f"Behavior schema: {type(behavior)}")
print(f"Stimulus schema: {type(stimulus)}")

Individual schemas loaded:
Experiment schema: <class 'datajoint.schemas.VirtualModule'>
Behavior schema: <class 'datajoint.schemas.VirtualModule'>
Stimulus schema: <class 'datajoint.schemas.VirtualModule'>


In [5]:
# get_all_schemas() - Get all three schemas at once
# Parameters: config (Dict, optional)
# Returns: Dict with keys 'experiment', 'behavior', 'stimulus'
schemas = get_all_schemas()
print("All schemas loaded:")
for schema_name, schema_module in schemas.items():
    print(f"{schema_name}: {type(schema_module)}")

All schemas loaded:
experiment: <class 'datajoint.schemas.VirtualModule'>
stimulus: <class 'datajoint.schemas.VirtualModule'>
behavior: <class 'datajoint.schemas.VirtualModule'>


## 3. Session-Level Data Loading

These functions load session-level information and metadata.

In [6]:
# get_sessions() - Get sessions for an animal within a date range
# Parameters: animal_id (int), from_date (str), to_date (str), format (str), min_trials (int)
# Returns: Session DataFrame or DataJoint expression
sessions = get_sessions(animal_id, min_trials=20)
print(f"Sessions for animal {animal_id} (min 20 trials):")
print(sessions.head())
print(f"\nTotal sessions: {len(sessions)}")
print(f"Session range: {sessions['session'].min()} to {sessions['session'].max()}")

Sessions for animal 290 (min 20 trials):
   animal_id  session trials_count user_name      setup        session_tmst  \
0        290        1           30       bot  master-02 2025-03-20 12:31:32   
1        290        2           49       bot  master-02 2025-03-20 12:44:24   
2        290        3           70       bot  master-02 2025-03-21 12:43:51   
3        290        4           19       bot  master-02 2025-03-23 12:46:31   
4        290        5           31       bot  master-02 2025-03-23 13:05:35   

  experiment_type  logger_tmst  
0       FreeWater          0.0  
1        Approach          0.0  
2        Approach          0.0  
3        Approach          0.0  
4        Approach          0.0  

Total sessions: 40
Session range: 1 to 46


In [7]:
# get_sessions() with date filtering
# Filter sessions by date range
sessions_filtered = get_sessions(animal_id, from_date="2025-01-01", to_date="2025-12-31")
print(f"Sessions for animal {animal_id} in 2025:")
print(sessions_filtered.head())
print(f"Total sessions in 2025: {len(sessions_filtered)}")

Sessions for animal 290 in 2025:
   animal_id  session user_name      setup        session_tmst  \
0        290        1       bot  master-02 2025-03-20 12:31:32   
1        290        2       bot  master-02 2025-03-20 12:44:24   
2        290        3       bot  master-02 2025-03-21 12:43:51   
3        290        4       bot  master-02 2025-03-23 12:46:31   
4        290        5       bot  master-02 2025-03-23 13:05:35   

  experiment_type  logger_tmst  
0       FreeWater          0.0  
1        Approach          0.0  
2        Approach          0.0  
3        Approach          0.0  
4        Approach          0.0  
Total sessions in 2025: 45


In [8]:
# get_session_classes() - Retrieve session information and experimental classes
# Parameters: animal_id (int), session (int)
# Returns: DataFrame with session info and class combinations
session_classes = get_session_classes(animal_id, session)
print(f"Session classes for animal {animal_id}, session {session}:")
print(session_classes)

Session classes for animal 290, session 46:
   animal_id  session user_name    setup        session_tmst experiment_type  \
0        290       46       bot  ef-rp13 2025-07-16 10:42:14       MatchPort   

   logger_tmst stimulus_class behavior_class experiment_class  
0          0.0          Panda      MultiPort        MatchPort  


In [9]:
# get_session_duration() - Calculate session duration
# Parameters: animal_id (int), session (int)
# Returns: Formatted duration string or None
duration = get_session_duration(animal_id, session)
print(f"Session {session} duration: {duration}")

Session 46 duration: 30.99 minutes (1859.6 seconds)


In [10]:
# get_session_task() - Retrieve task configuration file
# Parameters: animal_id (int), session (int), save_file (bool)
# Returns: Tuple of (filename, git_hash)
filename, git_hash = get_session_task(animal_id, session, save_file=False)
print(f"Task file: {filename}")
print(f"Git hash: {git_hash}")

Task file: 2object_detection_visual-dif0-1.py
Git hash: b06989f


## 4. Trial-Level Data Loading

These functions load detailed trial-level data for analysis.

In [11]:
# get_trials() - Retrieve trial data for a specific session
# Parameters: animal_id (int), session (int), format (str), remove_abort (bool)
# Returns: Trial DataFrame or DataJoint expression
trials = get_trials(animal_id, session)
print(f"Trials for animal {animal_id}, session {session}:")
print(trials.head())
print(f"\nTotal trials: {len(trials)}")
print(f"Trial columns: {list(trials.columns)}")

Trials for animal 290, session 46:
   animal_id  session  trial_idx                 cond_hash   time
0        290       46          1  eUmhCs/lwH98hg5l7TW2og==   3938
1        290       46          2  eUmhCs/lwH98hg5l7TW2og==  14670
2        290       46          3  eUmhCs/lwH98hg5l7TW2og==  16004
3        290       46          4  eUmhCs/lwH98hg5l7TW2og==  35750
4        290       46          5  eUmhCs/lwH98hg5l7TW2og==  40543

Total trials: 236
Trial columns: ['animal_id', 'session', 'trial_idx', 'cond_hash', 'time']


In [12]:
# get_trials() with abort removal
# Remove aborted trials from the dataset
trials_no_abort = get_trials(animal_id, session, remove_abort=True)
print(f"Trials without aborts: {len(trials_no_abort)} (vs {len(trials)} with aborts)")
print("Trial outcomes:")
print(trials_no_abort)

Trials without aborts: 96 (vs 236 with aborts)
Trial outcomes:
    animal_id  session  trial_idx                 cond_hash     time
0         290       46          7  eUmhCs/lwH98hg5l7TW2og==    54064
1         290       46         10  eUmhCs/lwH98hg5l7TW2og==    66283
2         290       46         11  eUmhCs/lwH98hg5l7TW2og==    68984
3         290       46         12  eUmhCs/lwH98hg5l7TW2og==    72403
4         290       46         14  eUmhCs/lwH98hg5l7TW2og==    82410
..        ...      ...        ...                       ...      ...
91        290       46        229  1vyELoFSThi/L7d86pydeQ==  1539816
92        290       46        230  1vyELoFSThi/L7d86pydeQ==  1546771
93        290       46        231  1vyELoFSThi/L7d86pydeQ==  1561832
94        290       46        235  1vyELoFSThi/L7d86pydeQ==  1853413
95        290       46        236  1vyELoFSThi/L7d86pydeQ==  1857283

[96 rows x 5 columns]


In [13]:
# get_trial_states() - Retrieve trial state onset data
# Parameters: animal_id (int), session (int), format (str)
# Returns: Trial states DataFrame with state onset times
trial_states = get_trial_states(animal_id, session)
print(f"Trial states for animal {animal_id}, session {session}:")
print(trial_states.head())
print(f"\nUnique states: {trial_states['state'].unique()}")
print(f"Total state events: {len(trial_states)}")

Trial states for animal 290, session 46:
   animal_id  session  trial_idx   time       state
0        290       46          1   3947    PreTrial
1        290       46          1  13357       Trial
2        290       46          1  13662       Abort
3        290       46          1  14162  InterTrial
4        290       46          2  14671    PreTrial

Unique states: ['PreTrial' 'Trial' 'Abort' 'InterTrial' 'Reward' 'Punish']
Total state events: 943


In [14]:
# get_trial_experiment() - Retrieve trial experiment condition data
# Parameters: animal_id (int), session (int), format (str)
# Returns: Trial experiment conditions DataFrame
trial_experiment = get_trial_experiment(animal_id, session)
print(f"Trial experiment data for animal {animal_id}, session {session}:")
print(trial_experiment.head())
print(f"\nExperiment columns: {list(trial_experiment.columns)}")

Trial experiment data for animal 290, session 46:
   animal_id  session  trial_idx                 cond_hash   time  \
0        290       46          1  eUmhCs/lwH98hg5l7TW2og==   3938   
1        290       46          2  eUmhCs/lwH98hg5l7TW2og==  14670   
2        290       46          3  eUmhCs/lwH98hg5l7TW2og==  16004   
3        290       46          4  eUmhCs/lwH98hg5l7TW2og==  35750   
4        290       46          5  eUmhCs/lwH98hg5l7TW2og==  40543   

  stimulus_class behavior_class experiment_class trial_selection  max_reward  \
0          Panda      MultiPort        MatchPort       staircase        1200   
1          Panda      MultiPort        MatchPort       staircase        1200   
2          Panda      MultiPort        MatchPort       staircase        1200   
3          Panda      MultiPort        MatchPort       staircase        1200   
4          Panda      MultiPort        MatchPort       staircase        1200   

   ...  intertrial_duration  trial_duration  reward_du

In [15]:
# get_trial_behavior() - Retrieve trial behavior condition data
# Parameters: animal_id (int), session (int), format (str)
# Returns: Trial behavior conditions DataFrame
trial_behavior = get_trial_behavior(animal_id, session)
print(f"Trial behavior data for animal {animal_id}, session {session}:")
print(trial_behavior.head())
print(f"\nBehavior columns: {list(trial_behavior.columns)}")

Trial behavior data for animal 290, session 46:
   animal_id  session  trial_idx                  beh_hash   time  \
0        290       46          1  YakvXLaMHiQxJ888h16xug==   3947   
1        290       46          2  YakvXLaMHiQxJ888h16xug==  14671   
2        290       46          3  wUu0uK7wjIAhezgPkkddDg==  16005   
3        290       46          4  YakvXLaMHiQxJ888h16xug==  35752   
4        290       46          5  YakvXLaMHiQxJ888h16xug==  40544   

   response_port  reward_port  reward_amount reward_type  
0              1            1            6.0       water  
1              1            1            6.0       water  
2              2            2            6.0       water  
3              1            1            6.0       water  
4              1            1            6.0       water  

Behavior columns: ['animal_id', 'session', 'trial_idx', 'beh_hash', 'time', 'response_port', 'reward_port', 'reward_amount', 'reward_type']


In [16]:
# get_trial_stimulus() - Retrieve trial stimulus condition data
# Parameters: animal_id (int), session (int), stim_class (str), format (str)
# Returns: Trial stimulus conditions DataFrame
trial_stimulus = get_trial_stimulus(animal_id, session)
print(f"Trial stimulus data for animal {animal_id}, session {session}:")
print(trial_stimulus.head())
print(f"\nStimulus columns: {list(trial_stimulus.columns)}")

Trial stimulus data for animal 290, session 46:
   animal_id  session  trial_idx period                 stim_hash  light_idx  \
0        290       46          1         YQSNG18SbwfntIQTmyxqcw==          1   
1        290       46          1         YQSNG18SbwfntIQTmyxqcw==          2   
2        290       46          2         YQSNG18SbwfntIQTmyxqcw==          1   
3        290       46          2         YQSNG18SbwfntIQTmyxqcw==          2   
4        290       46          3         lzWHNOwmeHDD5URqKDyUbQ==          1   

   obj_id  start_time end_time           light_color  ... obj_mag  \
0       2       13367    13661  [0.7, 0.7, 0.7, 1.0]  ...     0.5   
1       2       13367    13661  [0.2, 0.2, 0.2, 1.0]  ...     0.5   
2       2       14778    14996  [0.7, 0.7, 0.7, 1.0]  ...     0.5   
3       2       14778    14996  [0.2, 0.2, 0.2, 1.0]  ...     0.5   
4       2       34527    34743  [0.7, 0.7, 0.7, 1.0]  ...     0.5   

                                             obj_rot obj

## 5. Behavioral Event Data Loading

These functions load specific behavioral events like licks and proximity sensor data.

In [17]:
# get_trial_licks() - Retrieve all licks of a session
# Parameters: animal_id (int), session (int), format (str)
# Returns: Lick events DataFrame
trial_licks = get_trial_licks(animal_id, session)
print(f"Trial licks for animal {animal_id}, session {session}:")
print(trial_licks.head())
print(f"\nTotal lick events: {len(trial_licks)}")
print(f"Lick columns: {list(trial_licks.columns)}")
print(f"Licks per port: {trial_licks['port'].value_counts()}")

Trial licks for animal 290, session 46:
   animal_id  session  trial_idx  port   time
0        290       46          3     1  19211
1        290       46          3     1  19373
2        290       46          4     1  36913
3        290       46          4     1  37102
4        290       46          6     2  53557

Total lick events: 1167
Lick columns: ['animal_id', 'session', 'trial_idx', 'port', 'time']
Licks per port: port
1    698
2    469
Name: count, dtype: int64


In [18]:
# get_trial_proximities() - Retrieve proximity sensor data
# Parameters: animal_id (int), session (int), ports (List), format (str)
# Returns: Proximity data DataFrame
trial_proximities = get_trial_proximities(animal_id, session)
print(f"Trial proximities for animal {animal_id}, session {session}:")
print(trial_proximities.head())
print(f"\nTotal proximity events: {len(trial_proximities)}")
print(f"Proximity columns: {list(trial_proximities.columns)}")
print(f"Proximity per port: {trial_proximities['port'].value_counts()}")

Trial proximities for animal 290, session 46:
   animal_id  session  trial_idx  port   time  in_position
0        290       46          1     3  11665            1
1        290       46          1     3  11666            1
2        290       46          1     3  11764            0
3        290       46          1     3  11765            1
4        290       46          1     3  11766            1

Total proximity events: 2045
Proximity columns: ['animal_id', 'session', 'trial_idx', 'port', 'time', 'in_position']
Proximity per port: port
3    2045
Name: count, dtype: int64


In [19]:
# get_trial_proximities() with port filtering
# Filter proximity data for specific ports
proximity_filtered = get_trial_proximities(animal_id, session, ports=[1, 2, 3])
print(f"Proximity data for ports [1, 2, 3]: {len(proximity_filtered)} events")
print(f"Events per port: {proximity_filtered['port'].value_counts()}")

Proximity data for ports [1, 2, 3]: 2045 events
Events per port: port
3    2045
Name: count, dtype: int64


## 6. Analysis and Performance Functions

These functions compute performance metrics and provide session analysis.

In [20]:
# get_performance() - Calculate performance as ratio of reward to total decisive trials
# Parameters: animal_id (int), session (int), trials (List[int])
# Returns: Performance ratio (0-1) or None
performance = get_performance(animal_id, session)
print(f"Performance for animal {animal_id}, session {session}: {performance:.3f}")

# Calculate performance for specific trials
trial_subset = list(range(1, 101))  # First 100 trials
performance_subset = get_performance(animal_id, session, trials=trial_subset)
print(f"Performance for first 100 trials: {performance_subset:.3f}")

Performance for animal 290, session 46: 0.625
Performance for first 100 trials: 0.865


In [None]:
# trials_per_session() - Returns the number of trials per session
# Parameters: animal_id (int), min_trials (int), format (str)
# Returns: DataFrame with trials_count column
session_trial_counts = trials_per_session(animal_id, min_trials=10)
print(f"Trial counts per session for animal {animal_id}:")
print(session_trial_counts.head())
print("\nSummary statistics:")
print(session_trial_counts['trials_count'].describe())

Trial counts per session for animal 290:
   animal_id  session trials_count
0        290        1           30
1        290        2           49
2        290        3           70
3        290        4           19
4        290        5           31

Summary statistics:
count      40
unique     36
top       231
freq        2
Name: trials_count, dtype: int64


In [22]:
# session_summary() - Print comprehensive summary of a session
# Parameters: animal_id (int), session (int)
# Returns: None (prints summary)
print(f"Comprehensive session summary for animal {animal_id}, session {session}:")
print("=" * 60)
session_summary(animal_id, session)

Comprehensive session summary for animal 290, session 46:
Animal id: 290, session: 46
User name: bot
Setup: ef-rp13
Session start: 2025-07-16 10:42:14
Session duration: 31.16 minutes (1869.7 seconds)

Experiment:  MatchPort
Stimulus:  Panda
Behavior:  MultiPort

Task filename: 2object_detection_visual-dif0-1.py
Git hash: b06989f

Session performance: 0.625
Number of trials: 236


## 7. Complete Data Loading Example

This example shows how to load all available data for a comprehensive analysis.

In [None]:
# Load all available data for a session
print(f"Loading complete dataset for animal {animal_id}, session {session}...")
print("=" * 60)

# Session-level data
sessions_data = get_sessions(animal_id)
session_info = get_session_classes(animal_id, session)
session_dur = get_session_duration(animal_id, session)
task_file, git_hash = get_session_task(animal_id, session, save_file=False)

# Trial-level data
trials_data = get_trials(animal_id, session)
states_data = get_trial_states(animal_id, session)
experiment_data = get_trial_experiment(animal_id, session)
behavior_data = get_trial_behavior(animal_id, session)
stimulus_data = get_trial_stimulus(animal_id, session)

# Behavioral events
licks_data = get_trial_licks(animal_id, session)
proximities_data = get_trial_proximities(animal_id, session)

# Performance metrics
performance_score = get_performance(animal_id, session)

print("Data loading complete!")
print(f"Sessions available: {len(sessions_data)}")
print(f"Trials in session {session}: {len(trials_data)}")
print(f"States: {len(states_data)}")
print(f"Lick events: {len(licks_data)}")
print(f"Proximity events: {len(proximities_data)}")
print(f"Session performance: {performance_score:.3f}")
print(f"Session duration: {session_dur}")
print(f"Task file: {task_file}")

Loading complete dataset for animal 290, session 46...
Data loading complete!
Sessions available: 45
Trials in session 46: 237
State events: 945
Lick events: 1178
Proximity events: 2049
Session performance: 0.629
Session duration: 31.17 minutes (1870.2 seconds)
Task file: 2object_detection_visual-dif0-1.py


## 8. Data Format Options

Most functions support both DataFrame and DataJoint formats for flexibility.

In [24]:
# Compare DataFrame vs DataJoint formats
print("Data format comparison:")
print("=" * 40)

# DataFrame format (default)
trials_df = get_trials(animal_id, session, format="df")
print(f"DataFrame format: {type(trials_df)}")
print(f"Shape: {trials_df.shape}")
print(f"Columns: {list(trials_df.columns)}")

# DataJoint format
trials_dj = get_trials(animal_id, session, format="dj")
print(f"\nDataJoint format: {type(trials_dj)}")
print(f"Length: {len(trials_dj)}")

# Convert DataJoint to DataFrame if needed
trials_from_dj = trials_dj.fetch(format="frame").reset_index()
print(f"\nConverted back to DataFrame: {type(trials_from_dj)}")
print(f"Shape: {trials_from_dj.shape}")

Data format comparison:
DataFrame format: <class 'pandas.core.frame.DataFrame'>
Shape: (238, 5)
Columns: ['animal_id', 'session', 'trial_idx', 'cond_hash', 'time']

DataJoint format: <class 'datajoint.schemas.Trial'>
Length: 238

Converted back to DataFrame: <class 'pandas.core.frame.DataFrame'>
Shape: (238, 5)


## 9. Error Handling and Edge Cases

Examples of how functions handle common edge cases.

In [25]:
# Handle cases with no data
print("Testing edge cases:")
print("=" * 30)

# Try to get performance for a session with no decisive trials
# (This would return None if no reward/punish trials exist)
try:
    perf_result = get_performance(animal_id, session, trials=[999999])  # Non-existent trial
    print(f"Performance for non-existent trial: {perf_result}")
except Exception as e:
    print(f"Expected error for non-existent trial: {type(e).__name__}")

# Try to get data for future date range
future_sessions = get_sessions(animal_id, from_date="2030-01-01", to_date="2030-12-31")
print(f"Sessions in future date range: {len(future_sessions)}")

Testing edge cases:
Performance for non-existent trial: None
Sessions in future date range: 0


## Summary

This notebook demonstrated all the data loading functions available in the ethopy_analysis package:

### Configuration & Setup
- `load_config()` - Load configuration settings
- `get_database_config()` - Get database configuration
- `get_setup()` - Get animal/session from setup ID

### Database Schemas
- `get_schema()` - Get individual schema
- `get_all_schemas()` - Get all schemas at once

### Session Data
- `get_sessions()` - Get sessions with filtering
- `get_session_classes()` - Get session metadata
- `get_session_duration()` - Get session duration
- `get_session_task()` - Get task file info

### Trial Data
- `get_trials()` - Get trial data
- `get_trial_states()` - Get state onset times
- `get_trial_experiment()` - Get experiment conditions
- `get_trial_behavior()` - Get behavior conditions
- `get_trial_stimulus()` - Get stimulus conditions

### Behavioral Events
- `get_trial_licks()` - Get lick events
- `get_trial_proximities()` - Get proximity events

### Analysis Functions
- `get_performance()` - Calculate performance metrics
- `trials_per_session()` - Count trials per session
- `session_summary()` - Print session summary

All functions support flexible data formats (DataFrame/DataJoint) and include proper error handling for edge cases.