-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_embedder.py
30 lines (28 loc) · 1.12 KB
/
image_embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from facenet_pytorch import InceptionResnetV1
from torchvision import transforms
import time
class ImageEmbedder:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = InceptionResnetV1(pretrained = "vggface2", device= device).eval()
transforms = transforms.Compose([
transforms.ToTensor()
])
@classmethod
#PIL image
def embed(cls, image):
image_tensor = cls.transforms(image).to(cls.device)
embeddings = cls.model(image_tensor.unsqueeze(dim = 0))
return embeddings.cpu().detach().numpy()[0]
@classmethod
#PIL images
def embeds(cls, image_list):
images_tensor = torch.stack([cls.transforms(image.resize((112, 112))) for image in image_list]).to(cls.device)
embeddings = cls.model(images_tensor)
return embeddings.cpu().detach().numpy()
@classmethod
#PIL images
def embeds_to_tensor(cls, image_list):
images_tensor = torch.stack([cls.transforms(image.resize((112, 112))) for image in image_list]).to(cls.device)
embeddings = cls.model(images_tensor)
return embeddings.cpu()