In [1]:
"""
Helper utilities
"""

from PIL import Image
import torch
import torchvision.transforms as T

def load_png(path) -> Image:
    """
    Loads an PNG image replacing the alpha background with a neutral gray.
    Result is returned in RGB
    """
    image = Image.open(path).convert('RGBA')
    bg = Image.new('RGB', image.size, (128, 128, 128))
    image_rgb = Image.alpha_composite(bg.convert('RGBA'), image).convert('RGB')
    return image_rgb

def image_to_tensor(image: Image) -> torch.Tensor:
    """Converts an RGB image to a Tensor"""
    transform = T.Compose([
        T.ToTensor(),
        T.Lambda(lambda t: 2.0 * t - 1.0)
    ])
    return transform(image)

In [2]:
"""
ReferenceNet
"""
from torch import nn

class FiLM(nn.Module):
    def __init__(self, features, out_channels, hidden_dims):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(features, hidden_dims),
            nn.SiLU(),
            nn.Linear(hidden_dims, out_channels*2)
        )
    
    def forward(self, x, features):
        features = self.proj(features)
        gamma, beta = features.chunk(2, dim=-1)

        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)

        return gamma * x + beta

class ReferenceNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, out_width, out_height,
                 unet_channels, clip_features):
        super().__init__()

        self.clip_film = FiLM(clip_features, in_channels, 64)

        self.unet_film = FiLM(out_channels*out_width*out_height, unet_channels, 128)
        
        self.down = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=3, stride=2, padding=1), # Stride = 2 will half the resolution
            nn.GroupNorm(4, out_channels), # channels must be divisible by groups
            nn.SiLU()
        )
    
    def forward(self, unet, latent, clip):
        # Compute the next latent by comining with CLIP and downsampling
        latent = self.clip_film(latent, clip)
        latent = self.down(latent)

        # Flatten the resulting latent and combine with the UNet
        latent_flat = latent.view(latent.shape[0], -1)
        unet = self.unet_film(unet, latent_flat)

        # Return the UNet result and the latent for use by the next ReferenceNet block
        return unet, latent


# Example usage:
# ref_block = ReferenceNetBlock(16, 32, 8, 8, 3, 512)

# latent = torch.randn((1, 16, 16, 16))
# clip_features = torch.randn((1, 512))
# x_unet = torch.randn((1, 3, 128, 128))

# x_unet, latent = ref_block(x_unet, latent, clip)
# print(f"x_unet: {x_unet.shape}, latent: {latent.shape}")

In [3]:
"""
Pose Guider

Encodes the pose into some (??) feature space that is simply added
to the noisy starting image.

Simply a 4-layer conv with a sizable hidden dimension.
"""

class PoseBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=3, stride=1, padding=1),
            nn.SiLU()
        )
    
    def forward(self, x):
        return self.block(x)


class PoseGuider(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_dims):
        super().__init__()
        self.blocks = nn.Sequential(
            PoseBlock(in_channels, hidden_dims),
            PoseBlock(hidden_dims, hidden_dims),
            PoseBlock(hidden_dims, hidden_dims),
            PoseBlock(hidden_dims, out_channels)
        )
    
    def forward(self, x):
        return self.blocks(x)

# Example usage

# Load the pose
# image = load_png('assets/pose.png')
# pose = image_to_tensor(image)

# # Create a noisy initial image
# noisy = torch.randn_like(pose)

# # Call the model and add the encoded result to the noise
# pose_model = PoseGuider(3, 3, 64)
# y = pose_model(ref)
# y += noisy
# print(f"UNet input shape: {y.shape}")

In [4]:
"""
Denoising UNet
"""
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            # nn.GroupNorm(2, out_channels),
            nn.SiLU()
        )
    def forward(self, x):
        return self.block(x)

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)  # downsample
        )
    def forward(self, x):
        return self.block(x)

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            ConvBlock(out_channels, out_channels)
        )
    def forward(self, x):
        return self.block(x)

class DenoisingUNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Pose Encoder
        self.pose_enc = PoseGuider(3, 3, 64)

        # Reference Encoder
        self.ref1 = ReferenceNetBlock(16, 32, 8, 8, 3, 512)
        self.enc1 = ConvBlock(3, 16)

        self.ref2 = ReferenceNetBlock(32, 64, 4, 4, 16, 512)
        self.enc2 = DownBlock(16, 32)

        self.ref3 = ReferenceNetBlock(64, 128, 2, 2, 32, 512)
        self.enc3 = DownBlock(32, 64)

        self.ref4 = ReferenceNetBlock(128, 256, 1, 1, 64, 512)
        self.enc4 = DownBlock(64, 128)

        # Bottleneck
        self.bottleneck = ConvBlock(128, 128)

        # Decoder
        self.dec4 = UpBlock(128, 64)
        self.dec3 = UpBlock(128, 32)
        self.dec2 = UpBlock(64, 16)
        self.dec1 = ConvBlock(32, 8)

        self.out = nn.Conv2d(8, 3, kernel_size=1)
    
    def forward(self, noisy, pose, latent, clip):
        # Encode
        pose = self.pose_enc(pose) + noisy

        tmp, latent = self.ref1(pose, latent, clip)
        e1 = self.enc1(tmp)
        
        tmp, latent = self.ref2(e1, latent, clip)
        e2 = self.enc2(tmp)

        tmp, latent = self.ref3(e2, latent, clip)
        e3 = self.enc3(tmp)

        tmp, _ = self.ref4(e3, latent, clip)
        e4 = self.enc4(tmp)

        # Bottleneck
        b = self.bottleneck(e4)

        # Decoder
        d4 = self.dec4(b)
        d3 = self.dec3(torch.cat([d4, e3], dim=1))
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))

        return self.out(d1)


# Example usage
# noise = torch.randn((1, 3, 128, 128))
# latent = torch.randn((1, 16, 16, 16))
# clip = torch.randn((1, 512))
# unet = DenoisingUNet()
# y = unet(noise, latent, clip)
# print(y.shape)

In [None]:
from typing import List
from torch.utils.data import Dataset

class MultiFileDataset(Dataset):
    def __init__(self, file_paths: List):
        self.file_paths = file_paths
        self.file_sizes = [len(torch.load(f, map_location='cpu')) for f in file_paths]
        self._cum_file_sizes = [sum(self.file_sizes[:i+1]) for i in range(len(self.file_sizes))]

        self._file_idx = None
        self._data = []
    
    def __len__(self):
        return sum(self.file_sizes)
    
    def __getitem__(self, idx):
        target_file_idx, data_idx = self._find_file(idx)
        if target_file_idx != self._file_idx:
            self._data = torch.load(self.file_paths[target_file_idx], map_location='cpu')
        return self._data[data_idx]

    def _find_file(self, idx):
        for i, c in enumerate(self._cum_file_sizes):
            if idx < c:
                file_idx = i
                data_idx = idx if i == 0 else idx - self._cum_file_sizes[i-1]
                return file_idx, data_idx


In [33]:
from pathlib import Path
from torch.utils.data import random_split, DataLoader

# Create the dataset
file_paths = [str(p) for p in Path('trainset').glob('*.pt')]
dataset = MultiFileDataset(file_paths)

# Split
train_split, val_split, test_split = 0.8, 0.1, 0.1
n_total = len(dataset)
n_train = int(train_split * n_total)
n_val = int(val_split * n_total)
n_test = n_total - n_train - n_val

train_set, val_set, test_set = random_split(
    dataset,
    [n_train, n_val, n_test]
)

# Create DataLoaders
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)

In [34]:
from tqdm import tqdm
import torch.nn.functional as F

# Basic setup
N_EPOCHS = 10000
device = torch.device('cuda')
# unet = DenoisingUNet().to(device)
# optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)

# Noise params
T = 1000
betas = torch.linspace(1e-4, 0.02, T).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

# losses = []

for epoch in range(N_EPOCHS):
    unet.train()
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False):
        # Move to GPU
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

        # Get inputs
        # TODO: look into this extra dim  issue. Shapes are 8, 1, 1, 3, 128, 128 for some reason
        latent = batch['reference_vae'].squeeze(1).squeeze(1)
        clip = batch['reference_clip'].squeeze(1).squeeze(1)
        pose = batch['pose'].squeeze(1).squeeze(1)

        # Prepare the target
        target = batch['target'].squeeze(1).squeeze(1)
        B = target.size(0)
        t = torch.randint(0, T, (B,), device=device)
        eps = torch.randn_like(target)
        a_bar_t = alphas_cumprod[t].view(B, 1, 1, 1)
        noisy = torch.sqrt(a_bar_t) * target + torch.sqrt(1 - a_bar_t) * eps

        # Forward
        pred_eps = unet(noisy, pose, latent, clip)

        # Loss
        loss = F.mse_loss(pred_eps, eps)

        # Backprop
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    losses.append(loss.item())
    print(f"Epoch {epoch+1}: loss = {loss.item():.4f}")

                                                          

Epoch 1: loss = 0.0310


                                                          

Epoch 2: loss = 0.0510


                                                          

Epoch 3: loss = 0.0433


                                                          

Epoch 4: loss = 0.0253


                                                          

Epoch 5: loss = 0.0491


                                                          

Epoch 6: loss = 0.0400


                                                          

Epoch 7: loss = 0.0263


                                                          

Epoch 8: loss = 0.0244


                                                          

Epoch 9: loss = 0.0259


                                                           

Epoch 10: loss = 0.0420


                                                           

Epoch 11: loss = 0.0340


                                                           

Epoch 12: loss = 0.0294


                                                           

Epoch 13: loss = 0.0331


                                                           

Epoch 14: loss = 0.0411


                                                           

Epoch 15: loss = 0.0273


                                                           

Epoch 16: loss = 0.0315


                                                           

Epoch 17: loss = 0.0343


                                                           

Epoch 18: loss = 0.0294


                                                           

Epoch 19: loss = 0.0343


                                                           

Epoch 20: loss = 0.0288


                                                           

KeyboardInterrupt: 

In [30]:
losses


[0.36390936374664307,
 0.167206808924675,
 0.13038450479507446,
 0.10169705003499985,
 0.12740829586982727,
 0.08056183904409409,
 0.10691474378108978,
 0.10771314799785614,
 0.11733394861221313,
 0.1254330575466156,
 0.1074734628200531,
 0.08809306472539902,
 0.09401794523000717,
 0.10569876432418823,
 0.11191048473119736,
 0.07192958891391754,
 0.09839329123497009,
 0.09563539922237396,
 0.095150887966156,
 0.12466983497142792,
 0.10174554586410522,
 0.0638820081949234,
 0.06841405481100082,
 0.05968412384390831,
 0.057460665702819824,
 0.061801727861166,
 0.048900969326496124,
 0.07124795019626617,
 0.07378151267766953,
 0.055166613310575485,
 0.08690949529409409,
 0.07315545529127121,
 0.06633316725492477,
 0.042446475476026535,
 0.07170681655406952,
 0.038401179015636444,
 0.03812115639448166,
 0.04186611622571945,
 0.05014296621084213,
 0.06284771859645844,
 0.05711023509502411,
 0.07623037695884705,
 0.040004950016736984,
 0.05278850719332695,
 0.026685133576393127,
 0.041526362

In [38]:
import random
import torch
from tqdm import tqdm

@torch.no_grad()
def generate(unet, pose, latent, clip, T=2000, device="cuda"):
    unet.eval()

    # Diffusion schedule (must match training)
    betas = torch.linspace(1e-4, 0.02, T).to(device)
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]])

    # Start from pure noise (same shape as your target)
    x_t = torch.randn((1, 3, 128, 128), device=device)

    for t in tqdm(reversed(range(T)), desc="Sampling"):
        # Prepare t tensor
        t_batch = torch.full((1,), t, device=device, dtype=torch.long)

        # Predict the noise at this step
        eps_pred = unet(x_t, pose, latent, clip)

        # Compute posterior mean (DDPM update)
        beta_t = betas[t]
        alpha_t = alphas[t]
        alpha_bar_t = alphas_cumprod[t]
        alpha_bar_prev = alphas_cumprod_prev[t]

        # Equation from DDPM paper
        pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)
        coef1 = torch.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)
        coef2 = torch.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
        mean = coef1 * pred_x0 + coef2 * x_t

        if t > 0:
            noise = torch.randn_like(x_t)
            sigma_t = torch.sqrt(beta_t)
            x_t = mean + sigma_t * noise
        else:
            x_t = mean

    # Optional: decode with VAE
    img = x_t

    # Map [-1,1] → [0,1] for display
    img = (img.clamp(-1, 1) + 1) / 2
    return img


idx = random.randint(0, len(dataset) - 1)
sample = train_set[idx]

latent = sample['reference_vae'].squeeze(0).to(device)
clip = sample['reference_clip'].squeeze(0).to(device)
pose = sample['pose'].squeeze(0).to(device)

image = generate(unet, pose, latent, clip)

from torchvision.transforms.functional import to_pil_image
img_tensor = image.squeeze(0).cpu()
img_pil = to_pil_image(img_tensor)
img_pil.show()


Sampling: 2000it [00:02, 699.02it/s]


In [31]:
torch.save({
    "epoch": epoch,
    "model_state": unet.state_dict(),
    "optimizer_state": optimizer.state_dict()
}, "checkpoints/unet_checkpoint.pt")


In [9]:
checkpoint = torch.load("checkpoints/unet_checkpoint.pt", map_location=device)
unet.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
start_epoch = checkpoint["epoch"] + 1