In [11]:
# 1. IMPORTS
# ---
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
import os # To navigate file paths
import torch

# 这行代码会自动选择GPU（如果可用），否则退回到CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [5]:
torch.cuda.is_available()


True

In [6]:
# 2. THE DATASET CLASS
# ---
# Its job: Load an audio file, convert it to a spectrogram, and return it with its numerical label.
class SpeechCommandsDataset(Dataset):
    def __init__(self, data_path):
        # TODO: Write logic here to find all audio files.
        # You'll also need to create a mapping from word labels to integers (e.g., "yes" -> 0, "no" -> 1).
        self.audio_paths = []
        self.label_map = {}
        self.transform = torchaudio.transforms.MelSpectrogram(n_mels=128)

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, index):
        audio_path = self.audio_paths[index]

        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_path)

        # Transform to spectrogram
        spectrogram = self.transform(waveform)

        # Get the label
        # TODO: Implement a way to get the word from the file path and convert it to a number.
        label = 0 # Placeholder

        return spectrogram, label


In [9]:
# 3. THE MODEL CLASS
# ---
# Its job: Define the Transformer architecture.
class AudioTransformer(nn.Module):
    def __init__(self, num_input_features=128, num_classes=35):
        super().__init__()
        # Using PyTorch's pre-built Transformer components
        encoder_layer = nn.TransformerEncoderLayer(d_model=num_input_features, nhead=8, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.output_layer = nn.Linear(num_input_features, num_classes)

    def forward(self, spectrogram_batch):
        # Input shape needs to be (batch, time, features) for batch_first=True
        # Spectrograms are often (batch, features, time), so we might need to permute
        x = spectrogram_batch.permute(0, 2, 1)

        x = self.transformer_encoder(x)
        x = x.mean(dim=1) # Average over the time dimension
        predictions = self.output_layer(x)
        return predictions

In [10]:
import torchaudio

print("Downloading SpeechCommands dataset...")

# This command will download the data to a folder named "SpeechCommands"
# in your current directory if it's not already there.
train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=".", download=True)

print("Download complete!")
print(f"Number of training samples: {len(train_dataset)}")


Downloading SpeechCommands dataset...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.26G/2.26G [00:25<00:00, 95.2MB/s]


Download complete!
Number of training samples: 105829


In [None]:
# 4. THE TRAINING SCRIPT
# ---
# This block runs when you execute the python file.
# if __name__ == '__main__':
# Instantiate the Dataset and DataLoader
dataset = SpeechCommandsDataset(data_path="path/to/SpeechCommands/data")
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Instantiate the Model, Loss Function, and Optimizer
model = AudioTransformer()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# The Training Loop
print("Starting training...")
for epoch in range(10): # An "epoch" is one full pass over the dataset
    for spectrograms, labels in data_loader:
        # 1. PREDICT: Pass data through the model
        predictions = model(spectrograms)

        # 2. COMPARE: Calculate the error
        loss = loss_fn(predictions, labels)

        # 3. ADJUST: Update the model's weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} finished, Loss: {loss.item()}")

print("Training complete!")
