In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from sklearn.manifold import MDS
from PIL import Image
import os
import pingouin as pg
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import orthogonal_procrustes
from nltk.corpus import wordnet as wn
import json
import pickle
import gc

import ecoset
import categorization as cat
import train
import utils

# Seed for reproducibility
np.random.seed(2023)
tf.random.set_seed(2023)

In [None]:
# Load example model
weightPath = f"./models/AlexNet/ecoset_training_seeds_01_to_10/training_seed_01/model.ckpt_epoch89"
model = ecoset.make_alex_net_v2(weights_path=weightPath)
model.summary()

Weights from ./models/AlexNet/ecoset_training_seeds_01_to_10/training_seed_01/model.ckpt_epoch89 loaded successfully.
Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 conv1 (Conv2D)              (None, 54, 54, 64)        23296     
                                                                 
 pool1 (MaxPooling2D)        (None, 26, 26, 64)        0         
                                                                 
 conv2 (Conv2D)              (None, 26, 26, 192)       307392    
                                                                 
 pool2 (MaxPooling2D)        (None, 12, 12, 192)       0         
                                                                 
 conv3 (Conv2D)              (None, 12, 12, 384)       663936    
       

In [None]:
# Nodes of interest for ecoset
ecosetCatNodes = {"dog": 8, "bird": 25, "car": 2, "bus": 27}
# Get fc8 weights
fc8 = model.get_layer("fc8")
fc8 = fc8.get_weights()[0]

# Squeeze
fc8 = np.squeeze(fc8)

# Get median
fc8Med = np.median(fc8)

# Dog weight indices
dogWeights = fc8[:, ecosetCatNodes["dog"]]
ecosetDogIdxs = np.where(dogWeights > 0)[0]

# Bird weight indices
birdWeights = fc8[:, ecosetCatNodes["bird"]]
ecosetBirdIdxs = np.where(birdWeights > 0)[0]

# Car weight indices
carWeights = fc8[:, ecosetCatNodes["car"]]
ecosetCarIdxs = np.where(carWeights > 0)[0]

# Bus weight indices
busWeights = fc8[:, ecosetCatNodes["bus"]]
ecosetBusIdxs = np.where(busWeights > fc8Med)[0]

In [None]:
# Create a dictionary of idxs for each category
ecosetIdxs = {
    "animal": [ecosetDogIdxs, ecosetBirdIdxs],
    "vehicle": [ecosetCarIdxs, ecosetBusIdxs],
    "dog": [ecosetDogIdxs],
    "bird": [ecosetBirdIdxs],
    "car": [ecosetCarIdxs],
    "bus": [ecosetBusIdxs],
    "CUB_002.Laysan_Albatross": [ecosetBirdIdxs],
    "CUB_048.European_Goldfinch": [ecosetBirdIdxs],
    "CUB_094.White_breasted_Nuthatch": [ecosetBirdIdxs],
    "CUB_136.Barn_Swallow": [ecosetBirdIdxs],
    "CUB_199.Winter_Wren": [ecosetBirdIdxs],
}

# Pickle
with open("ecosetIdxs.pkl", "wb") as f:
    pickle.dump(ecosetIdxs, f)

In [None]:
def find_nested_synsets(targetSynset, synsets):
    """
    Return the index and synsets from the list synsets that is nested under
    the targetSynset
    """
    nestedIdxs, nestedSynsets = [], []
    for i, synset in enumerate(synsets):
        lowest_common_hypernyms = synset.lowest_common_hypernyms(targetSynset)
        if targetSynset in lowest_common_hypernyms:
            nestedIdxs.append(i)
            nestedSynsets.append(synset)

    return nestedIdxs, nestedSynsets

In [None]:
# Load imagenet class index
with open("./imagenet_class_index.json", "r") as f:
    imagenetCats = json.load(f)

imagenetSynsets = [
    wn.synset_from_pos_and_offset("n", int(value[0][1:]))
    for value in imagenetCats.values()
]

# Superordinate
imagenetAnimalIdxs, animalSynsets = find_nested_synsets(
    wn.synset("animal.n.01"), imagenetSynsets
)
imagenetVehicleIdxs, vehicleSynsets = find_nested_synsets(
    wn.synset("vehicle.n.01"), imagenetSynsets
)

# Basic
imagenetDogIdxs, dogSynsets = find_nested_synsets(
    wn.synset("dog.n.01"), imagenetSynsets
)
imagenetBirdIdxs, birdSynsets = find_nested_synsets(
    wn.synset("bird.n.01"), imagenetSynsets
)
imagenetCarIdxs, carSynsets = find_nested_synsets(
    wn.synset("car.n.01"), imagenetSynsets
)
imagenetBusIdxs, busSynsets = find_nested_synsets(
    wn.synset("bus.n.01"), imagenetSynsets
)

# Sub is just the basic ones

In [None]:
imagenetIdxs = {
    "animal": [imagenetAnimalIdxs],
    "vehicle": [imagenetVehicleIdxs],
    "bird": [imagenetBirdIdxs],
    "dog": [imagenetDogIdxs],
    "car": [imagenetCarIdxs],
    "bus": [imagenetBusIdxs],
    "CUB_002.Laysan_Albatross": [imagenetBirdIdxs],
    "CUB_048.European_Goldfinch": [imagenetBirdIdxs],
    "CUB_094.White_breasted_Nuthatch": [imagenetBirdIdxs],
    "CUB_136.Barn_Swallow": [imagenetBirdIdxs],
    "CUB_199.Winter_Wren": [imagenetBirdIdxs],
}

# Pickle
with open("imagenetIdxs.pkl", "wb") as f:
    pickle.dump(imagenetIdxs, f)