In [3]:
%load_ext autoreload
%autoreload 2
import os
import pandas as pd
import torch
import sys
sys.path.append('/home/siddsuresh97/Projects/vision_robustness_using_semantic_norms')
from src.utils import cogsci_helper_functions, vss_helper_functions

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# zsd

In [4]:
# load leuven full
leuven_full = pd.read_csv('../../../data/leuven_full.csv')
leuven_bce = pd.read_csv('../../../data/leuven_bce.csv')

leuven_bce = leuven_bce.set_index('Unnamed: 0')

# take out the columns 'appears_in_comics' and 'breeds_rapidly' from leuven_full
leuven_full = leuven_full.drop(columns=['appears_in_comics', 'breeds_rapidly'])
leuven_full = leuven_full.set_index('Unnamed: 0')
leuven_full = leuven_full.applymap(lambda x: 1 if x >= 1 else 0)

In [160]:
loss = 'bce'
config = 'scratch'
epoch = 299
img_dir = 'zsd_images'
model_weights_root_dir = '/media/external/siddsuresh97/model_weights/vss_exp_after_cogsci'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_bce = cogsci_helper_functions.load_model(f"alexnet_bce_scratch", epoch, model_weights_root_dir, device)
model_ce = cogsci_helper_functions.load_model(f"alexnet_ce_scratch", epoch, model_weights_root_dir, device)
model_hybrid = cogsci_helper_functions.load_model(f"alexnet_hybrid_scratch", epoch, model_weights_root_dir, device)


/media/external/siddsuresh97/model_weights/vss_exp_after_cogsci/alexnet/bce/alexnet_bce_scratch_states.pkl
Loading checkpoint from /media/external/siddsuresh97/model_weights/vss_exp_after_cogsci/alexnet/bce/alexnet_bce_scratch_states.pkl
/media/external/siddsuresh97/model_weights/vss_exp_after_cogsci/alexnet/ce/alexnet_ce_scratch_states.pkl
Loading checkpoint from /media/external/siddsuresh97/model_weights/vss_exp_after_cogsci/alexnet/ce/alexnet_ce_scratch_states.pkl
/media/external/siddsuresh97/model_weights/vss_exp_after_cogsci/alexnet/hybrid/alexnet_hybrid_scratch_states.pkl
Loading checkpoint from /media/external/siddsuresh97/model_weights/vss_exp_after_cogsci/alexnet/hybrid/alexnet_hybrid_scratch_states.pkl


In [166]:
imageloader = vss_helper_functions.get_image_loader(img_dir)

bce_final_layer_outputs = []
ce_final_layer_outputs = []
hybrid_final_layer_outputs = []
targets = []
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(imageloader):
        data =  data.to(device)
        bce_final_layer_outputs.append(torch.sigmoid(model_bce(data)).cpu())
        ce_final_layer_outputs.append(torch.sigmoid(model_ce(data)).cpu())
        hybrid_final_layer_outputs.append(torch.sigmoid(model_hybrid(data)).cpu())
        targets.append(target.cpu())
# reduce the list of tensors to a single tensor

bce_final_layer_outputs = torch.cat(bce_final_layer_outputs)
ce_final_layer_outputs = torch.cat(ce_final_layer_outputs)
hybrid_final_layer_outputs = torch.cat(hybrid_final_layer_outputs)

targets = torch.cat(targets)


def get_top_k_similarities(final_layer_outputs, leuven_full, type, k=10):
    # find cosine sim of final_layer_outputs[0] with all columns of leuven_full.T
    # pick the top 5 most similar columns and pick the corresponding column names
    if type == 'bce':
        top_5_similarities = []
        for i in range(len(final_layer_outputs)):
            similarities = []
            for j in range(leuven_full.T.shape[1]):
                similarities.append(torch.nn.CosineSimilarity(dim=1, eps=1e-6)(final_layer_outputs[i], torch.Tensor(leuven_full.T.iloc[:, j].values.reshape(1, -1))))    
            top_5_similarities.append(torch.Tensor(similarities).topk(k, largest=True, sorted=True))
            # get the names of the top 5 columns
        top_5_column_names = []
        for i in range(len(top_5_similarities)):
            top_5_column_names.append(leuven_full.T.columns[top_5_similarities[i][1].numpy()])
        top_5_column_names = pd.DataFrame(top_5_column_names)
    else:
        top_5_similarities = []
        for i in range(len(final_layer_outputs)):
            top_5_similarities.append(final_layer_outputs[i].topk(k, largest=True, sorted=True))
        top_5_column_names = []
        for i in range(len(top_5_similarities)):
            top_5_column_names.append(leuven_full.T.columns[top_5_similarities[i][1].numpy()])
        top_5_column_names = pd.DataFrame(top_5_column_names)

    top_5_column_names['target'] = targets
    # find the names of imageloader.dataset.targets and not the indices
    top_5_column_names['target'] = top_5_column_names['target'].apply(lambda x: imageloader.dataset.classes[x])

    return top_5_column_names

bce_top_5_column_names = get_top_k_similarities(bce_final_layer_outputs, leuven_full, k = 10, type='bce')
ce_top_5_column_names = get_top_k_similarities(ce_final_layer_outputs, leuven_full, k = 10, type='ce')
hybrid_top_5_column_names = get_top_k_similarities(hybrid_final_layer_outputs, leuven_full, k = 10, type='hybrid')

In [167]:
bce_top_5_column_names

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,target
0,lizard,iguana,salamander,chameleon,gecko,monitor_lizard,snake,python,blindworm,boa,caiman
1,beetle,maybug,cricket,grashopper,cockroach,centipede,ladybug,fruit_fly,fly,caterpillar,ladybug
2,whale,orca,dolphin,shark,sperm_whale,swordfish,piranha,squid,carp,stickleback,orca
3,car,bus,truck.1,truck,taxi,delivery_van,jeep,motorbike.1,airplane,tractor,taxi
4,towel,place_mat,apron,plate,mug,bowl,colander,glass,fork,sieve,trout
5,sheep,deer,pig,llama,cow,horse,donkey,kangaroo,zebra,giraffe,woodpecker


In [168]:
ce_top_5_column_names

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,target
0,squid,bee_,parrot,duck,bison,anchovy,crow,llama,cod,elephant,caiman
1,bison,rhinoceros,donkey,zebra,goldfish,sheep,blackbird,kangaroo,cod,stickleback,ladybug
2,whale,bat,bee_,eel,vulture,flatfish,elephant,squid,squirrel,heron,orca
3,cow,cat,sardine,ray,deer,plaice,cuckoo,monkey,falcon,dolphin,taxi
4,turkey,crow,flatfish,parrot,elephant,herring,horsefly,trout,penguin,beaver,trout
5,parakeet,tiger,hippopotamus,pig,stork,trout,hamster,kip,fox,seagull,woodpecker


In [169]:
hybrid_top_5_column_names

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,target
0,bison,parrot,donkey,crow,cod,orca,heron,squid,beaver,rhinoceros,caiman
1,donkey,rhinoceros,penguin,bison,dove,sole,pike,sheep,zebra,kangaroo,ladybug
2,whale,hedgehog,orca,elephant,bat,beaver,eel,hamster,flatfish,heron,orca
3,cow,sardine,cat,ray,deer,giraffe,sperm_whale,cuckoo,monkey,plaice,taxi
4,duck,parrot,squid,crow,rabbit,flatfish,pike,cod,sheep,beaver,trout
5,parakeet,swordfish,pig,rabbit,stork,seagull,trout,peacock,kip,fox,woodpecker


In [136]:
a = list(set(leuven_full.index) - set(leuven_bce.index))
for i in a:
    print(i)

orca
bazooka
butterfly
carriage
rifle
tambourine
filling-knife
apron
vacuum_cleaner
grenade
chickadee
hat
rope
teaspoon
knuckle_dusters
horsefly
scarf
sparrow
fridge
cabary
flee
pelican
cobra
microwave
bra
flatfish
go-cart
skateboard
robin
bow
carp
rocket
sole
rhinoceros
dragonfly
banjo
fork
swordfish
stork
goldfish
louse
eel
viger
peacock
herring
shirt
plaice
pike
pistol
dove
kip
ostrich
salmon
Zeppelin
can_opener
crowbar.1
trout
colander
whisk
python
whip
spanner
swallow
pyjamas
belt
contrabass
bee_
mixer
socks
jumper
harpsichord
taxi
pheasant
cuckoo
shorts
motorbike
squid
oven
file
cricket
sweater
panpipe
cod
cap
flute
leech
trombone
tram
magpie
grashopper
boots
ladybug
scales
drum_set
harp
subway_train
machinegun
piranha
hippopotamus
scissor
duck
fly
owl
fox
seagull
parrot
bass_guitar
monitor_lizard
accordion
club
tank
blackbird
plane
nail
oilcan
trumpet
parakeet
cart
eagle
triangle
place_mat
electric_kettle
cello
tie
pickaxe
woodpecker
piccolo
giraffe
polar_bear
caiman
hat.1
boa
v

# adversarial attacks