In [None]:
# Training

import random
import os
from pathlib import Path
from PIL import Image
import numpy as np
import torch
from torchvision import transforms as T

# root
    # 251514-296055...
        # src_img.png
        # 00000.jpg-00020.jpg # tgt
        # 00000_mask.png-00020.png # mask
        # poses.npy # pose

class Portrait4dDataset:
    def __init__(self, root_dir, transform=None):
        # self.root = Path(root)
        self.root_dir = root_dir
        self.transform = transform
        self.subjects = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.samples = []
        
        for sub in self.subjects:
            sub_path = os.path.join(root_dir, sub)
            views = sorted([f for f in os.listdir(sub_path) if f.endswith('.jpg')])

            for view_file in views:
                view_id = view_file.split('.')[0]
                mask_file = f"{view_id}_mask.png"
                
                if os.path.exists(os.path.join(sub_path, mask_file)):
                    self.samples.append({
                        'sub_path': sub_path,
                        'view_file': view_file,
                        'mask_file': mask_file,
                        'view_index': int(view_id)
                    })
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        sub_path = sample['sub_path']
        src_idx = 9
        
        # Load Source Image
        src_path = os.path.join(sub_path, "src.jpg")
        src = Image.open(src_path).convert("RGB")

        # Load View Image
        img_path = os.path.join(sub_path, sample['view_file'])
        image = Image.open(img_path).convert("RGB")
        
        # Load Mask
        mask_path = os.path.join(sub_path, sample['mask_file'])
        mask = Image.open(mask_path).convert("L") # 1-channel grayscale
        
        # Load Pose
        poses = np.load(os.path.join(sub_path, 'poses.npy'))
        
        src_pose = torch.from_numpy(poses[src_idx]).float()
        
        # Target: The pose of the current view (000xx.jpg)
        tgt_pose = torch.from_numpy(poses[sample['view_index']]).float()
        
        # Delta: The movement required to get from 'hint' to 'target'
        delta_pose = tgt_pose - src_pose
        
        if self.transform:
            src = self.transform(src)
            image = self.transform(image)
            mask_transform = T.Compose([
                T.Resize((256, 256)),
                T.ToTensor()
            ])
            mask = mask_transform(mask)
    
        return {
            "jpg": image,
            "hint": src,
            'mask': mask,
            "delta_pose": delta_pose, # pose from original angle
            "subject_id": os.path.basename(sub_path),
            "txt": ""
        }

In [None]:
# Inference

from PIL import Image
import numpy as np
import torchvision.transforms as T

# Load the last checkpoint
import torch
from cldm.toss import TOSS
from cldm.model import load_state_dict
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms as T

from ldm.util import instantiate_from_config

config = OmegaConf.load("models/toss_vae.yaml")
model = instantiate_from_config(config.model)

state_dict = torch.load("ckpt/toss.ckpt", map_location="cpu", weights_only=False)["state_dict"]
# Remove the old pose_net so your new one can take its place
keys_to_remove = [k for k in state_dict.keys() if "pose_net" in k]
for k in keys_to_remove:
    del state_dict[k]
model.load_state_dict(state_dict, strict=False)

trained_ckpt = torch.load("checkpoints/v2/last.ckpt", map_location="cpu", weights_only=False)
# This will inject your LoRA weights and your trained PoseNet
m, u = model.load_state_dict(trained_ckpt["state_dict"], strict=False)

print("Unexpected keys (should be empty):", u)

model.cuda()
model.eval()

# source image
src_img = Image.open("00010.jpg").convert("RGB")

transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3)  # â†’ [-1,1]
])

src_img = transform(src_img)
src_img = src_img.unsqueeze(0).cuda()
print(src_img.shape)

with torch.no_grad():
    z = model.encode_first_stage(src_img)
    z = model.get_first_stage_encoding(z)
print(z.shape)

# with torch.no_grad():
#     hint_post = model.encode_first_stage(src_img)
#     hint_latent = hint_post.mode() if hasattr(hint_post, "mode") else hint_post



# pose (example)
# delta_pose = torch.eye(4)[:3]   # (3,4)
# delta_pose = delta_pose.unsqueeze(0).cuda()
# delta_pose = torch.tensor([[0.0, 0.0, 0.0]], device="cuda")  # [B=1, 3]
# delta_pose = create_rotation_matrix(yaw_deg=30).cuda()
delta_pose = torch.tensor([[0.0, 0.0, 0.0]], device="cuda")

# delta_pose_cfg = delta_pose.repeat(2, 1) # Shape becomes [2, 16]


# empty text
txt = [""]

print(src_img.shape)
print(delta_pose.shape)

x_T = torch.randn_like(z)

cond = {
    "in_concat": [z],
    "c_crossattn": [model.get_learned_conditioning(txt)],
    "c_concat": [src_img],
    "delta_pose": delta_pose,
}

uc = {
    "in_concat": [z * 0],
    "c_crossattn": [model.get_unconditional_conditioning(1)],
    "c_concat": [src_img],
    "delta_pose": delta_pose,
}


from ldm.models.diffusion.ddim import DDIMSampler
from torchvision.utils import save_image

sampler = DDIMSampler(model)

sampler.make_schedule(
    ddim_num_steps=50,
    ddim_eta=0.0,
    verbose=False
)

print(sampler.ddim_timesteps)

strength = 0.4
t_enc = int(strength * len(sampler.ddim_timesteps))

z_enc = sampler.stochastic_encode(z, torch.tensor([t_enc]).cuda())

# Decode
samples = sampler.decode(
    z_enc,
    cond,
    t_enc,
    unconditional_guidance_scale=3.0,
    unconditional_conditioning=None
)

# Decode
with torch.no_grad():
    out = model.decode_first_stage(samples)
out = torch.clamp((out + 1) / 2, 0, 1)
print(out.shape)
save_image(out, "test_output.png")


# LOSS
    # Perceptual Loss, Contrastive Loss
