In [45]:
from typing import Any
import random

import numpy as np

from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import ToTensor

from validation_model import FaceValidateV1
from settings import *
from avapix.avapix_loss import AvapixLoss
import avapix.avapix_utils as utils

In [None]:
class EmbeddedFacesDataset(Dataset):
    def __init__(self, images, gen_per_image) -> None:
        super().__init__()

        self.transform = ToTensor()

        self.images = images
        self.gen_per_image = gen_per_image

        self.random_lengths = []
        self.curr_image_tensor = None

    def __getitem__(self, index) -> Any:
        if index % self.gen_per_image == 0:
            self.random_lengths = list(range(128))
            random.shuffle(self.random_lengths)

            curr_image = self.images[index // self.gen_per_image]
            self.curr_image_tensor = self.transform(Image.open(curr_image))

        curr_rand_len = self.random_lengths.pop()

        output_img = utils.generate_input_v1(self.curr_image_tensor,
                                             DEFAULT_RANDOM_SEED,
                                             curr_rand_len)
        
        return output_img

    def __len__(self):
        return len(self.images) * self.gen_per_image


In [None]:
model = FaceValidateV1.load_state_dict(VALIDATION_MODEL_PATH)