In [46]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import joblib
from PIL import Image
from torchvision import models, transforms

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Device:", device)
 

Device: mps


In [47]:
audio_model = joblib.load("audio_model.joblib")
audio_scaler = joblib.load("audio_scaler.joblib")
audio_feature_columns = joblib.load("audio_feature_columns.joblib")
audio_classes = joblib.load("audio_classes.joblib")

print("Audio classes example:", audio_classes[:5])
print("Num audio features:", len(audio_feature_columns))


Audio classes example: ['aguimp', 'alpina', 'aluco', 'apiaster', 'apivorus']
Num audio features: 169


In [48]:
ckpt = torch.load("image_model.pth", map_location=device)
image_classes = ckpt["classes"]

image_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
image_model.fc = nn.Linear(image_model.fc.in_features, len(image_classes))
image_model.load_state_dict(ckpt["model_state"])
image_model = image_model.to(device)
image_model.eval()

print("Image classes example:", image_classes[:5])
print("Num image classes:", len(image_classes))


Image classes example: ['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet']
Num image classes: 200


In [49]:
def topk_from_probs(probs, labels, k=5):
    idx = np.argsort(probs)[::-1][:k]
    return [(labels[i], float(probs[i])) for i in idx]

def pretty_image_label(label):
    # "001.Black_footed_Albatross" -> "Black footed Albatross"
    if "." in label:
        label = label.split(".", 1)[1]
    return label.replace("_", " ")

# Image transforms (ImageNet normalization for ResNet)
img_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

def predict_image_topk(img_path, k=5):
    x = img_tfms(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        probs = F.softmax(image_model(x), dim=1).squeeze(0).cpu().numpy()
    return probs, topk_from_probs(probs, image_classes, k=k)

def predict_audio_topk(X_audio_row_df, k=5):
    # Ensure correct column order
    X_audio_row_df = X_audio_row_df[audio_feature_columns]
    probs = audio_model.predict_proba(audio_scaler.transform(X_audio_row_df)).squeeze(0)
    return probs, topk_from_probs(probs, audio_classes, k=k)


In [50]:
# Optional mapping: audio label -> keyword or canonical common name
AUDIO_TO_CANON = {
    "trochilus": "hummingbird",
    "argentatus": "gull",
    "arvensis": "lark",
}

def agreement_check(image_topk_labels, audio_top1_label):
    aud = AUDIO_TO_CANON.get(audio_top1_label, None)
    if aud is None:
        return ("no_mapping", f"No mapping for audio label '{audio_top1_label}' yet.")

    aud = aud.lower()
    img_topk_pretty = [pretty_image_label(l).lower() for l, _ in image_topk_labels]

    if aud in img_topk_pretty[0]:
        return ("strong", f"‚úÖ Strong agreement: '{aud.title()}' matches IMAGE top-1")
    if any(aud in lab for lab in img_topk_pretty[:5]):
        return ("weak", f"üü° Weak agreement: '{aud.title()}' appears in IMAGE top-5")
    return ("none", f"‚ùå No agreement: '{aud.title()}' not found in IMAGE top-5")


In [51]:
def multimodal_assistant(img_path, X_audio_row_df, k=5):
    _, img_topk = predict_image_topk(img_path, k=k)
    _, aud_topk = predict_audio_topk(X_audio_row_df, k=k)

    print("=== Multimodal Bird ID (learning mode) ===\n")

    print("IMAGE top predictions:")
    for lab, p in img_topk:
        print(f"  ‚Ä¢ {pretty_image_label(lab):35s}  {p:.3f}")

    print("\nAUDIO top predictions:")
    for lab, p in aud_topk:
        print(f"  ‚Ä¢ {lab:35s}  {p:.3f}")

    print("\nAgreement check:")
    status, msg = agreement_check(img_topk, aud_topk[0][0])
    print(" ", msg)
    print("\nRecommendation:")
    if status == "strong":
      print(f"  ‚úÖ Recommend: {pretty_image_label(img_topk[0][0])} (strong multimodal agreement)")
    elif status == "weak":
      print(f"  üü° Leaning: {pretty_image_label(img_topk[0][0])} (audio agrees in top-5)")
    else:
      print("  ‚ö†Ô∏è No agreement ‚Äî consider another photo or audio clip, or trust the stronger model confidence.")

    return {"image_topk": img_topk, "audio_topk": aud_topk, "agreement": status}


In [52]:
import os, random

train_root = "/Users/saramcghee/.cache/kagglehub/datasets/veeralakrishna/200-bird-species-with-11788-images/versions/1/CUB_200_2011/CUB_200_2011/splits/train"

cls = random.choice(os.listdir(train_root))
img = random.choice(os.listdir(os.path.join(train_root, cls)))
img_path = os.path.join(train_root, cls, img)

print("Using image:", img_path)


Using image: /Users/saramcghee/.cache/kagglehub/datasets/veeralakrishna/200-bird-species-with-11788-images/versions/1/CUB_200_2011/CUB_200_2011/splits/train/066.Western_Gull/Western_Gull_0065_55728.jpg


In [53]:
import pandas as pd

audio_path = "/Users/saramcghee/.cache/kagglehub/datasets/fleanend/birds-songs-numeric-dataset/versions/3"
audio_test_df = pd.read_csv(os.path.join(audio_path, "test.csv"))

row = audio_test_df.sample(1, random_state=2)
X_audio_row_df = row.drop(columns=["id", "genus", "species"])

multimodal_assistant(img_path, X_audio_row_df, k=5)


=== Multimodal Bird ID (learning mode) ===

IMAGE top predictions:
  ‚Ä¢ Ring billed Gull                     0.160
  ‚Ä¢ Anna Hummingbird                     0.152
  ‚Ä¢ Black Tern                           0.139
  ‚Ä¢ Rufous Hummingbird                   0.087
  ‚Ä¢ Northern Flicker                     0.069

AUDIO top predictions:
  ‚Ä¢ trochilus                            0.985
  ‚Ä¢ argentatus                           0.002
  ‚Ä¢ europaea                             0.001
  ‚Ä¢ arvensis                             0.001
  ‚Ä¢ philomelos                           0.001

Agreement check:
  üü° Weak agreement: 'Hummingbird' appears in IMAGE top-5

Recommendation:
  üü° Leaning: Ring billed Gull (audio agrees in top-5)


{'image_topk': [('064.Ring_billed_Gull', 0.15985079109668732),
  ('067.Anna_Hummingbird', 0.15161050856113434),
  ('142.Black_Tern', 0.13938046991825104),
  ('069.Rufous_Hummingbird', 0.08712928742170334),
  ('036.Northern_Flicker', 0.068853460252285)],
 'audio_topk': [('trochilus', 0.9850702318825041),
  ('argentatus', 0.002067355445219387),
  ('europaea', 0.0014971913708846768),
  ('arvensis', 0.0014679615772921779),
  ('philomelos', 0.001460973796353702)],
 'agreement': 'weak'}