In [None]:
# Load Model (Training)
import torch
from cldm.toss import TOSS
from cldm.model import load_state_dict
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config

config = OmegaConf.load("models/toss_vae.yaml")
model = instantiate_from_config(config.model)

state_dict = load_state_dict("ckpt/toss.ckpt")

keys_to_remove = [
    k for k in state_dict.keys()
    if "pose_net" in k
]

for k in keys_to_remove:
    print("Removing:", k)
    del state_dict[k]

m, u = model.load_state_dict(state_dict, strict=False)

print("Missing keys:", m)
print("Unexpected keys:", u)

In [None]:
# Portrait4D Dataset Class

import random
import os
from pathlib import Path
from PIL import Image
import numpy as np
import torch
from torchvision import transforms as T

# root
    # 251514-296055
        # src_img.png
        # 00000.jpg-00020.jpg # tgt
        # 00000_mask.png-00020.png # mask
        # poses.npy # pose
        
        # src_img
        # tgt_img
        # mask
        # delta_pose
        # txt
import os

class Portrait4dDataset:
    def __init__(self, root_dir, transform=None):
        # self.root = Path(root)
        self.root_dir = root_dir
        self.transform = transform
        self.subjects = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.samples = []
        
        for sub in self.subjects:
            sub_path = os.path.join(root_dir, sub)
            views = sorted([f for f in os.listdir(sub_path) if f.endswith('.jpg')])

            for view_file in views:
                view_id = view_file.split('.')[0]
                mask_file = f"{view_id}_mask.png"
                
                if os.path.exists(os.path.join(sub_path, mask_file)):
                    self.samples.append({
                        'sub_path': sub_path,
                        'view_file': view_file,
                        'mask_file': mask_file,
                        'view_index': int(view_id)
                    })
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        sub_path = sample['sub_path']
        src_idx = 9
        
        # Load Source Image
        src_path = os.path.join(sub_path, "src.jpg")
        src = Image.open(src_path).convert("RGB")

        # Load View Image
        img_path = os.path.join(sub_path, sample['view_file'])
        image = Image.open(img_path).convert("RGB")
        
        # Load Mask
        mask_path = os.path.join(sub_path, sample['mask_file'])
        mask = Image.open(mask_path).convert("L") # 1-channel grayscale
        
        # Load Pose
        poses = np.load(os.path.join(sub_path, 'poses.npy'))
        
        src_pose = torch.from_numpy(poses[src_idx]).float()
        
        # Target: The pose of the current view (000xx.jpg)
        tgt_pose = torch.from_numpy(poses[sample['view_index']]).float()
        
        # Delta: The movement required to get from 'hint' to 'target'
        delta_pose = tgt_pose - src_pose
        
        if self.transform:
            src = self.transform(src)
            image = self.transform(image)
            mask_transform = T.Compose([
                T.Resize((256, 256)),
                T.ToTensor()
            ])
            mask = mask_transform(mask)
    
        return {
            "jpg": image,
            "hint": src,
            'mask': mask,
            "delta_pose": delta_pose, # pose from original angle
            "subject_id": os.path.basename(sub_path),
            "txt": ""
        }

In [None]:
# Training

from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader

transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    # T.Normalize([0.5], [0.5])
])

dataset = Portrait4dDataset(
    root_dir="datasets/portrait4d", transform=transform
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_cb = ModelCheckpoint(
    dirpath="checkpoints/v2/",
    filename="toss-{step}",
    save_last=True,
    every_n_train_steps=500
)

trainer = pl.Trainer(
    max_steps=3000,
    accelerator="gpu",
    devices=1,
    callbacks=[checkpoint_cb],
    log_every_n_steps=10,
    enable_checkpointing=True,
)

model.learning_rate = 3e-5
model.sd_locked = True
model.first_stage_key = "jpg"
model.control_key = "hint"
model.cond_stage_key = "txt"

import torch.utils.checkpoint as cp
cp.checkpoint = lambda func, *args, **kwargs: func(*args)

trainer.fit(model, dataloader)

# print("Initializing iterator...")
# data_iter = iter(dataloader)

# print("Fetching first batch...")
# batch = next(data_iter)
# print("Success! Batch keys:", batch.keys())

In [None]:
# Load Model (Inference)

from PIL import Image
import numpy as np
import torchvision.transforms as T

# Load the last checkpoint
import torch
from cldm.toss import TOSS
from cldm.model import load_state_dict
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms as T

from ldm.util import instantiate_from_config

config = OmegaConf.load("models/toss_vae.yaml")
model = instantiate_from_config(config.model)

state_dict = torch.load("ckpt/toss.ckpt", map_location="cpu", weights_only=False)["state_dict"]

# Remove the old pose_net so your new one can take its place
keys_to_remove = [k for k in state_dict.keys() if "pose_net" in k]
for k in keys_to_remove:
    del state_dict[k]

model.load_state_dict(state_dict, strict=False)

trained_ckpt = torch.load("checkpoints/v2/last.ckpt", map_location="cpu", weights_only=False)

# This will inject your LoRA weights and your trained PoseNet
m, u = model.load_state_dict(trained_ckpt["state_dict"], strict=False)

print("Unexpected keys (should be empty):", u)

In [None]:
# Inference v1 (Manually typed inference)
model.cuda()
model.eval()

# source image
src_img = Image.open("minion-toy-vbr.png").convert("RGB")

transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3)  # → [-1,1]
])

src_img = transform(src_img)
src_img = src_img.unsqueeze(0).cuda()
print(src_img.shape)

with torch.no_grad():
    z = model.encode_first_stage(src_img)
    z = model.get_first_stage_encoding(z)
print(z.shape)

# with torch.no_grad():
#     hint_post = model.encode_first_stage(src_img)
#     hint_latent = hint_post.mode() if hasattr(hint_post, "mode") else hint_post



# pose (example)
# delta_pose = torch.eye(4)[:3]   # (3,4)
# delta_pose = delta_pose.unsqueeze(0).cuda()
# delta_pose = torch.tensor([[0.0, 0.0, 0.0]], device="cuda")  # [B=1, 3]
# delta_pose = create_rotation_matrix(yaw_deg=30).cuda()

print("pose_net in_features:", model.model.diffusion_model.pose_net[0].in_features)

yaw_deg = 0
yaw = yaw_deg * torch.pi / 180
delta_pose = torch.tensor([[0.0, 0.0, 0.0]], device="cuda")
# delta_pose = torch.zeros(1, 16).cuda()
# delta_pose = torch.randn(1, 16).cuda() * 0.5
# delta_pose = torch.tensor([[9.50876296e-01, 3.86712351e-09,  3.09571296e-01,  1.24118519e+00,
#   -3.93458927e-16,  1.00000000e+00, -1.24918680e-08, -5.00844948e-08,
#   -3.09571296e-01,  1.18782193e-08,  9.50876296e-01,  4.01241302e+00,
#    0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  1.00000000e+00]], device="cuda")
print("DELTA_POSE:", delta_pose.shape)

# delta_pose_cfg = delta_pose.repeat(2, 1) # Shape becomes [2, 16]


# empty text
txt = [""]

batch = {
        "jpg": src_img,
        "hint": src_img,
        # 'mask': mask,
        "delta_pose": delta_pose, # pose from original angle
        # "subject_id": os.path.basename(sub_path),
        "txt": txt
    }

print(src_img.shape)
print(delta_pose.shape)

x_T = torch.randn_like(z)

# cond = {
#     "in_concat": [z],
#     "c_crossattn": [model.get_learned_conditioning(txt)],
#     "c_concat": [src_img],
#     "delta_pose": delta_pose,
# }

# uc_cond = {
#     "in_concat": [z * 0],
#     "c_crossattn": [model.get_unconditional_conditioning(1)],
#     "c_concat": [src_img],
#     "delta_pose": delta_pose,
# }

x, cond = model.get_input(batch, model.first_stage_key)
# cond["delta_pose"] = torch.zeros_like(cond["delta_pose"])
print("COND", cond)

from ldm.models.diffusion.ddim import DDIMSampler
from torchvision.utils import save_image

sampler = DDIMSampler(model)

# with torch.no_grad():
#     samples, _ = sampler.sample(
#         S=50,                      # inference steps
#         batch_size=1,
#         shape=(4, 64, 64),         # latent shape (예시)
#         conditioning=cond,         # conditioning (same as training)
#         unconditional_guidance_scale=7.5,
#         unconditional_conditioning=uc_cond,
#         eta=0.0
#     )

# txt2img
# with torch.no_grad():
#     samples, _ = sampler.sample(
#         S=20,                      # inference steps
#         batch_size=1,
#         shape=(z.shape[1], z.shape[2], z.shape[3]),         # latent shape
#         conditioning=cond,         # conditioning (same as training)
#         unconditional_guidance_scale=1,
#         # unconditional_conditioning=uc_cond,
#         # eta=0.0
#     )

print("loaded?", any((p.abs().mean().item() > 0) for p in model.parameters()))
print("z:", z.shape, z.abs().mean().item(), z.std().item())
print("cond keys:", cond.keys(), {k: type(v) for k,v in cond.items()})

# img2img
sampler.make_schedule(
    ddim_num_steps=50,
    ddim_eta=0.0,
    verbose=False
)

print(sampler.ddim_timesteps)

strength = 0.4
t_enc = int(strength * len(sampler.ddim_timesteps))
z_enc = sampler.stochastic_encode(z, torch.tensor([t_enc]).cuda())

print("x_dec batch:", z_enc.shape[0])
print("cond delta_pose batch:", cond["delta_pose"].shape[0])
print("cond crossattn batch:", cond["c_crossattn"][0].shape[0])
print("cond in_concat batch:", cond["in_concat"][0].shape[0])


# Decode
samples = sampler.decode(
    z_enc,
    cond,
    t_enc,
    unconditional_guidance_scale=3.0,
    unconditional_conditioning=None
)

# Decode
with torch.no_grad():
    out = model.decode_first_stage(samples)
out = torch.clamp((out + 1) / 2, 0, 1)
print(out.shape)
save_image(out, "test_output.png")


# LOSS
    # Perceptual Loss, Contrastive Loss






In [None]:
# Inference v2 (Inference code used in Gradio)
# TOSS app.py to ipynb
model.cuda()
model.eval()

# source image
src_img = Image.open("examples/00010.jpg").convert("RGB")

transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    # T.Normalize([0.5]*3, [0.5]*3)  # → [-1,1]
])

src_img = transform(src_img)
src_img = src_img.unsqueeze(0).cuda()
print(src_img.shape)

with torch.no_grad():
    z = model.encode_first_stage(src_img)
    z = model.get_first_stage_encoding(z)
print(z.shape)

print("pose_net in_features:", model.model.diffusion_model.pose_net[0].in_features)

# empty text
txt = [""]

# batch = {
#         "jpg": src_img,
#         "hint": src_img,
#         # 'mask': mask,
#         "delta_pose": delta_pose, # pose from original angle
#         # "subject_id": os.path.basename(sub_path),
#         "txt": txt
#     }

n_samples = 1
prompt_scale = 0.0
img_scale = 1.0
ddim_steps = 50
ddim_eta = 0.0
h = 256
w = 256
# yaw_deg = 0
# yaw = yaw_deg * torch.pi / 180

# delta_pose = torch.tensor([[9.50876296e-01, 3.86712351e-09,  3.09571296e-01,  1.24118519e+00,
#   -3.93458927e-16,  1.00000000e+00, -1.24918680e-08, -5.00844948e-08,
#   -3.09571296e-01,  1.18782193e-08,  9.50876296e-01,  4.01241302e+00,
#    0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  1.00000000e+00]], device="cuda")

import math
def get_T_from_relative(x, y, z, pose_enc="freq")->torch.Tensor:
    """
    Args:
        x: relative polar degree
        y: relative azimuth degree
        z: relative distance
        
    example:
        (0., -90., 0.): left view
        (0., 90., 0.): right view
        (0., 180., 0.): back view
        (-90., 0., 0.): top view
        (90., 0., 0.): bottom view
    """
    print("POSE_ENC:", pose_enc)
    if pose_enc in ["freq","identity"]:
        d_T = torch.tensor([math.radians(x), math.radians(y), z])
    elif pose_enc == "zero":
        d_T = torch.tensor([math.radians(x), math.sin(
                math.radians(y)), math.cos(math.radians(y)), z])
    else:
        raise NotImplementedError
    return d_T


# delta_pose = torch.zeros((1, 3), device="cuda")
# print("DELTA_POSE:", delta_pose.shape)

delta_pose = get_T_from_relative(0, 0, 0, pose_enc=model.model.diffusion_model.pose_enc)

# delta_pose_cfg = delta_pose.repeat(2, 1) # Shape becomes [2, 16]


# hint
c_cat = src_img
# text
uc_cross = model.get_unconditional_conditioning(n_samples)
c = model.get_learned_conditioning(txt)
# camera pose
# delta_pose = T[None, :].repeat(n_samples, 1).to(c.device)
# concat for concat pipline
in_concat = model.encode_first_stage(((src_img*2-1).to(c.device))).mode().detach()


cond = {}
cond['delta_pose'] = delta_pose
cond['c_crossattn'] = [c]
cond['c_concat'] = [c_cat]
cond['in_concat'] = [in_concat]

# uc2 for prompt
uc2 = {}
uc2['delta_pose'] = delta_pose
uc2['c_crossattn'] = [uc_cross]
uc2['c_concat'] = [c_cat]
uc2['in_concat'] = [in_concat]
            
# uc for image
uc = {}
uc['delta_pose'] = delta_pose
uc['c_crossattn'] = [uc_cross]
uc['c_concat'] = [c_cat]
uc['in_concat'] = [in_concat*0] 

from ldm.models.diffusion.ddim import DDIMSampler
from torchvision.utils import save_image

sampler = DDIMSampler(model)

shape = [4, h // 8, w // 8]
x_T = torch.randn(in_concat.shape, device=c.device)
samples_ddim, _ = sampler.sample(S=ddim_steps,
                                conditioning=cond,
                                batch_size=n_samples,
                                shape=shape,
                                verbose=False,
                                unconditional_guidance_scale=img_scale,
                                unconditional_conditioning=uc,
                                unconditional_guidance_scale2=prompt_scale,
                                unconditional_conditioning2=uc2,
                                eta=ddim_eta,
                                x_T=x_T)
print(samples_ddim.shape)
x_samples_ddim = model.decode_first_stage(samples_ddim)

out = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()

print(out.shape)
save_image(out, "test_output.png")