In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from PIL import Image
import numpy as np

class MelSpectrogramDataset(Dataset):
    def __init__(self, mel_folder, transform=None):
        self.mel_folder = mel_folder
        self.transform = transform
        self.mel_files = [f for f in os.listdir(mel_folder) if f.endswith('.png')]
        
    def __len__(self):
        return len(self.mel_files)
    
    def __getitem__(self, idx):
        mel_file = self.mel_files[idx]
        mel_path = os.path.join(self.mel_folder, mel_file)
        
        # Load the mel spectrogram
        image = Image.open(mel_path).convert('RGB')
        label = self.get_label_from_filename(mel_file)  # Implement this method to extract the label
        
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def get_label_from_filename(self, filename):
        # Implement your logic to extract the label from the filename
        pass

# Cross-validation
fold_acc = []
for fold in range(1, 6):
    print(f'Testing fold {fold}')
    
    # Create dataset for training and validation
    train_dataset = MelSpectrogramDataset(mel_folder=f'/path/to/train_fold{fold}')
    val_dataset = MelSpectrogramDataset(mel_folder=f'/path/to/val_fold{fold}')
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    model = ESCModel(pretrained=True)  # Adjust according to your model
    model.train()
    
    # Training loop
    for epoch in range(num_epochs):
        for images, labels in tqdm(train_loader):
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
    # Evaluation
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    fold_acc.append(accuracy)
    print(f'Fold {fold} Accuracy: {accuracy:.4f}')

# Average accuracy across folds
average_accuracy = np.mean(fold_acc)
print(f'Average Accuracy: {average_accuracy:.4f}')