In [None]:
# Importing Libraries

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

import skimage
from skimage.transform import resize
from skimage.io import imread, imshow
from skimage.color import rgb2gray
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.utils import shuffle
from sklearn.metrics import roc_curve, auc

from sklearn import svm

import cv2

import torch
from torchvision import models
from torchvision import transforms
from torchvision.io import read_image, ImageReadMode
from PIL import Image

import sys

from sklearn.model_selection import train_test_split

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

from skimage.transform import resize
from sklearn import svm

import torch
import torchvision
from torchvision import transforms
from PIL import Image
import time

In [None]:
import re

def alex_net_model(X_test):

    '''

    This function calls a pretrained AlexNet model
    It runs the test images we created above through the network and generates predictions
    It then prints these classification predictions along with the probability percentage

    Args:

    X_test: numpy array of the test images

    Returns: None

    '''

    # initializing pretrained alexnet model
    model = models.alexnet(pretrained=True)
    # putting model in evaluation mode
    model.eval()

    # defining transforms that will be applied to the image

    preprocess = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(
      mean=[0.485, 0.456, 0.406],
      std=[0.229, 0.224, 0.225]
    )])

    # initializing empty list for storing predictions
    preds = []

    # calling classes used in alexnet model for labelling
    with open('/kaggle/input/document4/Doc4.txt') as f:
      classes = [line.strip() for line in f.readlines()]

      # iterate through images in Test_Images folder, transforming images, conerting to tensor, then
      # running through alexnet network and printing list of predictions along with image label

    for i in range(0, len(X_test)):

      test_image_path = f'/kaggle/working/Test_{i}.png'

      img = Image.open(test_image_path).convert('RGB')

      img_resized = img.resize((256,256))

      img_t = preprocess(img_resized)

      batch_t = torch.unsqueeze(img_t, 0)

      out = model(batch_t)

      _ , indices = torch.sort(out, descending = True)

      percentage = torch.nn.functional.softmax(out, dim = 1)[0]*100

      #print([(f'Test_{i}.png', classes[idx], percentage[idx].item()) for idx in indices[0][:1]])

      # print the predicted classes for every index in the list of predictions
      preds.append([(classes[idx]) for idx in indices[0][:1]][0])

    # stripping the numeric data from the string
    pattern = r'\d+:'

    preds = [re.sub(pattern, '', text) for text in preds]

    # Asked LLM to isolate the bird species from the AlexNet classes textfile

    bird_species = [
        'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch',
        'junco', 'indigo bunting', 'robin', 'bulbul', 'jay', 'magpie',
        'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture',
        'great grey owl', 'black grouse', 'ptarmigan', 'ruffed grouse',
        'prairie chicken', 'peacock', 'quail', 'partridge', 'African grey',
        'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal',
        'bee eater', 'hornbill', 'hummingbird', 'jacamar', 'toucan',
        'drake', 'red-breasted merganser', 'goose', 'black swan',
        'spoonbill', 'flamingo', 'little blue heron', 'American egret',
        'bittern', 'crane', 'limpkin', 'European gallinule', 'American coot',
        'bustard', 'ruddy turnstone', 'red-backed sandpiper', 'redshank',
        'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross'
    ]

    # Asked LLM to isolate the squirrel species from the AlexNet classes textfile

    squirrel_species = [
        'squirrel', 'fox squirrel', 'marmot'
    ]

    # going through and reclassifying predictions as either bird, squirrel, or other
    pred_classes = []

    for pred in preds:
      for bird in bird_species:
        if bird in pred:
          pred_classes.append('Bird')
          break
      else:
        for squirrel in squirrel_species:
          if squirrel in pred:
            pred_classes.append('Squirrel')
            break
        else:
          pred_classes.append('Other')

    # calculating accuracy, printing confusion matrix, and printing classification report
    CM = confusion_matrix(pred_classes, y_test)

    print("Accuracy:", round(accuracy_score(pred_classes, y_test), 2))
    
    print(classification_report(pred_classes, y_test))

    
    s = sns.heatmap(CM, annot=True, fmt='d', cmap = "Blues")
    s.set(xlabel='Squirrel', ylabel='Bird')
    
    print(s)
        
    #returning purely text predictions
    return preds