In [1]:
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel


class BERTEmbedder(torch.nn.Module):
    def __init__(self, device=None):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.model_bert = BertModel.from_pretrained("bert-base-uncased")

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_bert = self.model_bert.to(self.device)
        self.model_bert.eval()

        for p in self.model_bert.parameters():
            p.requires_grad = False

    def forward(self, input_text):
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            padding=True,
            truncation=True
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        outputs = self.model_bert(**inputs)

        token_embeddings = outputs.last_hidden_state
        attention_mask = inputs["attention_mask"]

        mask = attention_mask.unsqueeze(-1).float()
        sentence_embeddings = (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)

        embeddings = F.normalize(sentence_embeddings, p=2, dim=-1)
        return embeddings


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn.functional as F
import open_clip


class OpenClipVitEmbedder(torch.nn.Module):
    def __init__(self, device=None):
        super().__init__()
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            model_name="ViT-B-32",
            pretrained="openai"
        )

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.model.eval()

        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, image_tensor):
        image_tensor = image_tensor.to(self.device)
        image_features = self.model.encode_image(image_tensor)
        image_features = F.normalize(image_features, p=2, dim=-1)
        return image_features