In [1]:
import torch
import torch.nn as nn
from torchvision.models import  efficientnet_v2_s, EfficientNet_V2_S_Weights
from torchvision import transforms as T

In [2]:
base_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.DEFAULT)
model = nn.Sequential(base_model.features,
                       base_model.avgpool)
model.eval()

Sequential(
  (0): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
           

In [3]:
transformations = nn.Sequential(
                                T.Resize(384, antialias=True, interpolation=T.InterpolationMode.BICUBIC),
                                T.CenterCrop(384)
                                )

fullmodel = nn.Sequential(
    transformations,
    base_model.features,
    base_model.avgpool
)

fullmodel

Sequential(
  (0): Sequential(
    (0): Resize(size=384, interpolation=bicubic, max_size=None, antialias=True)
    (1): CenterCrop(size=(384, 384))
  )
  (1): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 

In [5]:
import urllib
import numpy as np
import cv2

req = urllib.request.urlopen('http://assets.myntassets.com/v1/images/style/properties/504a27acee8e6d89d7eec2fae5b5ef01_images.jpg')
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
img

array([[[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       ...,

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]]

In [6]:
transform = T.Compose([
              T.ToTensor()
            ])
im_tensor = transform(img)
im_tensor.size()

torch.Size([3, 2400, 1800])

In [27]:
fullmodel(im_tensor.unsqueeze(0)).detach().cpu().numpy().reshape(1, -1).shape

(1, 1280)

In [4]:
torch.save(fullmodel, '../src/backend/ML-models/feature-extractor.pth')

In [29]:
import mlflow

mlflow.set_tracking_uri('../mlflow')

if mlflow.get_experiment_by_name('image-based recsys') is None:
    mlflow.set_experiment('image-based recsys')

In [31]:
with mlflow.start_run(experiment_id=340802484178285161):
    mlflow.log_param('extractor', 'EfficientNet_V2_S')
    mlflow.log_param('dataset_hash', '44c73c23')
    mlflow.log_param('output_shape', 1280)
    mlflow.pytorch.log_model(fullmodel, 'EfficientNet-V2-S')