# Libraries

In [4]:
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms
import torchaudio.transforms as T
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import Audio, clear_output, Image

# Dataset class

In [2]:
class DS(Dataset):

    def __init__(self, csv, transform=None):
        self.csv = csv
        self.df = pd.read_csv(csv)
        self.transform = transform
        # label to index mapping
        self.label_to_idx = {label: idx for idx, label in enumerate(sorted(self.df['label'].unique()))}
        self.idx_to_label = {i:l for l,i in self.label_to_idx.items()}
        self.class_names = [self.idx_to_label[i] for i in range(len(self.idx_to_label))]
        print(f'\nDataset initialized with: {len(self.df)} samples')
        print(f'Label mapping: {self.label_to_idx}')
        print(f'Class distribution:\n{self.df['label'].value_counts()}\n')

    # get total number of samples
    def __len__(self):
        return len(self.df)

    # get a single example (image tensor + label index) for a given index
    def __getitem__(self, idx):
        # load image
        image = Image.open(self.df.iloc[idx]['image_path']).convert('RGB')
        # apply transforms
        if self.transform:
            image = self.transform(image)
        # get label
        label = self.label_to_idx[self.df.iloc[idx]['label']]
        return image, label

# Data preparation and loaders
## Data transform & augmentation

In [3]:
def get_transforms(normalize_mean, normalize_std):
    # Training transforms
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        # ColorJitter: Brightness, Contrast, Saturation, Hue (N% of variation)
        # transforms.ColorJitter(brightness=0.2),
        # transforms.ColorJitter(contrast=0.2),
        # transforms.ColorJitter(saturation=0.2),
        # transforms.ColorJitter(hue=0.1),
        # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        # Random Grayscale (with N% probability)
        # transforms.RandomGrayscale(p=0.3),
        # PIL Image (H×W×C, values 0-255) → PyTorch tensor (C×H×W, values 0.0-1.0)
        transforms.ToTensor(),
        # standardizes pixel values by formula 'normalized_value=(pixel_value-mean)/std' for each color channel (R,G,B)
        transforms.Normalize(mean=normalize_mean, std=normalize_std)
    ])

    # Validation transforms
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=normalize_mean, std=normalize_std)
    ])

    return train_transform, val_transform

## Data loaders (train & valid sets)

In [4]:
def prepare_data_loaders(csv,
                         val_split=0.2, 
                         batch_size=32, 
                         normalize_mean=[0.485, 0.456, 0.406], 
                         normalize_std=[0.229, 0.224, 0.225]
                        ):
    
    # create full ds
    full_ds = DS(csv, transform=None)
    label_to_idx = full_ds.label_to_idx
    
    # split into train and validation
    val_size = int(val_split * len(full_ds))
    train_size = len(full_ds) - val_size
    train_ds, val_ds = random_split(full_ds, [train_size, val_size], generator=torch.Generator().manual_seed(42))
    print(f'\nDataset splits (val_split={val_split}):')
    print(f'Training set: {train_size}')
    print(f'Validation set: {val_size}\n')
    
    # dataset transform
    train_transform, val_transform = get_transforms(normalize_mean, normalize_std)
    train_ds.dataset.transform = train_transform
    val_ds.dataset.transform = val_transform

    # create data loaders
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

# Model Architecture (CNN Class)

In [5]:
class SpectrogramCNN(nn.Module):

    def __init__(self, dropout_rate=0.5):
        super(SpectrogramCNN, self).__init__()

        # 1) convolutional blocks (for each block: ConvLayer + BatchNorm + MaxPooling)
        # convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        # batch normalization
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        # max pooling
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 2) adaptive (global) pooling layer
        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        
        # 3) fully connected layers
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 2)  # 2 classes for output
        # dropout
        self.dropout = nn.Dropout(dropout_rate)

        print('\nSpectrogram (224x224) CNN model')
        print('Architecture:')
        print('- 4 convolutional blocks (BatchNorm, ReLU, MaxPool)')
        print('- global average pooling')
        print('- 3 fully connected layers with dropout\n')

    def forward(self, x):
        
        # convolutional blocks
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        
        # adaptive (global) pooling layer
        x = self.global_pool(x)
        # flattering the layer
        x = x.view(x.size(0), -1)
        
        # fully connected layers (with dropout)
        x = F.relu(self.fc1(self.dropout(x)))
        x = F.relu(self.fc2(self.dropout(x)))
        x = self.fc3(x)
        
        return x

# Training Function

In [6]:
def train_epoch(model, train_loader, criterion, optimizer):
    # start training
    model.train()
    total_loss = 0.0  # total loss across all batches
    total_samples = 0  # total samples processed
    total_correct = 0  # correct predictions

    # initialize progress bar
    pbar = tqdm(train_loader, desc='Training', leave=False)  # remove bar after complition

    # loop for processing data in batches
    for batch_idx, (images, labels) in enumerate(pbar):  # 'enumerate' gives both item and idx when looping
        optimizer.zero_grad()  # reset grads before next calculation
        # forward pass
        outputs = model(images)  # get preds
        loss = criterion(outputs, labels)  # calculate loss
        # backward pass
        loss.backward()  # compute grads 
        optimizer.step()  # update weights
        # stats
        total_loss += loss.item()  # accumulate loss across batches (here, loss tensor becomes regular Py num)
        total_samples += labels.size(0)  # add to total count of samples processed num of samples in this batch (1st dimension)
        _, predicted = torch.max(outputs.data, 1)  # find max val and its idx in each row of outputs, get only latter
        total_correct += (predicted == labels).sum().item()  # create boolean tensor, count true vals, convert into Py num, add to total

        # update progress bar
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100. * total_correct / total_samples:.2f}%'})

    # epoch averages
    epoch_loss = total_loss / len(train_loader)  # divide by num of batches in train data
    epoch_acc = 100. * total_correct / total_samples

    return epoch_loss, epoch_acc

# Validation Function

In [7]:
def validate_epoch(model, val_loader, criterion):
    # start validation
    model.eval()
    total_loss = 0.0  # total loss across all batches
    total_samples = 0  # total samples processed
    total_correct = 0  # correct predictions
    all_predictions = []
    all_labels = []

    # initialize progress bar
    pbar = tqdm(val_loader, desc='Validation', leave=False)  # progress-bar-wrapped version of val_loader

    with torch.no_grad():  # don't calculate grads
        for images, labels in pbar:
            # forward pass
            outputs = model(images)  # get preds
            loss = criterion(outputs, labels)  # calculate loss
            # stats
            total_loss += loss.item()  # accumulate loss across batches (here, loss tensor becomes regular Py num)
            total_samples += labels.size(0)  # add to total count of samples processed num of samples in this batch (1st dimension)
            _, predicted = torch.max(outputs.data, 1)  # find max val and its idx in each row of outputs, get only latter
            total_correct += (predicted == labels).sum().item()  # create boolean tensor, count true vals, convert into Py num, add to total
            # store preds and labels
            all_predictions.extend(predicted.numpy())  # convert to NumPy array, add to list
            all_labels.extend(labels.numpy())  # convert to NumPy array, add to list

            # update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100. * total_correct / total_samples:.2f}%'})

    # epoch averages
    epoch_loss = total_loss / len(val_loader)  # divide by num of batches in valid data
    epoch_acc = 100. * total_correct / total_samples

    return epoch_loss, epoch_acc, all_predictions, all_labels

# Main train loop

In [4]:
def train_model(model,
                train_loader,
                val_loader,
                criterion,
                optimizer,
                num_epochs,
                scheduler,
                patience_epochs=10,
                save_path='/Users/hela/Code/pata/best_model.pth'
               ):
    
    trained_model = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]}  # dict for graphs
    best_val_acc = 0.0
    patience = patience_epochs  # wait N epochs without improvement before stopping
    patience_counter = 0  # count epochs without improvement

    print('\n'+'='*70)
    print('Training')
    print('='*70+'\n')

    # loop for each epoch
    for epoch in range(num_epochs):
        print(f'\nEpoch [{epoch+1}/{num_epochs}]')
        # training
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
        # validation
        val_loss, val_acc, _, _ = validate_epoch(model, val_loader, criterion)
        # update lr
        scheduler.step(val_loss)  # adjust lr if val_loss doesn't improve
        
        # print results
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # save history for graphs
        trained_model['train_loss'].append(train_loss)
        trained_model['train_acc'].append(train_acc)
        trained_model['val_loss'].append(val_loss)
        trained_model['val_acc'].append(val_acc)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch':epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'val_acc': val_acc,
                        'val_loss': val_loss},
                       save_path)
            print(f"✓ Model saved. (Val Acc: {val_acc:.2f}%)")
            patience_counter = 0
        else:
            patience_counter += 1

        # early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break
    
    print('\n'+'='*70)
    print('Training completed')
    print('='*70+'\n')

    return trained_model

# Load best model

In [5]:
def load_model(model, load_path='/Users/hela/Code/pata/best_model.pth'):
    checkpoint = torch.load(load_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    print(f"\nBest model from: epoch {checkpoint['epoch']+1}")
    print(f"Validation Accuracy: {checkpoint['val_acc']:.2f}%")
    print(f"Validation Loss: {checkpoint['val_loss']:.4f}")

    return model

# Evaluate preds
## Get preds

In [12]:
def predict_image(trained_model, path_img, class_names=['pa','ta']):
    # load image
    image = Image.open(path_img).convert('RGB')

    # transform (tensor)
    test_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = test_transform(image).unsqueeze(0)

    # get results
    with torch.no_grad():
        outputs = trained_model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, prediction = torch.max(probabilities, 1)
        label = class_names[prediction.item()]
    
    #return prediction.item(), confidence.item(), probabilities.numpy()[0]
    return image, confidence, label

## Visualize preds

In [11]:
def eval_predict(trained_model, csv_file='/Users/hela/Code/pata/data_labeling.csv'):
    df = pd.read_csv(csv_file)
    row_index = None
    label = None
    confidence = None
        
    def set_new_trial():
        nonlocal row_index, label, confidence
        # get unlabeled trials only
        unlabeled_trials = df[df['label'].isna()]
        # select random row and corresponding file info
        rand_row = unlabeled_trials.sample(n=1)
        row_index = rand_row.index[0]
        path_audio = rand_row['audio_path'].iloc[0]
        path_img = rand_row['image_path'].iloc[0]
        # get prediction
        image, confidence, label = predict_image(trained_model, path_img)
        # display example and results
        print(f'File name: {rand_row['name'].iloc[0]}')
        print(f'Prediction: {label}\nConfidence: {confidence.item():.2f}')
        display(Audio(path_audio, autoplay=True))
        display(image)
        print(df['label'].value_counts())
        print(df['label'].notna().sum(), '/', 8640, 'labeled trials.')
        
    def update_display():
        with output: 
            clear_output()
            set_new_trial()
                
    # Widgets
    output = widgets.Output()
    button_corr = widgets.Button(description = '✅')
    button_incorr = widgets.Button(description = '❌')
    button_err = widgets.Button(description = 'Broken')
    button_next = widgets.Button(description = '->')
    button_stop = widgets.Button(description = 'Save progress')
    display(widgets.HBox([button_incorr, button_corr]))
    display(widgets.HBox([button_err, button_next]))
    display(widgets.HBox([output]))
    display(widgets.HBox([button_stop]))
    
    def click_corr(_):
        with output:
            df.at[row_index, 'label'] = label
            update_display()
    button_corr.on_click(click_corr)

    def click_incorr(_):
        with output:
            opposite_label = 'ta' if label == 'pa' else 'pa'
            df.at[row_index, 'label'] = opposite_label
            update_display()
    button_incorr.on_click(click_incorr)
    
    def click_err(_):
        with output:
            df.at[row_index, 'label'] = 'err'
            update_display()
    button_err.on_click(click_err)

    def click_next(_):
        with output:
            update_display()
    button_next.on_click(click_next)

    def click_stop(_):
        with output:
            df.to_csv('/Users/hela/Code/pata/data_labeling.csv', index=False)
            print('New predictions are saved.')
    button_stop.on_click(click_stop)

    # Start
    update_display()

# Final table
## Predict table

In [2]:
def predict_table(trained_model, csv_file='/Users/hela/Code/pata/data_labeling.csv'):
    df = pd.read_csv(csv_file)
    df_preds = df[df['label'].isna()]
    
    preds = []
    confs = []
    
    for idx, row in tqdm(df_preds.iterrows(), total=len(df_preds)):
        path_img = row['image_path']
        _, confidence, label = predict_image(trained_model, path_img)
        preds.append(label)
        confs.append(confidence.item())
        
    df_preds['predict'] = preds
    df_preds['confidence'] = confs

    df_preds.to_csv('/Users/hela/Code/pata/data_predicted.csv', index=False)
    print("CSV file with predicted labels is saved as 'data_predicted.csv'.")

## Eval predict table

In [3]:
def eval_predict_table(csv_file='/Users/hela/Code/pata/data_predicted.csv', conf_low=0, conf_high=1, unchecked_only=True):
    df_original = pd.read_csv(csv_file)
    df_filtered = df_original[(df_original['confidence'] > conf_low) & (df_original['confidence'] < conf_high)]
    row_index = None
    label = None
    confidence = None
        
    def set_new_trial():
        nonlocal row_index, label, df_filtered
        # check condition
        if unchecked_only == False:
            df_to_sample = df_filtered
        else:
            df_to_sample = df_filtered[df_filtered['label'].isna()]
        # check condition
        if len(df_to_sample) == 0:
            print(f'No unchecked trials with confidence from {conf_low} to {conf_high}: {len(df_to_sample)}')
            return
        # select random row and corresponding file info
        rand_row = df_to_sample.sample(n=1)
        row_index = rand_row.index[0]
        path_audio = rand_row['audio_path'].iloc[0]
        path_img = rand_row['image_path'].iloc[0]
        label = rand_row['predict'].iloc[0]
        image = Image(path_img)
        # display example and results
        print(f'Selected trials with confidence from {conf_low} to {conf_high}: {len(df_to_sample)}')
        print('-' * 30)
        print(f'File name: {rand_row['name'].iloc[0]}')
        print(f'Prediction: {label}\nConfidence: {rand_row['confidence'].iloc[0]:.2f}')
        display(Audio(path_audio, autoplay=True))
        display(image)
    
    def update_display():
        with output:
            clear_output()
            set_new_trial()
                
    # Widgets
    output = widgets.Output()
    button_corr = widgets.Button(description = '✅')
    button_incorr = widgets.Button(description = '❌')
    button_err = widgets.Button(description = 'Broken')
    button_next = widgets.Button(description = '->')
    button_stop = widgets.Button(description = 'Save progress')
    display(widgets.HBox([button_incorr, button_corr]))
    display(widgets.HBox([button_err, button_next]))
    display(widgets.HBox([output]))
    display(widgets.HBox([button_stop]))

    def click_corr(_):
        with output:
            df_original.at[row_index, 'label'] = label
            df_filtered.at[row_index, 'label'] = label
            update_display()
    button_corr.on_click(click_corr)
    
    def click_incorr(_):
        with output:
            opposite_label = 'ta' if label == 'pa' else 'pa'
            df_original.at[row_index, 'label'] = opposite_label
            df_filtered.at[row_index, 'label'] = opposite_label
            update_display()
    button_incorr.on_click(click_incorr)
    
    def click_err(_):
        with output:
            df_original.at[row_index, 'label'] = 'err'
            df_filtered.at[row_index, 'label'] = 'err'
            update_display()
    button_err.on_click(click_err)

    def click_next(_):
        with output:
            update_display()
    button_next.on_click(click_next)

    def click_stop(_):
        with output:
            df_original.to_csv('/Users/hela/Code/pata/data_predicted.csv', index=False)
            print('Corrected predictions are saved.')
    button_stop.on_click(click_stop)

    # Start
    update_display()