In [4]:
#| default_exp target_layer

In [5]:
#| export
import torchvision.models as models

def get_model_and_target_layer(model_name):
    """
    Get the model and its target layer based on the model name.

    Args:
        model_name (str): The name of the model (e.g., 'ResNet50', 'ResNet18', 'VGG16', 'InceptionV3').

    Returns:
        torch.nn.Module: The model instance.
        str or torch.nn.Module: The target layer for the model.

    Usage:
    - Call this function to retrieve a model and its corresponding target layer based on the model name.
    """
    if model_name == "ResNet50":
        model = models.resnet50(pretrained=True).eval()
        target_layer = model.layer4[-1]
    elif model_name == "ResNet18":
        model = models.resnet18(pretrained=True).eval()
        target_layer = model.layer4[-1]
    elif model_name == "VGG16":
        model = models.vgg16(pretrained=True).eval()
        target_layer = "features.28"
    elif model_name == "InceptionV3":
        model = models.inception_v3(pretrained=True).eval()
        target_layer = model.Mixed_7c.branch3x3dbl_3b
    else:
        raise ValueError("Model not recognized. Please specify one of the supported model names.")

    return model, target_layer


In [6]:
#| hide
# os.chdir("/project/validating_attribution_techniques/shardul/api_notebooks/")
from nbdev.export import nb_export
nb_export('target_layer.ipynb', '/project/validating_attribution_techniques/commons/api')