In [1]:
import import_ipynb
from pytorch_model_fns import SpectrogramCNN, load_model
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms
from IPython.display import Audio, clear_output
from ipywidgets import widgets

# Load model

In [2]:
model = SpectrogramCNN(dropout_rate=0.5)
trained_model = load_model(model, load_path='/Users/hela/Code/pata/best_model.pth')


Spectrogram (224x224) CNN model
Architecture:
- 4 convolutional blocks (BatchNorm, ReLU, MaxPool)
- global average pooling
- 3 fully connected layers with dropout


Best model from: epoch 19
Validation Accuracy: 95.31%
Validation Loss: 0.1281


# Get model prediction
### manual image choice

In [3]:
def predict_image_manual(trained_model, csv_file='/Users/hela/Code/pata/data_unlabeled.csv', class_names=['pa','ta']):
    # load image
    df = pd.read_csv(csv_file)
    image = Image.open(df.iloc[91]['image_path']).convert('RGB')

    # transform (tensor)
    test_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = test_transform(image).unsqueeze(0)

    # get results
    with torch.no_grad():
        outputs = trained_model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, prediction = torch.max(probabilities, 1)
        label = class_names[prediction.item()]
    
    #return prediction.item(), confidence.item(), probabilities.numpy()[0]
    return outputs, probabilities, confidence, prediction, label

In [4]:
outputs, probabilities, confidence, prediction, label = predict_image_manual(trained_model)
outputs, probabilities, confidence, prediction, label

FileNotFoundError: [Errno 2] No such file or directory: '/Users/hela/Code/pata/data_unlabeled.csv'

## random image choice

In [5]:
def predict_image(trained_model, path_img, class_names=['pa','ta']):
    # load image
    image = Image.open(path_img).convert('RGB')

    # transform (tensor)
    test_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = test_transform(image).unsqueeze(0)

    # get results
    with torch.no_grad():
        outputs = trained_model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, prediction = torch.max(probabilities, 1)
        label = class_names[prediction.item()]
    
    #return prediction.item(), confidence.item(), probabilities.numpy()[0]
    return image, confidence, label

In [6]:
data_df = pd.read_csv('/Users/hela/Code/pata/data_unlabeled.csv')

unlabeled_trials = data_df[data_df['label'].isna()]
rand_row = unlabeled_trials.sample(n=1)
path_img = rand_row['image_path'].iloc[0]
image, confidence, label = predict_image(trained_model, path_img)
path_audio = rand_row['audio_path'].iloc[0]
print(f'File name: {rand_row['name'].iloc[0]}')
print(f'Label: {label}\nConfidence: {confidence.item():.2f}')
display(Audio(path_audio, autoplay=True))
display(image)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/hela/Code/pata/data_unlabeled.csv'

# Widget
### Widget to look at preds

In [7]:
def visual_predict(trained_model, csv_file='/Users/hela/Code/pata/data_labeling.csv'):
    df = pd.read_csv(csv_file)
    row_index = None
        
    def set_new_trial():
        nonlocal row_index
        # get unlabeled trials only
        unlabeled_trials = df[df['label'].isna()]
        # select random row and corresponding file info
        rand_row = unlabeled_trials.sample(n=1)
        row_index = rand_row.index[0]
        path_audio = rand_row['audio_path'].iloc[0]
        path_img = rand_row['image_path'].iloc[0]
        # get prediction
        image, confidence, label = predict_image(trained_model, path_img)
        # display example and results
        print(f'File name: {rand_row['name'].iloc[0]}')
        print(f'Label: {label}\nConfidence: {confidence.item():.2f}')
        display(Audio(path_audio, autoplay=True))
        display(image)
        
    def update_display():
        with output: 
            clear_output()
            set_new_trial()
                
    # Widgets
    output = widgets.Output()
    button_next = widgets.Button(description = '->')
    display(widgets.HBox([button_next]))
    display(widgets.HBox([output]))
    
    def click_next(_):
        with output:
            update_display()
    button_next.on_click(click_next)
    
    # Start
    update_display()

In [8]:
visual_predict(trained_model)

HBox(children=(Button(description='->', style=ButtonStyle()),))

HBox(children=(Output(),))

### Widget to eval preds

In [11]:
def eval_predict(trained_model, csv_file='/Users/hela/Code/pata/data_labeling.csv'):
    df = pd.read_csv(csv_file)
    row_index = None
    label = None
    confidence = None
        
    def set_new_trial():
        nonlocal row_index, label, confidence
        # get unlabeled trials only
        unlabeled_trials = df[df['label'].isna()]
        # select random row and corresponding file info
        rand_row = unlabeled_trials.sample(n=1)
        row_index = rand_row.index[0]
        path_audio = rand_row['audio_path'].iloc[0]
        path_img = rand_row['image_path'].iloc[0]
        # get prediction
        image, confidence, label = predict_image(trained_model, path_img)
        # display example and results
        print(f'File name: {rand_row['name'].iloc[0]}')
        print(f'Label: {label}\nConfidence: {confidence.item():.2f}')
        display(Audio(path_audio, autoplay=True))
        display(image)
        print(df['label'].value_counts())
        print(df['label'].notna().sum(), '/', 8640, 'labeled trials.')
        
    def update_display():
        with output: 
            clear_output()
            set_new_trial()
                
    # Widgets
    output = widgets.Output()
    button_corr = widgets.Button(description = '✅')
    button_incorr = widgets.Button(description = '❌')
    button_err = widgets.Button(description = 'Broken')
    button_next = widgets.Button(description = '->')
    button_stop = widgets.Button(description = 'Save progress')
    display(widgets.HBox([button_incorr, button_corr]))
    display(widgets.HBox([button_err, button_next]))
    display(widgets.HBox([output]))
    display(widgets.HBox([button_stop]))
    
    def click_corr(_):
        with output:
            df.at[row_index, 'label'] = label
            update_display()
    button_corr.on_click(click_corr)

    def click_incorr(_):
        with output:
            opposite_label = 'ta' if label == 'pa' else 'pa'
            df.at[row_index, 'label'] = opposite_label
            update_display()
    button_incorr.on_click(click_incorr)
    
    def click_err(_):
        with output:
            df.at[row_index, 'label'] = 'err'
            update_display()
    button_err.on_click(click_err)

    def click_next(_):
        with output:
            update_display()
    button_next.on_click(click_next)

    def click_stop(_):
        with output:
            df.to_csv('/Users/hela/Code/pata/data_labeling.csv', index=False)
            print('New predictions are saved.')
    button_stop.on_click(click_stop)

    # Start
    update_display()

In [12]:
eval_predict(trained_model)

HBox(children=(Button(description='❌', style=ButtonStyle()), Button(description='✅', style=ButtonStyle())))

HBox(children=(Button(description='Broken', style=ButtonStyle()), Button(description='->', style=ButtonStyle()…

HBox(children=(Output(),))

HBox(children=(Button(description='Save progress', style=ButtonStyle()),))

In [13]:
# CHECK NEW LABELS
check_df = pd.read_csv('/Users/hela/Code/pata/data_labeling.csv')
check_df['label'].value_counts()

label
ta     895
pa     805
err     15
Name: count, dtype: int64