In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.utils.data import Dataset

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

import os
import matplotlib.pyplot as plt
import numpy as np
import imageio
import copy
import math
from tqdm.notebook import trange, tqdm
from PIL import Image

from diffusers.models import AutoencoderKL

In [17]:
# training parameters
batch_size =  32
lr = 2e-5

train_epoch = 1200

# data_loader
latent_size = 32

timesteps = 500
patch_size = 2


data_set_root = "../../datasets"

use_cuda = 0 # torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

In [18]:
class LatentDataset(Dataset):
    def __init__(self, latent_dir):
        self.latent_dir = latent_dir
        self.latent_files = sorted(os.listdir(latent_dir))

    def __len__(self):
        return len(self.latent_files)

    def __getitem__(self, idx):
        latent_file = self.latent_files[idx]
        latent = np.load(os.path.join(self.latent_dir, latent_file))
        return torch.tensor(latent)

In [19]:
EDGEBAGS = True

if EDGEBAGS:
    nm = "edges2shoes" # edges2handbags
    dataset_dir = f"/home/roman/PycharmProjects/jobs/dandy/pytorch-CycleGAN-and-pix2pix/datasets/{nm}/" # /class_A" #"."
    latent_save_dir = f"/home/roman/PycharmProjects/jobs/dandy/pytorch-CycleGAN-and-pix2pix/datasets/{nm}_latent" # "."
    latent_save_dir_right = latent_save_dir + "_right"
    latent_save_dir_left = latent_save_dir + "_left"

    
if False:
    
    data_set_root = "/media/luke/Quick_Storage/Data/CelebAHQ/image_latents"
    
    trainset = LatentDataset(data_set_root)
else:    
    dir_latent = latent_save_dir_right
    
    trainset = LatentDataset(dir_latent)# data_set_root)
    
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

In [20]:
def extract_patches(image_tensor, patch_size=8):
    # Get the dimensions of the image tensor
    bs, c, h, w = image_tensor.size()
    
    # Define the Unfold layer with appropriate parameters
    unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
    
    # Apply Unfold to the image tensor
    unfolded = unfold(image_tensor)
    
    # Reshape the unfolded tensor to match the desired output shape
    # Output shape: BSxLxH, where L is the number of patches in each dimension
    unfolded = unfolded.transpose(1, 2).reshape(bs, -1, c * patch_size * patch_size)
    
    return unfolded


def reconstruct_image(patch_sequence, image_shape, patch_size=8):
    """
    Reconstructs the original image tensor from a sequence of patches.

    Args:
        patch_sequence (torch.Tensor): Sequence of patches with shape
        BS x L x (C x patch_size x patch_size)
        image_shape (tuple): Shape of the original image tensor (bs, c, h, w).
        patch_size (int): Size of the patches used in extraction.

    Returns:
        torch.Tensor: Reconstructed image tensor.
    """
    bs, c, h, w = image_shape
    num_patches_h = h // patch_size
    num_patches_w = w // patch_size
    
    # Reshape the patch sequence to match the unfolded tensor shape
    unfolded_shape = (bs, num_patches_h, num_patches_w, patch_size, patch_size, c)
    patch_sequence = patch_sequence.view(*unfolded_shape)
    
    # Transpose dimensions to match the original image tensor shape
    patch_sequence = patch_sequence.permute(0, 5, 1, 3, 2, 4).contiguous()
    
    # Reshape the sequence of patches back into the original image tensor shape
    reconstructed = patch_sequence.view(bs, c, h, w)
    
    return reconstructed

In [21]:
class ConditionalNorm2d(nn.Module):
    def __init__(self, hidden_size, num_features):
        super(ConditionalNorm2d, self).__init__()
        self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False)

        self.fcw = nn.Linear(num_features, hidden_size)
        self.fcb = nn.Linear(num_features, hidden_size)

    def forward(self, x, features):
        bs, s, l = x.shape
        
        out = self.norm(x)
        w = self.fcw(features).reshape(bs, 1, -1)
        b = self.fcb(features).reshape(bs, 1, -1)

        return w * out + b

    
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
    
    
# Transformer block with self-attention
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4, num_features=128):
        # Initialize the parent nn.Module
        super(TransformerBlock, self).__init__()
        
        # Layer normalization to normalize the input data
        self.norm = nn.LayerNorm(hidden_size)
        
        # Multi-head attention mechanism
        self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads=num_heads, 
                                                    batch_first=True, dropout=0.0)
        
        # Another layer normalization
        self.con_norm = ConditionalNorm2d(hidden_size, num_features)
        
        # Multi-layer perceptron (MLP) with a hidden layer and activation function
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.LayerNorm(hidden_size * 4),
            nn.ELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )
                
    def forward(self, x, features):
        # Apply the first layer normalization
        norm_x = self.norm(x)
        
        # Apply multi-head attention and add the input (residual connection)
        x = self.multihead_attn(norm_x, norm_x, norm_x)[0] + x
        
        # Apply the second layer normalization
        norm_x = self.con_norm(x, features)
        
        # Pass through the MLP and add the input (residual connection)
        x = self.mlp(norm_x) + x
        
        return x

    
# Define a Vision Encoder module for the Diffusion Transformer
class DiT(nn.Module):
    def __init__(self, image_size, channels_in, patch_size=16, 
                 hidden_size=128, num_features=128, 
                 num_layers=3, num_heads=4):
        super(DiT, self).__init__()
        
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(num_features),
            nn.Linear(num_features, 2 * num_features),
            nn.GELU(),
            nn.Linear(2 * num_features, num_features),
            nn.GELU()
        )
        
        self.patch_size = patch_size
        self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size)
        
        seq_length = (image_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_size).normal_(std=0.02))
        
        # Create multiple transformer blocks as layers
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads) for _ in range(num_layers)
        ])
        
        self.fc_out = nn.Linear(hidden_size, channels_in * patch_size * patch_size)
                
    def forward(self, image_in, index):  
        # Get timestep embedding
        index_features = self.time_mlp(index)

        # Split input into patches
        patch_seq = extract_patches(image_in, patch_size=self.patch_size)
        patch_emb = self.fc_in(patch_seq)

        # Add a unique embedding to each token embedding
        embs = patch_emb + self.pos_embedding
        
        # Pass the embeddings through each Transformer block
        for block in self.blocks:
            embs = block(embs, index_features)
        
        # Project to output
        image_out = self.fc_out(embs)
        
        # Reconstruct the input from patches and return result
        return reconstruct_image(image_out, image_in.shape, patch_size=self.patch_size)

In [22]:
def cosine_alphas_bar(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, steps, steps)
    alphas_bar = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_bar = alphas_bar / alphas_bar[0]
    return alphas_bar[:timesteps]
    
def noise_from_x0(curr_img, img_pred, alpha):
    return (curr_img - alpha.sqrt() * img_pred)/((1 - alpha).sqrt() + 1e-4)
    
def cold_diffuse(diffusion_model, sample_in, total_steps, start_step=0):
    diffusion_model.eval()
    bs = sample_in.shape[0]
    alphas = torch.flip(cosine_alphas_bar(total_steps), (0,)).to(device)
    random_sample = copy.deepcopy(sample_in)
    with torch.no_grad():
        for i in trange(start_step, total_steps - 1):
            index = (i * torch.ones(bs, device=sample_in.device)).long()

            img_output = diffusion_model(random_sample, index)

            noise = noise_from_x0(random_sample, img_output, alphas[i])
            x0 = img_output

            rep1 = alphas[i].sqrt() * x0 + (1 - alphas[i]).sqrt() * noise
            rep2 = alphas[i + 1].sqrt() * x0 + (1 - alphas[i + 1]).sqrt() * noise

            random_sample += rep2 - rep1

        index = ((total_steps - 1) * torch.ones(bs, device=sample_in.device)).long()
        img_output = diffusion_model(random_sample, index)

    return img_output


In [23]:
# Create a dataloader itterable object
dataiter = iter(train_loader)
# Sample from the itterable object
latents = next(dataiter)

print(f"latents={latents.shape} latent_size={latent_size}")

latents=torch.Size([32, 4, 32, 32]) latent_size=32


In [24]:
# network
dit = DiT(latent_size, channels_in=latents.shape[1], patch_size=patch_size, 
            hidden_size=768, num_layers=10, num_heads=8).to(device)

# Adam optimizer
optimizer = optim.Adam(dit.parameters(), lr=lr)

# Scaler for mixed precision training
# scaler = torch.cuda.amp.GradScaler()
scaler = torch.amp.GradScaler('cuda')

alphas = torch.flip(cosine_alphas_bar(timesteps), (0,)).to(device)

In [29]:
# Load Checkpoint
def show(dit):
    latent_noise = 0.95 * torch.randn(8, 4, latent_size, latent_size, device=device)
    
    with torch.no_grad():
        #with torch.cuda.amp.autocast():
        with torch.amp.autocast("cuda:0"): 
            # fake_latents = cold_diffuse(u_net, latent_noise, total_steps=timesteps)
            fake_latents = cold_diffuse(dit, latent_noise, total_steps=timesteps)
    
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device)
    
    with torch.no_grad():
        #with torch.cuda.amp.autocast():
        with torch.amp.autocast("cuda:0"): 
            fake_sample = vae.decode(fake_latents / 0.18215).sample
            
    plt.figure(figsize = (20, 10))
    
    out = vutils.make_grid(fake_sample[:8].detach().float().cpu(), nrow=4, normalize=True)
    
    _ = plt.imshow(out.numpy().transpose((1, 2, 0)))


    
def load_model(fn="latent_dit_100.pt", device="cpu"):
    cp = torch.load(fn)
    
    # network
    dit_100 = DiT(latent_size, channels_in=latents.shape[1], patch_size=patch_size, 
                hidden_size=768, num_layers=10, num_heads=8).to(device)
    
    
    dit_100.load_state_dict(cp["model_state_dict"])
    
    optimizer_100 = optim.Adam(dit_100.parameters(), lr=lr)
    
    optimizer_100.load_state_dict(cp["optimizer_state_dict"])
    
    loss_log_100 = cp["train_data_logger"]
    
    start_epoch = cp["epoch"]
    
    print(f"start_epoch={start_epoch}")

    return dit_100, optimizer_100, loss_log_100, start_epoch

In [30]:
device

device(type='cpu')

In [31]:
dit, optimizer, loss_log, start_epoch = load_model("latent_dit_2000.pt", device="cpu")

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
show(dit)