In [None]:
import torch

In [None]:
from src.networks.pathology import CoxResNet


model = CoxResNet(input_channels=8)

# load model from file
model.load_state_dict(torch.load('/data2/projects/DigiStrudMed_sklein/huggingface/2023-04-05_None_no_scheduler_CI.pth'))
# model.load_state_dict(torch.load('/data2/projects/DigiStrudMed_sklein/huggingface/2023-04-05_None_no_scheduler_CPH.pth'))


In [None]:
import torch
import torchvision.models as models
from captum.attr import LayerIntegratedGradients, MultiInputGradientShap
from captum.attr import visualization as viz

# Define your multimodal model that takes in a dictionary of inputs
class MultimodalModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.image_model = models.resnet18(pretrained=True)
        self.fc1 = torch.nn.Linear(1000+1, 512)
        self.fc2 = torch.nn.Linear(512, 10)

    def forward(self, inputs):
        img = inputs['image']
        num = inputs['numerical']
        img_feat = self.image_model(img)
        combined_feat = torch.cat((img_feat, num), dim=1)
        out = torch.relu(self.fc1(combined_feat))
        out = self.fc2(out)
        return out

# Create an instance of your multimodal model
model = MultimodalModel()

# Define your input dictionary
inputs = {
    'image': torch.randn(1, 3, 224, 224), # batch size x channels x height x width
    'numerical': torch.tensor([0.5]) # a single numerical value
}

# Create an instance of the LayerIntegratedGradients algorithm for the image input
lig = LayerIntegratedGradients(model, model.image_model.conv1)

# Compute the attribution scores for the image input
attr_image = lig.attribute(inputs['image'], target=0)

# Create an instance of the MultiInputGradientShap algorithm for the multimodal inputs
migs = MultiInputGradientShap(model)

# Compute the attribution scores for the multimodal inputs
attr_multimodal = migs.attribute(inputs=(inputs['image'], inputs['numerical']), target=0)

# Visualize the attribution scores for the image input
viz.visualize_image_attr(attr_image.squeeze(), inputs['image'].squeeze())

# Visualize the attribution scores for the multimodal inputs
viz.visualize_image_attr_multiple(attr_multimodal[0].squeeze(), attr_multimodal[1].squeeze(), ['image', 'numerical'], inputs['image'].squeeze())
