# DeepInfant: Infant Cry Classification

This notebook demonstrates how to train and evaluate the DeepInfant model for infant cry classification.

In [None]:
import torch
import matplotlib.pyplot as plt
from train import DeepInfantModel, DeepInfantDataset
from torch.utils.data import DataLoader
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np

## 1. Data Preparation

In [None]:
# Create dataset
dataset = DeepInfantDataset('Data/v2')
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

## 2. Model Training

In [None]:
# Initialize model and training components
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeepInfantModel().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

In [None]:
from train import train_model

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50, device=device)

## 3. Model Evaluation

In [None]:
def plot_training_history(history):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)

## 4. Confusion Matrix

In [None]:
def plot_confusion_matrix(model, val_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.unsqueeze(1).to(device)
            outputs = model(inputs)
            _, preds = outputs.max(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

plot_confusion_matrix(model, val_loader, device)

## 5. Test Predictions

In [None]:
from predict import InfantCryPredictor

predictor = InfantCryPredictor('deepinfant.pth')

# Test on a single file
test_file = "path/to/test/audio.wav"
label, confidence = predictor.predict(test_file)
print(f"Prediction: {label} (Confidence: {confidence:.2%})")