In [1]:
#convert

# babilim.model.imagenet

> An implemntation of various imagenet models.

In [2]:
#export
from babilim.core.annotations import RunOnlyOnce
from babilim.core.module_native import ModuleNative

In [3]:
#export
class ImagenetModel(ModuleNative):
    def __init__(self, encoder_type, only_encoder=False, pretrained=False):
        """
        Create one of the iconic image net models in one line.
        Allows for only using the encoder part.

        This model assumes the input image to be 0-255 (8 bit integer) with 3 channels.

        :param encoder_type: The encoder type that should be used. Must be in ("vgg16", "vgg16_bn", "vgg19", "vgg19_bn", "resnet50", "resnet101", "resnet152", "densenet121", "densenet169", "densenet201", "inception_v3", "mobilenet_v2")
        :param only_encoder: Leaves out the classification head for VGG16 leaving you with a feature encoder.
        :param pretrained: If you want imagenet weights for this network.
        """
        super().__init__()
        self.only_encoder = only_encoder
        self.pretrained = pretrained
        self.encoder_type = encoder_type

    @RunOnlyOnce
    def _build_tf(self, image):
        raise NotImplementedError()

    def _call_tf(self, image):
        raise NotImplementedError()

    @RunOnlyOnce
    def _build_pytorch(self, image):
        import torch
        from torchvision.models import vgg16_bn, vgg16_bn, vgg19, vgg19_bn, resnet50, resnet101, resnet152, densenet121, densenet169, densenet201, inception_v3, mobilenet_v2
        from torch.nn import Sequential
        model = None
        if self.encoder_type == "vgg16":
            model = vgg16
        elif self.encoder_type == "vgg16_bn":
            model = vgg16_bn
        elif self.encoder_type == "vgg19":
            model = vgg19
        elif self.encoder_type == "vgg19_bn":
            model = vgg19_bn
        elif self.encoder_type == "resnet50":
            model = resnet50
        elif self.encoder_type == "resnet101":
            model = resnet101
        elif self.encoder_type == "resnet152":
            model = resnet152
        elif self.encoder_type == "densenet121":
            model = densenet121
        elif self.encoder_type == "densenet169":
            model = densenet169
        elif self.encoder_type == "densenet201":
            model = densenet201
        elif self.encoder_type == "inception_v3":
            model = inception_v3
        elif self.encoder_type == "mobilenet_v2":
            model = mobilenet_v2
        else:
            raise RuntimeError("Unsupported encoder type.")
        
        if self.only_encoder:
            self.model = Sequential(*list(model(pretrained=self.pretrained).features))
        else:
            self.model = model(pretrained=self.pretrained)
        
        if torch.cuda.is_available():
            self.model = self.model.to(torch.device(self.device))
        
        # Just in case, make the image a float tensor
        image = image.float()

        # Standardization values from torchvision.models documentation
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # Create tensors for a 0-255 value range image.
        self.mean = torch.as_tensor([i * 255 for i in mean], dtype=image.dtype, device=image.device)
        self.std = torch.as_tensor([j * 255 for j in std], dtype=image.dtype, device=image.device)

    def _call_pytorch(self, image):
        # Just in case, make the image a float tensor and apply variance correction.
        image = image.float()
        image.sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None])

        return self.model(image)

In [4]:
from babilim.core.tensor import Tensor
import numpy as np

encoder = ImagenetModel("vgg16_bn", only_encoder=True, pretrained="imagenet")
fake_image_batch_pytorch = Tensor(data=np.zeros((1, 3, 256, 256), dtype=np.float32), trainable=False)
print(fake_image_batch_pytorch.shape)
result = encoder(fake_image_batch_pytorch)
print(result.shape)

(1, 3, 256, 256)
(1, 512, 8, 8)


In [5]:
from babilim.core.tensor import Tensor
import numpy as np

model = ImagenetModel("resnet50", only_encoder=False, pretrained="imagenet")
fake_image_batch_pytorch = Tensor(data=np.zeros((1, 3, 256, 256), dtype=np.float32), trainable=False)
print(fake_image_batch_pytorch.shape)
result = model(fake_image_batch_pytorch)
print(result.shape)

(1, 3, 256, 256)
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to C:\Users\fuerst/.cache\torch\hub\checkpoints\resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:23<00:00, 4.41MB/s]
(1, 1000)
