In [1]:
import gc

import joblib
import numpy as np
import pandas as pd
import safetensors
import torch

from service_embedder import MobileNetWrapper, CountVectorizerWrapper, BertWrapper
from service_embedder import image_data_to_tensor

In [2]:
ASSETS_DIR = '../assets'

In [3]:
df = pd.read_pickle(f'{ASSETS_DIR}/movies.pkl.zst', compression='zstd')

## MobileNet

In [4]:
mobilenet = MobileNetWrapper()

In [5]:
df_poster = df.dropna(subset=['poster'])[['poster']]
df_poster.info()

<class 'pandas.core.frame.DataFrame'>
Index: 26753 entries, 5 to 4530184
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   poster  26753 non-null  object
dtypes: object(1)
memory usage: 418.0+ KB


In [6]:
np.save(f'{ASSETS_DIR}/poster_ids.npy', df_poster.index.to_numpy())

In [7]:
posters = df_poster['poster'].apply(image_data_to_tensor).to_list()

In [12]:
batch_size = 1000
mobilenet_embeddings = []

with torch.inference_mode():
    for i in range(0, len(posters), batch_size):
        batch = torch.cat([p.unsqueeze(0) for p in posters[i:i + batch_size]]).to('mps')
        mobilenet_embeddings.append(mobilenet._mobilenet(batch))
        gc.collect()
        torch.mps.empty_cache()

mobilenet_embeddings = torch.cat(mobilenet_embeddings)
mobilenet_embeddings.shape

In [17]:
safetensors.torch.save_file(
    {'embedding': mobilenet_embeddings},
    f'{ASSETS_DIR}/embeddings_mobilenet.safetensors'
)

## CountVectorizer

In [4]:
df_plot = df.dropna(subset=['plot'])[['plot']]
df_plot.info()

<class 'pandas.core.frame.DataFrame'>
Index: 44491 entries, 1 to 7158814
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   plot    44491 non-null  object
dtypes: object(1)
memory usage: 695.2+ KB


In [5]:
np.save(f'{ASSETS_DIR}/plot_ids.npy', df_plot.index.to_numpy())

In [6]:
plots = df_plot['plot'].to_list()

In [7]:
CountVectorizerWrapper.train()
count_vectorizer = CountVectorizerWrapper()

In [9]:
count_embeddings = count_vectorizer.get_embedding(plots)

In [13]:
joblib.dump(count_embeddings, f'{ASSETS_DIR}/embeddings_count_vectorizer.joblib.gz')

['../assets/embeddings_count_vectorizer.joblib.gz']

## Bert

In [10]:
bert = BertWrapper()

In [18]:
bert_embeddings = bert.get_embedding(df_plot['plot'].tolist())
bert_embeddings = torch.cat([e.unsqueeze(0) for e in bert_embeddings])
bert_embeddings.shape

In [None]:
safetensors.torch.save_file(
    {'embedding': bert_embeddings},
    f'{ASSETS_DIR}/embeddings_cls_bert.safetensors'
)