In [None]:
from fastai.vision.all import *
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# Load the trained model
learn = load_learner('model.pkl')

# Summary - which layer would you choose for embedding?
learn.model

In [2]:
class SaveOutput:
    def __init__(self):
        self.outputs = []
        
    def __call__(self, module, module_input, module_output):
        self.outputs.append(module_output)
        
    def clear(self):
        self.outputs = []

# Instantiate the hook
save_output = SaveOutput()

layer = learn.model[1][-2]  # This should be the Linear layer with 512 outputs

# Register the hook on the specific layer
hook = layer.register_forward_hook(save_output)

# Define a function to extract embeddings from a single image
def get_embedding(img_path):
    # Load the image
    img = PILImage.create(img_path)
    
    # Ensure it's transformed in the same way as during training
    img = learn.dls.test_dl([img]).one_batch()[0]

    # Run a forward pass to trigger the hook and extract embeddings
    save_output.clear()  # Clear previous outputs
    learn.model.eval()
    with torch.no_grad():
        _ = learn.model(img)

    # The hook's output is now stored in save_output.outputs[0]
    embedding = save_output.outputs[0].squeeze().numpy()

    return embedding

# Test the function with an image
img_path = './bears/test/blackbear/image1.jpg'
embeddings = get_embedding(img_path)

print("Embedding shape:", embeddings.shape)  # Should be (512,)
print(embeddings)

# Cleanup - Remove the hook when done to prevent memory issues
hook.remove()


Embedding shape: (512,)
[-6.19231761e-01 -6.36461854e-01  1.18142045e+00 -4.07315731e-01
 -6.81153297e-01 -6.37736738e-01 -5.98264873e-01 -3.74436527e-01
  3.50174308e-01 -8.12551826e-02 -6.51802063e-01 -6.34675264e-01
 -3.59858781e-01 -5.22592187e-01 -6.09737277e-01 -4.81929153e-01
  9.04844999e-01 -5.92425406e-01  7.67802715e-01 -6.55589759e-01
 -1.12197906e-01  8.12714398e-01 -5.72820604e-01 -6.30509079e-01
 -6.25042200e-01  4.98911113e-01 -5.17139472e-02 -6.38888538e-01
  8.99106324e-01 -4.48392034e-01 -5.46490133e-01  2.51619726e-01
 -5.17981410e-01 -6.59129381e-01  7.57388175e-01 -6.49257362e-01
 -6.30341768e-01 -5.82361042e-01  3.28413874e-01 -2.71851867e-01
 -5.95298111e-01 -5.48532903e-01 -6.59836173e-01 -5.50057411e-01
 -5.66546679e-01 -6.04160309e-01 -6.16082668e-01 -1.34915948e-01
  1.30697265e-01 -6.15760326e-01 -6.42432272e-01 -6.95860863e-01
 -6.29767299e-01 -6.62366748e-01  3.98956120e-01  2.79097348e-01
  1.62113619e+00 -6.23016179e-01 -5.93878210e-01 -6.31720483e-01
 