In [30]:
import torch.nn as nn

import data
from models import CAMManager, ClassifierHead, CNNBackbone
from pre_training import Trainer


def findLastConvLayer(model):
    """
    Find the last convolutional layer in a model.

    Args:
        model: The model to search through

    Returns:
        The last convolutional layer found, or None if not found
    """
    last_conv = None

    # Check if model is Sequential
    if isinstance(model, nn.Sequential):
        # Check each module in the sequential container
        for module in model:
            result = findLastConvLayer(module)
            if result is not None:
                last_conv = result

    # Check if model has features attribute (like our backbone)
    elif hasattr(model, "features") and isinstance(model.features, nn.Sequential):
        for module in model.features:
            if isinstance(module, nn.Conv2d):
                last_conv = module

    # Check if the model itself is a Conv2d
    elif isinstance(model, nn.Conv2d):
        return model

    return last_conv


trainer = Trainer()
backbone = CNNBackbone()
head = ClassifierHead()
trainer.set_model(backbone, [head], "checkpoints/cnn_species_checkpoint_epoch10.pt")
trainer.load_checkpoint("checkpoints/cnn_species_checkpoint_epoch10.pt")

_, _, loader = data.create_dataloaders(target_type=["species", "segmentation"])

  checkpoint = torch.load(checkpoint_path)


Checkpoint loaded from checkpoints/cnn_species_checkpoint_epoch10.pt (epoch 10)
Images already downloaded: oxford_pet_data/images.tar.gz
Annotations already downloaded: oxford_pet_data/annotations.tar.gz
Dataset prepared with 37 classes.
Dataset split complete: training (70.0%), validation (15.0%), testing (15.0%)
Images already downloaded: oxford_pet_data/images.tar.gz
Annotations already downloaded: oxford_pet_data/annotations.tar.gz
Dataset prepared with 37 classes.
Dataset split complete: training (70.0%), validation (15.0%), testing (15.0%)
Images already downloaded: oxford_pet_data/images.tar.gz
Annotations already downloaded: oxford_pet_data/annotations.tar.gz
Dataset prepared with 37 classes.
Dataset split complete: training (70.0%), validation (15.0%), testing (15.0%)


In [31]:
model = nn.Sequential(backbone, head)
layer = findLastConvLayer(model)

In [32]:
manager = CAMManager(model, loader, target_type="species", target_layer=layer)

In [34]:
manager.dataset

<torch.utils.data.dataset.TensorDataset at 0x306192e70>