-
Notifications
You must be signed in to change notification settings - Fork 542
Description
Hi,
Thank you for supporting this nice library. :)
Now I'm trying to apply various XAI modules to my seq2seq model.
To analyze seq2seq model, I'm using additional forward_func for wrapping my model. (please see below)
def forward_func(inputs, seq_idx):
# to match dimensions between target and pred.
pred = model(inputs) # (BN, seq_N, Class_N)
single_pred = pred[:,seq_idx, :] # (BN, class_N)
return single_pred
but for LRP, it doesn't work now.
Traceback (most recent call last):
File "evaluate.py", line 365, in
evaluate()
File "evaluate.py", line 328, in evaluate
pred, label, conf, info = validate(test_loader, model, args.arch_type, save_path_root=eval_path)
File "evaluate.py", line 160, in validate
attr = xai.get_attributes(data, label, pred, info) # (BN, ouput_seq_N, input_seq_N, features, sample_len)
File "/home/tylee/sleepbot/sleepbot/utils/XAI.py", line 128, in get_attributes
target=target_tensor[:,seq_idx])
File "/home/tylee/captum/captum/attr/_core/lrp.py", line 153, in attribute
self._original_state_dict = self.model.state_dict()
AttributeError: 'function' object has no attribute 'state_dict'
Is there any way to apply forward func to LRP module?
Thank you :)