In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd

In [None]:
class ABRWavesDataset(Dataset):
    def __init__(self, csv_file):
        # Read the CSV file into a dataframe
        self.df = pd.read_csv(csv_file)

    def __len__(self):
        # Return the number of samples
        return len(self.df)

    def __getitem__(self, idx):
        # Extract the waveform and mask for the given index
        waveform = np.fromstring(self.df.iloc[idx]['waveform'][1:-1], dtype=float, sep=' ')
        mask = np.fromstring(self.df.iloc[idx]['mask'][1:-1], dtype=float, sep=' ')
        weight = self.df.iloc[idx]['sample_weight']

        # Convert the waveform and mask to PyTorch tensors
        waveform = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)  # Add channel dimension

        # We assume the target is the index of the peak (the maximum value in the mask)
        peak_index = np.argmax(mask)
        target = torch.tensor(peak_index, dtype=torch.float32)

        return waveform, target, weight

In [None]:
# Define CNN Model with tunable hyperparameters
class CNN(nn.Module):
    def __init__(self, filter1, filter2, dropout1, dropout2, dropout_fc):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=filter1, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv1d(in_channels=filter1, out_channels=filter2, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(filter2 * 61, 128)
        self.fc2 = nn.Linear(128, 1)
        self.dropout1 = nn.Dropout(dropout1)
        self.dropout2 = nn.Dropout(dropout2)
        self.dropout_fc = nn.Dropout(dropout_fc)
        self.batch_norm1 = nn.BatchNorm1d(filter1)
        self.batch_norm2 = nn.BatchNorm1d(filter2)
    
    def forward(self, x):
        x = self.pool(nn.functional.relu(self.batch_norm1(self.conv1(x))))
        x = self.dropout1(x)
        x = self.pool(nn.functional.relu(self.batch_norm2(self.conv2(x))))
        x = self.dropout2(x)
        x = x.view(-1, self.fc1.in_features)
        x = nn.functional.relu(self.fc1(x))
        x = self.dropout_fc(x)
        x = self.fc2(x)
        return x

In [None]:
from sklearn.model_selection import train_test_split
import pandas as pd

# Load the combined CSV file
combined_csv_file = 'wave_mask_pairs.csv'
combined_df = pd.read_csv(combined_csv_file)

# Determine sample weights based on the mask (lab source)
# Use the filename length to distinguish between Marcotti Lab (<=3) and Manor Lab (>3)
combined_df['lab'] = combined_df['filename'].apply(lambda x: 'Marcotti Lab' if len(x) <= 3 else 'Manor Lab')
combined_df['mouse'] = combined_df['filename'].apply(lambda x: x.split()[0])

# Get the unique mice and split into train and test sets
unique_mice = combined_df['mouse'].unique()
train_mice, test_mice = train_test_split(unique_mice, test_size=0.2, random_state=42)

# Filter the dataset based on the mice splits
train_df = combined_df[combined_df['mouse'].isin(train_mice)]
test_df = combined_df[combined_df['mouse'].isin(test_mice)]

# Further split the training data into train and validation sets (80-20 split)
train_mice, val_mice = train_test_split(train_mice, test_size=0.2, random_state=42)
train_df = combined_df[combined_df['mouse'].isin(train_mice)]
val_df = combined_df[combined_df['mouse'].isin(val_mice)]

lab_counts = train_df['lab'].value_counts()
inverse_weights = {lab: 1.0 / count for lab, count in lab_counts.items()}
weight_sum = sum(inverse_weights[lab] for lab in train_df['lab'])
normalized_weights = {lab: (w / weight_sum) * len(train_df) for lab, w in inverse_weights.items()}
train_df['sample_weight'] = train_df['lab'].map(normalized_weights)

lab_counts = val_df['lab'].value_counts()
inverse_weights = {lab: 1.0 / count for lab, count in lab_counts.items()}
weight_sum = sum(inverse_weights[lab] for lab in val_df['lab'])
normalized_weights = {lab: (w / weight_sum) * len(val_df) for lab, w in inverse_weights.items()}
val_df['sample_weight'] = val_df['lab'].map(normalized_weights)

lab_counts = test_df['lab'].value_counts()
inverse_weights = {lab: 1.0 / count for lab, count in lab_counts.items()}
weight_sum = sum(inverse_weights[lab] for lab in test_df['lab'])
normalized_weights = {lab: (w / weight_sum) * len(test_df) for lab, w in inverse_weights.items()}
test_df['sample_weight'] = test_df['lab'].map(normalized_weights)

# Save the train, validation, and test sets to separate CSV files
train_df.to_csv('train_wave_mask_pairs.csv', index=False)
val_df.to_csv('val_wave_mask_pairs.csv', index=False)
test_df.to_csv('test_wave_mask_pairs.csv', index=False)

# Display the sizes of the train, validation, and test sets
print(f"Train set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Test set size: {len(test_df)}")
print(f"Total size: {len(train_df) + len(val_df) + len(test_df)}")

In [None]:
from torch.utils.data import DataLoader

# Load datasets for training, validation, and test
train_dataset = ABRWavesDataset('train_wave_mask_pairs.csv')
val_dataset = ABRWavesDataset('val_wave_mask_pairs.csv')
test_dataset = ABRWavesDataset('test_wave_mask_pairs.csv')

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Define batch size
batch_size = 32

In [None]:
# Initialize model, loss function, and optimizer
model = CNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
# Hyperparameter search space
filter1 = 128
filter2 = 32
dropout1 = 0.5
dropout2 = 0.3
dropout_fc = 0.1
lr = 1e-3
weight_decay = 1e-5
optimizer_name = "Adam"
patience = 25

# Model initialization
model = CNN(filter1, filter2, dropout1, dropout2, dropout_fc).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) if optimizer_name == "Adam" else optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
criterion = nn.MSELoss()

best_val_loss = float("inf")
best_model_state = None
early_stop_counter = 0

for epoch in range(1000):  # Early stopping will terminate if needed
    model.train()
    total_loss = 0.0
    total_correct_5, total_correct_10, total_samples = 0, 0, 0
    
    for waveform, target, weight in train_dataloader:
        waveform, target, weight = waveform.to(device), target.to(device), weight.to(device)
        optimizer.zero_grad()
        outputs = model(waveform).squeeze()
        loss = criterion(outputs, target)
        weighted_loss = (loss * weight).mean()
        weighted_loss.backward()
        optimizer.step()
        total_loss += weighted_loss.item()
        
        # Calculate accuracy within 5 and 10
        with torch.no_grad():
            predicted_peak = outputs.detach().cpu().numpy()
            target_peak = target.cpu().numpy()
            total_correct_5 += np.sum(np.abs(predicted_peak - target_peak) <= 5)
            total_correct_10 += np.sum(np.abs(predicted_peak - target_peak) <= 10)
            total_samples += waveform.size(0)
    
    avg_train_loss = total_loss / len(train_dataloader)
    accuracy_5 = (total_correct_5 / total_samples) * 100
    accuracy_10 = (total_correct_10 / total_samples) * 100
    
    # Validation step
    model.eval()
    val_loss, val_correct_5, val_correct_10, val_samples = 0.0, 0, 0, 0
    with torch.no_grad():
        for waveform, target, weight in val_dataloader:
            waveform, target, weight = waveform.to(device), target.to(device), weight.to(device)
            outputs = model(waveform).squeeze()
            loss = criterion(outputs, target)
            weighted_loss = (loss * weight).mean()
            val_loss += weighted_loss.item()
            
            predicted_peak = outputs.detach().cpu().numpy()
            target_peak = target.cpu().numpy()
            val_correct_5 += np.sum(np.abs(predicted_peak - target_peak) <= 5)
            val_correct_10 += np.sum(np.abs(predicted_peak - target_peak) <= 10)
            val_samples += waveform.size(0)
    
    avg_val_loss = val_loss / len(val_dataloader)
    val_accuracy_5 = (val_correct_5 / val_samples) * 100
    val_accuracy_10 = (val_correct_10 / val_samples) * 100

    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1}: Accuracy within 5 points: {val_accuracy_5:2f}%, Accuracy within 10 points: {val_accuracy_10:2f}%')
    
    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_state = model.state_dict()
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break
    
model.load_state_dict(best_model_state)  # Restore best weights


In [None]:
from sklearn.metrics import mean_absolute_error, mean_squared_error
# Initialize lists to store predictions and ground truth values
predictions = []
ground_truths = []

# Set the model to evaluation mode
model.eval()

# Disable gradient calculation
with torch.no_grad():
    for waveform, target, _ in test_dataloader:  # Assuming you have a dataloader set up for evaluation
        # Forward pass
        outputs = model(waveform)
        
        # Convert predictions and ground truth tensors to numpy arrays
        predictions.extend(outputs.squeeze().detach().numpy() * (10/244))
        ground_truths.extend(target.numpy() * (10/244))

# Calculate Mean Absolute Error (MAE)
mae = mean_absolute_error(ground_truths, predictions)
print(f'Mean Absolute Error (MAE): {mae:.10f}')

# Calculate Root Mean Squared Error (RMSE)
rmse = mean_squared_error(ground_truths, predictions, squared=False)
print(f'Root Mean Squared Error (RMSE): {rmse:.10f}')
accuracy_within_10 = sum(abs(np.array(predictions) - np.array(ground_truths)) <= (100/244)) / len(predictions)
print(f'Accuracy within {round(100/244,2)} ms or 10 points: {accuracy_within_10*100}')