In [10]:
pip install torch


Note: you may need to restart the kernel to use updated packages.


In [50]:
import os
import torch
import torchaudio
from torchaudio.transforms import MelSpectrogram, Resample
from sklearn.model_selection import train_test_split
import pandas as pd
import torch.nn.functional as F

# Define constants
SAMPLE_RATE = 16000  # Target sample rate for all audio files
N_MELS = 128         # Number of Mel filterbanks for spectrograms
DATA_DIR = "en"      # Path to the Common Voice English dataset
TARGET_LENGTH = 500  # Fixed length for the time dimension

def preprocess_audio(file_path, sample_rate=SAMPLE_RATE, target_length=TARGET_LENGTH):
    """
    Load and preprocess audio by resampling, converting to mel-spectrogram, 
    and ensuring a single channel with fixed dimensions.
    """
    waveform, orig_sample_rate = torchaudio.load(file_path)
    # Convert to mono (single channel)
    if waveform.size(0) > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Resample to target sample rate if needed
    if orig_sample_rate != sample_rate:
        resampler = Resample(orig_freq=orig_sample_rate, new_freq=sample_rate)
        waveform = resampler(waveform)

    # Generate Mel-Spectrogram
    mel_transform = MelSpectrogram(sample_rate=sample_rate, n_mels=N_MELS)
    mel_spectrogram = mel_transform(waveform)  # Shape: [1, freq, time]

    # Pad or truncate the spectrogram to the target length
    if mel_spectrogram.size(2) < target_length:
        padding = target_length - mel_spectrogram.size(2)
        mel_spectrogram = F.pad(mel_spectrogram, (0, padding))  # Pad along the time dimension
    else:
        mel_spectrogram = mel_spectrogram[:, :, :target_length]  # Truncate if longer than target_length
    
    return mel_spectrogram


def load_metadata(data_dir):
    """
    Load metadata CSV containing audio file paths and corresponding transcriptions.
    """
    metadata_path = os.path.join(data_dir, "validated.tsv")
    metadata = pd.read_csv(metadata_path, sep="\t")
    # Keep only necessary columns
    metadata = metadata[["path", "sentence"]]
    return metadata


def prepare_dataset(data_dir, sample_rate=SAMPLE_RATE, test_size=0.2, val_size=0.1):
    """
    Load the dataset, preprocess audio, and split into train/val/test sets.
    """
    # Load metadata
    metadata = load_metadata(data_dir)
    # Prepend data directory to file paths
    metadata["path"] = metadata["path"].apply(lambda x: os.path.join(data_dir, "clips", x))
    
    # Preprocess audio files and create feature-label pairs
    features, labels = [], []
    for _, row in metadata.iterrows():
        try:
            mel_spectrogram = preprocess_audio(row["path"], sample_rate)
            features.append(mel_spectrogram)
            labels.append(row["sentence"])
        except Exception as e:
            print(f"Error processing file {row['path']}: {e}")
    
    # Split into train, validation, and test sets
    train_features, test_features, train_labels, test_labels = train_test_split(
        features, labels, test_size=test_size, random_state=42
    )
    train_features, val_features, train_labels, val_labels = train_test_split(
        train_features, train_labels, test_size=val_size, random_state=42
    )
    
    print(f"Dataset split: Train: {len(train_features)}, Val: {len(val_features)}, Test: {len(test_features)}")
    return (train_features, train_labels), (val_features, val_labels), (test_features, test_labels)


# Main
if __name__ == "__main__":
    # Prepare dataset
    print("Preparing dataset...")
    (train_data, train_labels), (val_data, val_labels), (test_data, test_labels) = prepare_dataset(DATA_DIR)
    
    # Save splits for later use
    torch.save({"features": train_data, "labels": train_labels}, "train_data.pt")
    torch.save({"features": val_data, "labels": val_labels}, "val_data.pt")
    torch.save({"features": test_data, "labels": test_labels}, "test_data.pt")
    
    print("Dataset preparation completed.")


Preparing dataset...




Dataset split: Train: 2453, Val: 273, Test: 682
Dataset preparation completed.


In [20]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision

# Define MobileNetV2 + LSTM Model
class MobileNetLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(MobileNetLSTM, self).__init__()
        
        # MobileNetV2 backbone
        self.mobilenet = torchvision.models.mobilenet_v2(pretrained=True).features
        self.mobilenet.eval()  # Set to evaluation mode

        # Assuming the output feature size after MobileNet is 1280 (channels)
        # Adjusting the LSTM input size
        self.input_dim = 1280  # This should be the last dimension of MobileNet output
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        
        # Fully connected layer after LSTM
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # If input has 1 channel (grayscale), convert to 3 channels (RGB)
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)  # Repeat the single channel to 3 channels
    
        # Pass through MobileNet
        x = self.mobilenet(x)
        x = x.mean(dim=2)  # Reducing the height dimension (averaging)
    
        # Debug: Print shape before passing to LSTM
        #print(f"Shape after MobileNet: {x.shape}")
        
        # Flatten the output to match the LSTM input size (batch_size, seq_len, feature_dim)
        x = x.view(x.size(0), -1, self.input_dim)  # Flatten to (batch_size, seq_len, 1280)
    
        # Debug: Print shape after reshaping for LSTM
        #print(f"Shape before LSTM: {x.shape}")
    
        # Pass through LSTM
        x, _ = self.lstm(x)
    
        # Pass through fully connected layer
        x = self.fc(x[:, -1, :])  # Take the last output of the LSTM sequence
        
        return x




# Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    """
    Train MobileNetV2 + LSTM model.
    """
    device = torch.device("cpu")
    model.to(device)

    for epoch in range(epochs):
        model.train()
        train_loss, train_correct = 0.0, 0
        for features, labels in train_loader:
            features, labels = features.to(device), labels.to(device)
            # Check label values
            assert labels.min() >= 0, f"Labels contain negative values: {labels.min()}"
            assert labels.max() < num_classes, f"Labels contain values greater than or equal to {num_classes}: {labels.max()}"
        
            optimizer.zero_grad()
            outputs = model(features)
            # Assert correct shape of outputs and labels
            assert outputs.size(0) == labels.size(0), f"Mismatch between batch size of outputs and labels: {outputs.size(0)} vs {labels.size(0)}"
            assert outputs.size(1) == num_classes, f"Mismatch between output size and number of classes: {outputs.size(1)} vs {num_classes}"
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * features.size(0)
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == labels).sum().item()

        # Validation
        model.eval()
        val_loss, val_correct = 0.0, 0
        with torch.no_grad():
            for features, labels in val_loader:
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)

                # Assert labels are in range
                assert labels.min() >= 0 and labels.max() < num_classes, f"Labels must be between 0 and {num_classes - 1}. Found labels: {labels.min()} - {labels.max()}"
                
                loss = criterion(outputs, labels)

                val_loss += loss.item() * features.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()

        print(
            f"Epoch {epoch + 1}/{epochs}, "
            f"Train Loss: {train_loss / len(train_loader.dataset):.4f}, "
            f"Train Accuracy: {train_correct / len(train_loader.dataset):.4f}, "
            f"Val Loss: {val_loss / len(val_loader.dataset):.4f}, "
            f"Val Accuracy: {val_correct / len(val_loader.dataset):.4f}"
        )
    model.summary()
    print("Training completed.")
    return model
    
# Main Pipeline
if __name__ == "__main__":
    # Load preprocessed data
    train_data = torch.load("train_data.pt")
    val_data = torch.load("val_data.pt")
    
    #print("Train Data Keys:", train_data.keys())
    #print("Train Data Features Shape:", train_data["features"][0].shape if "features" in train_data else "Missing")
    #print("Train Data Labels:", train_data["labels"][:5] if "labels" in train_data else "Missing")
    
    #print("Validation Data Keys:", val_data.keys())
    #print("Validation Data Features Shape:", val_data["features"][0].shape if "features" in val_data else "Missing")
    #print("Validation Data Labels:", val_data["labels"][:5] if "labels" in val_data else "Missing")

        
    # Map labels to numerical indices if they are strings
    if isinstance(train_data["labels"][0], str):
        all_labels = train_data["labels"] + val_data["labels"]
        label_mapping = {label: idx for idx, label in enumerate(set(all_labels))}
        train_labels = torch.tensor([label_mapping[label] for label in train_data["labels"]])
        val_labels = torch.tensor([label_mapping[label] for label in val_data["labels"]])
    else:
        train_labels = torch.tensor(train_data["labels"])
        val_labels = torch.tensor(val_data["labels"])
    
    train_features = torch.stack(train_data["features"])  # Convert list to tensor
    val_features = torch.stack(val_data["features"])  # Convert list to tensor
    
    # Now you can proceed with training
    print("Train Features Shape:", train_features.shape)
    print("Train Labels Shape:", train_labels.shape)

    # Create DataLoaders
    train_dataset = TensorDataset(train_features, train_labels)
    val_dataset = TensorDataset(val_features, val_labels)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    # Define device
    device = torch.device("cpu")
    print("Using device:", device)
    print(f"Unique labels: {torch.unique(labels)}")
    # Example usage:
    input_dim = 1280  # MobileNet output channels
    hidden_dim = 128  # LSTM hidden size
    # Combine all unique labels from both datasets
    all_labels = torch.cat([train_labels, val_labels])
    unique_labels = torch.unique(all_labels)
    label_mapping = {label.item(): idx for idx, label in enumerate(unique_labels)}
    
    # Map train and validation labels
    train_labels = torch.tensor([label_mapping[label.item()] for label in train_labels])
    val_labels = torch.tensor([label_mapping[label.item()] for label in val_labels])
    
    # Update num_classes based on unique labels
    num_classes = len(unique_labels)
    print("Updated num_classes:", num_classes)

    for batch_idx, (features, labels) in enumerate(train_loader):
        # Print unique labels in the batch for debugging
        print(f"Batch {batch_idx+1} labels: {torch.unique(labels)}")
        
        # Check if labels are in the correct range
        assert labels.min() >= 0, f"Labels contain negative values: {labels.min()}"
        assert labels.max() < num_classes, f"Labels contain values greater than or equal to {num_classes}: {labels.max()}"

    model = MobileNetLSTM(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
        
    for param in model.parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print("NaN detected in gradients!")
            if torch.isinf(param.grad).any():
                print("Inf detected in gradients!")

    # Train the model
    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10)
    for param in model.parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print("NaN detected in gradients!")
            if torch.isinf(param.grad).any():
                print("Inf detected in gradients!")

    print("Model device:", next(model.parameters()).device)
    print("Input device:", features.device)
    
    # Save the trained model
    torch.save(trained_model.state_dict(), "mobilenet_lstm.pth")
    converted_labels = [label_mapping[label] for label in original_labels]

    torch.save({"features": features, "labels": converted_labels}, "train_data.pt")


  train_data = torch.load("train_data.pt")
  val_data = torch.load("val_data.pt")


Train Features Shape: torch.Size([2453, 1, 128, 500])
Train Labels Shape: torch.Size([2453])
Using device: cpu
Unique labels: tensor([  23,   53,   76,   97,  242,  482,  489,  734,  751,  831, 1060, 1316,
        1596, 1630, 1665, 1765, 2056, 2175, 2176, 2306, 2494])
Updated num_classes: 2725
Batch 1 labels: tensor([ 109,  295,  311,  428,  459,  524,  536,  748,  819,  822,  899, 1001,
        1104, 1368, 1372, 1523, 1649, 1776, 1857, 1871, 1926, 1947, 1970, 1979,
        2071, 2332, 2443, 2472, 2562, 2585, 2638, 2656])
Batch 2 labels: tensor([  47,  148,  220,  423,  436,  505,  622,  629,  728,  814,  836, 1080,
        1285, 1595, 1596, 1599, 1668, 1810, 1902, 1958, 2118, 2128, 2191, 2303,
        2389, 2470, 2506, 2537, 2538, 2632, 2670, 2708])
Batch 3 labels: tensor([ 380,  439,  474,  494,  725,  745,  990, 1013, 1128, 1178, 1279, 1315,
        1366, 1405, 1410, 1451, 1527, 1608, 1650, 1701, 1786, 2041, 2188, 2301,
        2306, 2329, 2415, 2430, 2445, 2454, 2478, 2576])
Batch 

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import numpy as np
from sklearn.model_selection import train_test_split
from collections import defaultdict

# Define the Few-Shot Learner combining MAML and Prototypical Networks
class FewShotLearner(nn.Module):
    def __init__(self, base_model, feature_dim):
        super(FewShotLearner, self).__init__()
        self.base_model = base_model  # MobileNetV2 + LSTM backbone
        self.feature_dim = feature_dim
        self.protonet_fc = nn.Linear(feature_dim, feature_dim)

    def forward(self, x):
        # Extract features
        features = self.base_model(x)
        # Prototype network uses transformed embeddings
        return self.protonet_fc(features)

    def get_embeddings(self, x):
        # Directly return base model features
        return self.base_model(x)

# Helper functions for Prototypical Networks
def compute_prototypes(support_features, support_labels):
    """
    Compute class prototypes as the mean feature vector for each class in the support set.
    """
    prototypes = defaultdict(list)
    for feature, label in zip(support_features, support_labels):
        prototypes[label.item()].append(feature)
    return {k: torch.stack(v).mean(0) for k, v in prototypes.items()}

def prototypical_loss(query_features, query_labels, prototypes):
    """
    Compute the Prototypical Networks loss using distances to prototypes.
    """
    dists = torch.stack([
        torch.cdist(query_features, prototype.unsqueeze(0)) for prototype in prototypes.values()
    ], dim=1).squeeze(2)
    labels = torch.tensor(list(prototypes.keys()), device=query_features.device)
    return nn.CrossEntropyLoss()(dists, query_labels)

# MAML Training Loop
def maml_train(model, task_loader, optimizer, meta_lr=0.001, inner_steps=1, inner_lr=0.01):
    """
    Train using MAML on sampled tasks.
    """
    meta_optimizer = optim.Adam(model.parameters(), lr=meta_lr)
    for task in task_loader:
        support, query = task
        support_features, support_labels = support
        query_features, query_labels = query

        # Clone model for inner-loop updates
        maml_model = FewShotLearner(model.base_model, model.feature_dim)
        maml_model.load_state_dict(model.state_dict())
        maml_optimizer = optim.SGD(maml_model.parameters(), lr=inner_lr)

        # Inner loop updates
        maml_model.train()
        for _ in range(inner_steps):
            maml_optimizer.zero_grad()
            support_preds = maml_model(support_features)
            loss = nn.CrossEntropyLoss()(support_preds, support_labels)
            loss.backward()
            maml_optimizer.step()

        # Meta-update
        maml_model.eval()
        query_preds = maml_model(query_features)
        meta_loss = nn.CrossEntropyLoss()(query_preds, query_labels)

        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()

# Prototypical Training Loop
def prototypical_train(model, task_loader, optimizer):
    """
    Train using Prototypical Networks on sampled tasks.
    """
    for task in task_loader:
        support, query = task
        support_features, support_labels = support
        query_features, query_labels = query

        # Compute prototypes
        model.eval()
        support_embeddings = model.get_embeddings(support_features)
        prototypes = compute_prototypes(support_embeddings, support_labels)

        # Prototypical loss
        query_embeddings = model.get_embeddings(query_features)
        proto_loss = prototypical_loss(query_embeddings, query_labels, prototypes)

        # Backward pass
        optimizer.zero_grad()
        proto_loss.backward()
        optimizer.step()

# Combined Training
def train_combined(model, task_loader, maml_steps=1, proto_steps=1, epochs=5):
    """
    Train the model using both MAML and Prototypical Networks.
    """
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        for step in range(maml_steps):
            maml_train(model, task_loader, optimizer)
        for step in range(proto_steps):
            prototypical_train(model, task_loader, optimizer)
    print("Training completed.")

# Task Loader (Simulating Few-Shot Tasks)
def create_task_loader(features, labels, num_tasks=100, k_shot=5, q_query=15):
    """
    Generate tasks for meta-learning with support and query splits.
    """
    task_loader = []
    unique_labels = list(set(labels.numpy()))
    for _ in range(num_tasks):
        selected_labels = np.random.choice(unique_labels, size=len(unique_labels), replace=False)
        support, query = [], []
        for label in selected_labels:
            indices = np.where(labels.numpy() == label)[0]
            np.random.shuffle(indices)
            support.extend(features[indices[:k_shot]])
            query.extend(features[indices[k_shot:k_shot + q_query]])
        task_loader.append(((torch.stack(support), torch.tensor(selected_labels)), (torch.stack(query), torch.tensor(selected_labels))))
    return task_loader

# Main Pipeline
if __name__ == "__main__":
    # Load preprocessed single-word dataset
    data = torch.load("single_word_data.pt")
    features = torch.stack(data["features"])  # Audio features (mel-spectrograms)
    labels = torch.tensor(data["labels"])     # Word labels

    # Train/Validation Split
    train_features, val_features, train_labels, val_labels = train_test_split(features, labels, test_size=0.2, random_state=42)

    # Create Task Loaders
    task_loader = create_task_loader(train_features, train_labels)

    # Load pretrained MobileNetV2+LSTM model
    pretrained_model = MobileNetLSTM(input_dim=features.size(-1), hidden_dim=128, num_classes=len(set(labels.numpy())))
    pretrained_model.load_state_dict(torch.load("mobilenet_lstm.pth"))

    # Few-Shot Learner
    feature_dim = 256  # Dimension after embedding
    few_shot_model = FewShotLearner(pretrained_model, feature_dim)

    # Train the combined model
    train_combined(few_shot_model, task_loader)


Unique labels: tensor([   0,    1,    2,  ..., 2722, 2723, 2724])
Number of unique classes: 2452
