/data/
    /species_1/
        audio1.wav
        audio2.wav
    /species_2/
        audio1.wav
        audio2.wav


In [None]:
import torch
import torch.nn as nn
import torchaudio
import torchvision.models as models

# Define or load the BirdNET architecture
class BirdNET(nn.Module):
    def __init__(self, num_classes):
        super(BirdNET, self).__init__()
        # Using a ResNet backbone for example
        self.backbone = models.resnet18(pretrained=True)
        # Freeze early layers to keep pretrained weights
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Replace final layer to match number of bird species
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

# Specify the number of species (classes)
num_classes = 10  # Update with your specific number of species
model = BirdNET(num_classes)


In [None]:
from torch.utils.data import DataLoader
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
import os
from PIL import Image

# Data preprocessing
mel_spectrogram = MelSpectrogram()
amplitude_to_db = AmplitudeToDB()

class BirdDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.data = []
        self.labels = []
        
        for label, species in enumerate(os.listdir(data_path)):
            species_path = os.path.join(data_path, species)
            for file_name in os.listdir(species_path):
                if file_name.endswith('.wav'):
                    self.data.append(os.path.join(species_path, file_name))
                    self.labels.append(label)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        audio_path = self.data[idx]
        waveform, sample_rate = torchaudio.load(audio_path)
        mel_spec = mel_spectrogram(waveform)
        db_spec = amplitude_to_db(mel_spec)
        
        if self.transform:
            db_spec = self.transform(db_spec)
        
        label = self.labels[idx]
        return db_spec, label

train_dataset = BirdDataset('/path/to/your/train_data', transform=your_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)


In [None]:
# Set up optimizer and loss
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Fine-tuning loop
num_epochs = 5  # Adjust based on your dataset size
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')


In [None]:
# Save the fine-tuned model
torch.save(model.state_dict(), 'fine_tuned_birdnet.pth')

# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total} %')


In [None]:
mport mlflow
import mlflow.pytorch

with mlflow.start_run():
    # Training loop with MLflow logging
    for epoch in range(num_epochs):
        # Training code with logging added here
        mlflow.log_metric("training_loss", running_loss/len(train_loader), step=epoch)
    
    # Log the final model
    mlflow.pytorch.log_model(model, "fine_tuned_birdnet")