In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models
from torchvision.datasets import ImageNet

import torch.autograd as autograd

In [2]:
class ModelFeatureExtracter(nn.Module):
    def __init__(self, model, layer):
        super().__init__()
        self.model = model
        self.layer = layer
        self.model.eval()
        
    def extract_layer_output(self, x):
        outputs = []
        def extracting_hook(module, input, output):
            outputs.append(output)
        
        hook_handle = self.layer.register_forward_hook(extracting_hook)
        
        self.model(x)
        
        hook_handle.remove()
        return outputs[-1]
    
    def forward(self, x):
        with torch.no_grad():
            return self.model(x)

In [125]:
def get_matvec_transposed(input, model, vector):
    input.requires_grad = True
    layer_output = model.extract_layer_output(input)
    dotproduct = torch.matmul(layer_output.flatten(), vector.flatten())
    return autograd.grad(dotproduct, input, create_graph=True)[0]

In [112]:
def get_matvec(input, model, vector, hidden_dim):
    w = torch.zeros(hidden_dim, requires_grad=True)
    matvec_transposed = get_matvec_transposed(input, model, w)
    dotproduct = torch.matmul(matvec_transposed.flatten(), vector.flatten())
    return autograd.grad(dotproduct, w)[0]

In [126]:
model = torchvision.models.vgg16(pretrained=True)

In [127]:
layer_to_extract_from = model.features[5]

In [128]:
me = ModelFeatureExtracter(model, layer_to_extract_from)

In [129]:
test_x = torch.rand((1, 3, 224, 224))

In [155]:
input_dim = torch.prod(torch.tensor(test_x.shape[1:])).item()
hidden_dim = torch.prod(torch.tensor(me.extract_layer_output(test_x).shape[1:])).item()

In [156]:
tr_vector = torch.rand((1, hidden_dim))
vector = torch.rand((1, input_dim))

In [157]:
mv_tr = get_matvec_transposed(test_x, me, tr_vector)

In [158]:
mv_tr.shape

torch.Size([1, 3, 224, 224])

In [159]:
mv = get_matvec(test_x, me, vector, hidden_dim)

In [160]:
mv.shape

torch.Size([1605632])

In [161]:
mv_tr

tensor([[[[ -2.6177,  -0.3186,  18.1138,  ...,   6.7650,  11.3487,   2.3672],
          [-19.5007, -26.8231,  -7.9017,  ..., -22.7644,  -2.3856,  -5.9028],
          [  9.0767,  -1.6857,  53.3410,  ...,  -6.1996,  18.0566,  -0.4757],
          ...,
          [-13.4677,   0.3108, -14.4383,  ...,  14.2157,  -8.0355,  -0.2123],
          [ -5.9817,  -1.4515, -16.3002,  ..., -16.8298,   4.4615,  -3.2136],
          [ -2.3038,  -0.4171,   6.7171,  ...,   4.9648,   6.8502,  -9.1637]],

         [[ -0.6216,  -4.1958,  15.8530,  ...,  20.1880,  15.3076,   1.4236],
          [-29.2631, -49.8833, -34.2983,  ..., -15.0438, -15.7315, -21.0461],
          [  7.2467, -11.7496,  52.0045,  ...,   0.2439,  -1.6616, -18.6072],
          ...,
          [ -9.0355,   9.6693, -43.3694,  ...,   2.7809, -11.6049,   4.7787],
          [  3.2395,  13.7466, -36.2825,  ..., -35.4500,  14.1948,   3.4723],
          [  0.9470,   6.8685,   3.8797,  ...,  -5.0725,  14.2642, -11.2471]],

         [[  5.1349,   5.4529,

In [163]:
mv[1000:2000]

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.9365, -1.5801,
         1.1569,  0.0000,  0.0000,  0.0000,  0.0000, -1.6761,  0.0000,  0.0000,
         0.0000,  0.0000,  0.8053,  0.0000,  1.7712,  0.6996,  0.0000,  0.0000,
         0.0000,  0.0000, -2.0480,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000, -0.7948,  0.0000,  0.0000,  0.0000, -4.3773,
         0.0000,  0.0000,  0.0000,  0.0000, -2.0755,  0.0000,  0.0000,  0.0000,
        -0.5874,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         1.2734, -1.9622,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.6203,
         0.0000,  0.0000,  0.0000,  0.0000, -1.8764,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.7528,  0.3713,  0.0000,
         0.0000, -1.2745,  0.0000,  0.0000,  0.0000,  0.0000,  0.4725,  0.0000,
         0.0000,  0.6001,  0.0000,  0.00