In [1]:
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import IterableDataset
import torchvision.transforms as T

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

def load_png(path) -> Image:
    """
    Loads an PNG image replacing the alpha background with a neutral gray.
    Result is returned in RGB
    """
    image = Image.open(path).convert('RGBA')
    bg = Image.new('RGB', image.size, (128, 128, 128))
    image_rgb = Image.alpha_composite(bg.convert('RGBA'), image).convert('RGB')
    return image_rgb

def image_to_tensor(image: Image) -> torch.Tensor:
    """Converts an RGB image to a Tensor"""
    transform = T.Compose([
        T.ToTensor(),
        T.Lambda(lambda t: 2.0 * t - 1.0)
    ])
    x = transform(image)
    return x.unsqueeze(0).to(DEVICE)

class PixelPoseDataset(IterableDataset):
    def __init__(self, character_dir: Path, vae_model, clip_model, clip_preprocess):
        self.character_dir = character_dir
        self.vae_model = vae_model
        self.clip_model = clip_model
        self.clip_preprocess = clip_preprocess
    
    def __iter__(self):
        for character in self.character_dir.iterdir():
            for animation in character.iterdir():
                frames_dir = animation / 'frames'
                poses_dir = animation / 'poses'

                if frames_dir.exists() and poses_dir.exists():
                    ref_vae_tensor, ref_clip_tensor = None, None
                    for frame, pose in zip(frames_dir.iterdir(), poses_dir.iterdir()):
                        # Load frame
                        frame_img = load_png(frame)

                        # Load target
                        target_tensor = image_to_tensor(frame_img)

                        # If this is the first frame, encode and store as the reference
                        if ref_vae_tensor is None:
                            with torch.no_grad():
                                ref_vae_tensor = self._encode_frame_vae(frame_img)
                                ref_clip_tensor = self._encode_frame_clip(frame_img)
                        
                        # Load pose
                        pose_img = load_png(pose)
                        pose_tensor = image_to_tensor(pose_img)

                        yield {
                            'reference_vae': ref_vae_tensor,
                            'reference_clip': ref_clip_tensor,
                            'pose': pose_tensor,
                            'target': target_tensor
                        }
    
    def _encode_frame_vae(self, frame: Image) -> torch.Tensor:
        # Compute latent with the vae
        x = image_to_tensor(frame)
        z = self.vae_model.encode(x).latent_dist.sample() # performs the reparameterization trick for us
        z *= self.vae_model.config.scaling_factor
        return z
    
    def _encode_frame_clip(self, frame: Image) -> torch.Tensor:
        x = self.clip_preprocess(frame).unsqueeze(0).to(DEVICE)
        features = self.clip_model.encode_image(x)
        features /= features.norm(dim=-1, keepdim=True) # Normalize
        return features

Using device: cuda


In [2]:
from diffusers import AutoencoderKL
import open_clip

# VAE
vae_model = AutoencoderKL.from_pretrained('stabilityai/stable-diffusion-3.5-large',
                                          subfolder='vae')
vae_model.eval()
vae_model = vae_model.to(DEVICE)

# CLIP
clip_model, clip_preprocess = open_clip.create_model_from_pretrained('hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
clip_model.eval()
clip_model = clip_model.to(DEVICE)

In [3]:
from tqdm import tqdm
from torch.utils.data import DataLoader

dataset = PixelPoseDataset(Path('/home/gbear/Dev/pixelpose/results'),
                                vae_model=vae_model,
                                clip_model=clip_model,
                                clip_preprocess=clip_preprocess)
loader = DataLoader(dataset)

out_dir = Path('trainset')
out_dir.mkdir(exist_ok=True)

chunk_size = 500
buffer = []

with tqdm() as pbar:
    for i, data in enumerate(loader):
        data_cpu = {k: v.cpu() for k, v in data.items()}
        buffer.append(data_cpu)  # keep only CPU tensors in memory
        if len(buffer) == chunk_size:
            torch.save(buffer, out_dir / f'{i//chunk_size:04d}.pt')
            buffer = []
            torch.cuda.empty_cache()
        pbar.update(1)

if len(buffer) > 0:
    torch.save(buffer, out_dir / f'{i//chunk_size:04d}.pt')

20185it [01:22, 245.62it/s]


KeyboardInterrupt: 