In [58]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../')

from main import load_and_prepare_sessions
from processing.session_sampling import MiceAnalysis
from analysis.timepoint_analysis import sample_signals_and_metrics, sample_low_and_high_signals
from config import all_brain_regions, all_event_types, all_metrics
from itertools import product
import numpy as np
from utils import mouse_br_events_count

window_size = 5
window = np.ones(window_size) / window_size

sessions = load_and_prepare_sessions("../../../Baseline", load_from_pickle=True, remove_bad_signal_sessions=True)
mouse_analyser = MiceAnalysis(sessions)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [59]:
# generate all aggregated signals
all_event_signals = []
labels = []

for mouse in mouse_analyser.mice_dict.values():
    mouse_sessions = mouse.sessions
    for brain_region, event in product(all_brain_regions, ['hit', 'mistake', 'miss', 'cor_reject', 'reward_collect']):
        mouse_signals = [] 
        for session in mouse_sessions:
            if session.signal_info.get((brain_region, event)) is None:
                continue
            signals = sample_signals_and_metrics([session], event, brain_region)[0]
            mouse_signals.append(signals[:, 150:250])
        if len(mouse_signals) == 0:
            continue
        mouse_signals = np.vstack(mouse_signals)
        sample_idxs = np.random.choice(len(mouse_signals), 100, replace=True)
        all_event_signals.append(mouse_signals)
        labels.extend([(mouse.mouse_id, brain_region, event)] * len(mouse_signals))

all_event_signals = np.vstack(all_event_signals)

In [61]:
len(labels)

20690

In [62]:
len(all_event_signals)

20690

In [56]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

In [63]:
import pickle

# Save the data to a pickle file
with open('all_event_signals.pkl', 'wb') as f:
    pickle.dump(all_event_signals, f)

# Save the data to a pickle file
with open('labels.pkl', 'wb') as f:
    pickle.dump(labels, f)

In [None]:
mouse_labels, br_labels, event_labels = zip(*labels)
mouse_labels = np.array(mouse_labels)
br_labels = np.array(br_labels)
event_labels = np.array(event_labels)

In [None]:
# Encode the br_labels and event_labels to numerical values
br_label_encoder = LabelEncoder()
br_labels_encoded = br_label_encoder.fit_transform(br_labels)

event_label_encoder = LabelEncoder()
event_labels_encoded = event_label_encoder.fit_transform(event_labels)

In [None]:
unique_mouse_labels = np.unique(mouse_labels)
train_mice, test_mice = train_test_split(unique_mouse_labels, test_size=0.4, random_state=42)

In [None]:
train_mask = np.isin(mouse_labels, train_mice)
test_mask = np.isin(mouse_labels, test_mice)

In [None]:
all_event_signals_train = all_event_signals[train_mask]
all_event_signals_test = all_event_signals[test_mask]

In [None]:
br_labels_train = br_labels_encoded[train_mask]
br_labels_test = br_labels_encoded[test_mask]

event_labels_train = event_labels_encoded[train_mask]
event_labels_test = event_labels_encoded[test_mask]

In [None]:
print("Shapes after adjustment:")
print(f"all_event_signals_train shape: {all_event_signals_train.shape}")
print(f"br_labels_train shape: {br_labels_train.shape}")
print(f"event_labels_train shape: {event_labels_train.shape}")
print(f"all_event_signals_test shape: {all_event_signals_test.shape}")
print(f"br_labels_test shape: {br_labels_test.shape}")
print(f"event_labels_test shape: {event_labels_test.shape}")

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Assuming all_event_signals_train, br_labels_train, event_labels_train, 
# all_event_signals_test, br_labels_test, event_labels_test are already defined

# Convert the data to PyTorch tensors
all_event_signals_train = torch.tensor(all_event_signals_train, dtype=torch.float32)
br_labels_train = torch.tensor(br_labels_train, dtype=torch.long)
event_labels_train = torch.tensor(event_labels_train, dtype=torch.long)

all_event_signals_test = torch.tensor(all_event_signals_test, dtype=torch.float32)
br_labels_test = torch.tensor(br_labels_test, dtype=torch.long)
event_labels_test = torch.tensor(event_labels_test, dtype=torch.long)

# Create TensorDatasets
train_dataset = TensorDataset(all_event_signals_train, br_labels_train, event_labels_train)
test_dataset = TensorDataset(all_event_signals_test, br_labels_test, event_labels_test)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, br_output_size, event_output_size, num_layers=2):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc_br = nn.Linear(hidden_size, br_output_size)
        self.fc_event = nn.Linear(hidden_size, event_output_size)
    
    def forward(self, x):
        # Set initial hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))
        
        # Decode the hidden state of the last time step for both br and event labels
        br_out = self.fc_br(out[:, -1, :])
        event_out = self.fc_event(out[:, -1, :])
        return br_out, event_out

# Hyperparameters
input_size = 1  # One feature per time step
hidden_size = 128
br_output_size = len(torch.unique(br_labels_train))  # Number of unique br_labels
event_output_size = len(torch.unique(event_labels_train))  # Number of unique event_labels
num_layers = 2  # Two-layer LSTM

# Initialize the model, loss function, and optimizer
model = LSTMModel(input_size, hidden_size, br_output_size, event_output_size, num_layers)
criterion_br = nn.CrossEntropyLoss()
criterion_event = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)  # Move model to the configured device

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    train_br_correct = 0
    train_event_correct = 0
    train_total = 0

    for signals, br_labels, event_labels in train_loader:
        # Move tensors to the configured device
        signals = signals.to(device).unsqueeze(-1)
        br_labels = br_labels.to(device)
        event_labels = event_labels.to(device)
        
        # Forward pass
        br_outputs, event_outputs = model(signals)
        br_loss = criterion_br(br_outputs, br_labels)
        event_loss = criterion_event(event_outputs, event_labels)
        loss = br_loss
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate training accuracy
        _, br_predicted = torch.max(br_outputs.data, 1)
        _, event_predicted = torch.max(event_outputs.data, 1)
        
        train_total += br_labels.size(0)
        train_br_correct += (br_predicted == br_labels).sum().item()
        train_event_correct += (event_predicted == event_labels).sum().item()

    train_br_accuracy = 100 * train_br_correct / train_total
    train_event_accuracy = 100 * train_event_correct / train_total

    # Evaluate the model
    model.eval()  # Set the model to evaluation mode
    br_correct = 0
    event_correct = 0
    total = 0

    with torch.no_grad():
        for signals, br_labels, event_labels in test_loader:
            signals = signals.to(device).unsqueeze(-1)
            br_labels = br_labels.to(device)
            event_labels = event_labels.to(device)
            
            br_outputs, event_outputs = model(signals)
            _, br_predicted = torch.max(br_outputs.data, 1)
            _, event_predicted = torch.max(event_outputs.data, 1)
            
            total += br_labels.size(0)
            br_correct += (br_predicted == br_labels).sum().item()
            event_correct += (event_predicted == event_labels).sum().item()
    
    br_accuracy = 100 * br_correct / total
    event_accuracy = 100 * event_correct / total

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, '
          f'Train Accuracy for br_labels: {train_br_accuracy:.2f}%, '
          f'Train Accuracy for event_labels: {train_event_accuracy:.2f}%, '
          f'Test Accuracy for br_labels: {br_accuracy:.2f}%, '
          f'Test Accuracy for event_labels: {event_accuracy:.2f}%')

# Final evaluation on the test set after training
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    br_correct = 0
    event_correct = 0
    total = 0
    for signals, br_labels, event_labels in test_loader:
        signals = signals.to(device).unsqueeze(-1)
        br_labels = br_labels.to(device)
        event_labels = event_labels.to(device)
        
        br_outputs, event_outputs = model(signals)
        _, br_predicted = torch.max(br_outputs.data, 1)
        _, event_predicted = torch.max(event_outputs.data, 1)
        
        total += br_labels.size(0)
        br_correct += (br_predicted == br_labels).sum().item()
        event_correct += (event_predicted == event_labels).sum().item()
    
    print(f'Final Test Accuracy for br_labels: {100 * br_correct / total:.2f}%')
    print(f'Final Test Accuracy for event_labels: {100 * event_correct / total:.2f}%')
