In [1]:
# Mount Google Drive (for Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
    print("Google Drive mounted successfully")
except:
    IN_COLAB = False
    print("Not running in Google Colab or drive already mounted")

Mounted at /content/drive
Google Drive mounted successfully


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
from PIL import Image
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
import math

# Ignore specific warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [4]:
# Print system info
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device: Tesla T4


In [5]:

# Dataset implementation
class FaceDataset(Dataset):
    def __init__(self, root_dir, transform=None, img_size=224):
        """
        Args:
            root_dir (string): Directory with all the face images.
            transform (callable, optional): Optional transform to be applied on a sample.
            img_size (int): Size of images to resize to
        """
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        print(f"Found {len(self.image_files)} images in {root_dir}")

        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

def get_dataloaders(data_dir, batch_size=32, img_size=224, train_ratio=0.9):
    """
    Create train and validation data loaders
    """
    dataset = FaceDataset(root_dir=data_dir, img_size=img_size)

    # Split into train and validation sets
    train_size = int(train_ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    return train_loader, val_loader

# Encoder implementation (similar to previous code)
class DINOv2FaceEncoder(nn.Module):
    def __init__(self, embedding_dim=768, finetune=True):
        super().__init__()
        # Initialize with pretrained DINOv2 weights
        self.model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

        # Remove the classification head
        self.embedding_dim = embedding_dim
        self.feature_dim = self.model.heads.head.in_features

        # Replace head with identity
        self.model.heads = nn.Identity()

        # Projection layer to desired embedding dimension
        if embedding_dim != self.feature_dim:
            self.projection = nn.Linear(self.feature_dim, embedding_dim)
        else:
            self.projection = nn.Identity()

        # Freeze or unfreeze the model
        if not finetune:
            for param in self.model.parameters():
                param.requires_grad = False

    def forward(self, x):
        features = self.model(x)
        embeddings = self.projection(features)
        return embeddings

In [6]:
# Diffusion model implementation (similar to previous code)
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [7]:
class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64,
                 time_dim=256, embedding_dim=768, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.embedding_dim = embedding_dim

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # Condition embedding
        self.cond_encoder = nn.Sequential(
            nn.Linear(embedding_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # Downsampling and upsampling paths
        self.downs = nn.ModuleList([
            self._down_block(in_channels, base_channels),
            self._down_block(base_channels, base_channels * 2),
            self._down_block(base_channels * 2, base_channels * 4),
            self._down_block(base_channels * 4, base_channels * 8)
        ])

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1),
            nn.BatchNorm2d(base_channels * 8),
            nn.GELU(),
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1),
            nn.BatchNorm2d(base_channels * 8),
            nn.GELU()
        )

        # Upsampling path with skip connections
        self.ups = nn.ModuleList([
            self._up_block(base_channels * 16, base_channels * 4),
            self._up_block(base_channels * 8, base_channels * 2),
            self._up_block(base_channels * 4, base_channels),
            self._up_block(base_channels * 2, base_channels)
        ])

        # Output convolution
        self.final_conv = nn.Conv2d(base_channels, out_channels, kernel_size=1)

    def _down_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )

    def _up_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )

    def forward(self, x, timestep, embedding):
        # Encode time
        t = self.time_mlp(timestep)

        # Encode conditioning
        c = self.cond_encoder(embedding)

        # Combine time and conditioning
        t = t + c

        # Store skip connections
        skip_connections = []

        # Downsampling
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Upsampling with skip connections
        for up, skip in zip(self.ups, reversed(skip_connections)):
            x = torch.cat([x, skip], dim=1)
            x = up(x)

        # Final convolution
        return self.final_conv(x)

In [8]:
class DiffusionModel:
    def __init__(self, model, beta_start=1e-4, beta_end=0.02, num_diffusion_steps=1000, device="cuda"):
        self.model = model.to(device)
        self.device = device
        self.num_diffusion_steps = num_diffusion_steps

        # Define beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_diffusion_steps).to(device)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)

        # Calculations for diffusion
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

    def get_noisy_image(self, x_start, t):
        x_start = x_start.to(self.device)
        noise = torch.randn_like(x_start).to(self.device)

        return (
            self.sqrt_alphas_cumprod[t, None, None, None] * x_start +
            self.sqrt_one_minus_alphas_cumprod[t, None, None, None] * noise,
            noise
        )

    def train_step(self, clean_images, embeddings, optimizer, criterion):
        optimizer.zero_grad()

        # Sample random timesteps
        batch_size = clean_images.shape[0]
        t = torch.randint(0, self.num_diffusion_steps, (batch_size,), device=self.device).long()

        # Get noisy image and true noise to predict
        noisy_images, true_noise = self.get_noisy_image(clean_images, t)

        # Predict noise
        pred_noise = self.model(noisy_images, t, embeddings)

        # Calculate loss
        loss = criterion(pred_noise, true_noise)

        # Backpropagation
        loss.backward()
        optimizer.step()

        return loss.item()

    @torch.no_grad()
    def sample(self, embedding, image_size=224, batch_size=1, channels=3):
        # Start with random noise
        img = torch.randn(batch_size, channels, image_size, image_size).to(self.device)
        embedding = embedding.to(self.device)

        # Iterative denoising
        for i in tqdm(reversed(range(0, self.num_diffusion_steps)), desc='Sampling'):
            timesteps = torch.full((batch_size,), i, device=self.device, dtype=torch.long)

            # Get model prediction (predicted noise)
            predicted_noise = self.model(img, timesteps, embedding)

            # Compute denoise step
            alpha = self.alphas[i]
            alpha_cumprod = self.alphas_cumprod[i]
            beta = self.betas[i]

            if i > 0:
                noise = torch.randn_like(img)
            else:
                noise = torch.zeros_like(img)

            # Compute step according to the formula
            img = (1 / torch.sqrt(alpha)) * (
                img - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
            ) + torch.sqrt(beta) * noise

        # Scale image to [0, 1]
        img = (img.clamp(-1, 1) + 1) / 2

        return img

In [9]:
def train_face_generation(
    data_dir='/content/drive/MyDrive/img_align_celeba',
    output_dir='outputs',
    batch_size=32,
    img_size=224,
    embedding_dim=768,
    encoder_epochs=10,
    diffusion_epochs=10,
    device='cuda'
):
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Get dataloaders
    train_loader, val_loader = get_dataloaders(
        data_dir,
        batch_size=batch_size,
        img_size=img_size
    )

    # Initialize encoder
    encoder = DINOv2FaceEncoder(embedding_dim=embedding_dim, finetune=True).to(device)

    # Initialize diffusion model
    unet = ConditionalUNet(
        in_channels=3,
        out_channels=3,
        base_channels=64,
        time_dim=256,
        embedding_dim=embedding_dim,
        device=device
    ).to(device)

    diffusion = DiffusionModel(unet, num_diffusion_steps=1000, device=device)

    # Optimizers and loss
    encoder_optimizer = optim.AdamW(encoder.parameters(), lr=1e-4)
    diffusion_optimizer = optim.AdamW(diffusion.model.parameters(), lr=2e-4)
    criterion = nn.MSELoss()

    # Training loop
    for epoch in range(diffusion_epochs):
        # Generate embeddings and train diffusion model
        encoder.train()
        diffusion.model.train()
        total_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{diffusion_epochs}"):
            images = batch.to(device)

            # Generate embeddings
            with torch.no_grad():
                embeddings = encoder(images)

            # Train diffusion model
            loss = diffusion.train_step(images, embeddings, diffusion_optimizer, criterion)
            total_loss += loss

        print(f"Epoch {epoch+1}, Average Loss: {total_loss/len(train_loader):.4f}")

        # Periodically generate and save samples
        if (epoch + 1) % 5 == 0:
            diffusion.model.eval()
            with torch.no_grad():
                # Sample some embeddings from validation set
                val_batch = next(iter(val_loader)).to(device)
                val_embeddings = encoder(val_batch)

                # Generate samples
                samples = diffusion.sample(val_embeddings[:16], image_size=img_size)

                # Save samples
                save_path = os.path.join(output_dir, f'samples_epoch_{epoch+1}.png')
                plt.figure(figsize=(10, 10))
                grid = make_grid(samples, nrow=4, normalize=True)
                plt.imshow(grid.permute(1, 2, 0).cpu())
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(save_path)
                plt.close()

    # Save final models
    torch.save({
        'encoder': encoder.state_dict(),
        'diffusion': diffusion.model.state_dict()
    }, os.path.join(output_dir, 'final_model.pth'))

    return encoder, diffusion

def make_grid(tensor, nrow=8, padding=2, normalize=False, scale_each=False, pad_value=0):
    """
    Make a grid of images from a tensor
    """
    from torchvision.utils import make_grid as tv_make_grid
    return tv_make_grid(tensor, nrow=nrow, padding=padding,
                        normalize=normalize,
                        scale_each=scale_each,
                        pad_value=pad_value)

In [10]:
# Main execution
if __name__ == "__main__":
    # Specific path to face dataset
    data_dir = '/content/drive/MyDrive/img_align_celeba'

    # Train the face generation model
    encoder, diffusion_model = train_face_generation(
        data_dir=data_dir,
        output_dir='face_generation_outputs',
        batch_size=32,
        img_size=224,
        embedding_dim=768,
        encoder_epochs=10,
        diffusion_epochs=10
    )

Found 5434 images in /content/drive/MyDrive/img_align_celeba


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:02<00:00, 150MB/s]
Epoch 1/10: 100%|██████████| 153/153 [03:54<00:00,  1.54s/it]


Epoch 1, Average Loss: 0.3194


Epoch 2/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]


Epoch 2, Average Loss: 0.1108


Epoch 3/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]


Epoch 3, Average Loss: 0.0850


Epoch 4/10: 100%|██████████| 153/153 [01:46<00:00,  1.43it/s]


Epoch 4, Average Loss: 0.0744


Epoch 5/10: 100%|██████████| 153/153 [01:46<00:00,  1.43it/s]

Epoch 5, Average Loss: 0.0658



Sampling: 1000it [00:04, 209.75it/s]
Epoch 6/10: 100%|██████████| 153/153 [01:47<00:00,  1.42it/s]


Epoch 6, Average Loss: 0.0605


Epoch 7/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]


Epoch 7, Average Loss: 0.0566


Epoch 8/10: 100%|██████████| 153/153 [01:46<00:00,  1.43it/s]


Epoch 8, Average Loss: 0.0513


Epoch 9/10: 100%|██████████| 153/153 [01:46<00:00,  1.43it/s]


Epoch 9, Average Loss: 0.0510


Epoch 10/10: 100%|██████████| 153/153 [01:46<00:00,  1.43it/s]

Epoch 10, Average Loss: 0.0455



Sampling: 1000it [00:04, 220.22it/s]


In [11]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
from PIL import Image
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
import math

# Ignore specific warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Dataset implementation
class FaceDataset(Dataset):
    def __init__(self, root_dir, transform=None, img_size=224):
        """
        Args:
            root_dir (string): Directory with all the face images.
            transform (callable, optional): Optional transform to be applied on a sample.
            img_size (int): Size of images to resize to
        """
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        print(f"Found {len(self.image_files)} images in {root_dir}")

        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

def get_dataloaders(data_dir, batch_size=32, img_size=224, train_ratio=0.9):
    """
    Create train and validation data loaders
    """
    dataset = FaceDataset(root_dir=data_dir, img_size=img_size)
    train_size = int(train_ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    return train_loader, val_loader

# Encoder implementation
class DINOv2FaceEncoder(nn.Module):
    def __init__(self, embedding_dim=768, finetune=True):
        super().__init__()
        # Initialize with pretrained DINOv2 weights
        self.model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        # Remove the classification head
        self.embedding_dim = embedding_dim
        self.feature_dim = self.model.heads.head.in_features
        # Replace head with identity
        self.model.heads = nn.Identity()
        # Projection layer to desired embedding dimension
        if embedding_dim != self.feature_dim:
            self.projection = nn.Linear(self.feature_dim, embedding_dim)
        else:
            self.projection = nn.Identity()
        # Freeze or unfreeze the model
        if not finetune:
            for param in self.model.parameters():
                param.requires_grad = False

    def forward(self, x):
        features = self.model(x)
        embeddings = self.projection(features)
        return embeddings

# Diffusion model components
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64,
                 time_dim=256, embedding_dim=768, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.embedding_dim = embedding_dim
        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )
        # Condition embedding
        self.cond_encoder = nn.Sequential(
            nn.Linear(embedding_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )
        # Downsampling and upsampling paths
        self.downs = nn.ModuleList([
            self._down_block(in_channels, base_channels),
            self._down_block(base_channels, base_channels * 2),
            self._down_block(base_channels * 2, base_channels * 4),
            self._down_block(base_channels * 4, base_channels * 8)
        ])
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1),
            nn.BatchNorm2d(base_channels * 8),
            nn.GELU(),
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1),
            nn.BatchNorm2d(base_channels * 8),
            nn.GELU()
        )
        # Upsampling with skip connections
        self.ups = nn.ModuleList([
            self._up_block(base_channels * 16, base_channels * 4),
            self._up_block(base_channels * 8, base_channels * 2),
            self._up_block(base_channels * 4, base_channels),
            self._up_block(base_channels * 2, base_channels)
        ])
        # Output convolution
        self.final_conv = nn.Conv2d(base_channels, out_channels, kernel_size=1)
    def _down_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )
    def _up_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )
    def forward(self, x, timestep, embedding):
        # Encode time
        t = self.time_mlp(timestep)
        # Encode conditioning
        c = self.cond_encoder(embedding)
        # Combine time and conditioning
        t = t + c
        # Store skip connections
        skip_connections = []
        # Downsampling
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
        # Bottleneck
        x = self.bottleneck(x)
        # Upsampling with skip connections
        for up, skip in zip(self.ups, reversed(skip_connections)):
            x = torch.cat([x, skip], dim=1)
            x = up(x)
        # Final convolution
        return self.final_conv(x)

class DiffusionModel:
    def __init__(self, model, beta_start=1e-4, beta_end=0.02, num_diffusion_steps=1000, device="cuda"):
        self.model = model.to(device)
        self.device = device
        self.num_diffusion_steps = num_diffusion_steps
        # Define beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_diffusion_steps).to(device)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
        # Calculations for diffusion
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
    def get_noisy_image(self, x_start, t):
        x_start = x_start.to(self.device)
        noise = torch.randn_like(x_start).to(self.device)
        return (
            self.sqrt_alphas_cumprod[t, None, None, None] * x_start +
            self.sqrt_one_minus_alphas_cumprod[t, None, None, None] * noise,
            noise
        )
    def train_step(self, clean_images, embeddings, optimizer, criterion):
        optimizer.zero_grad()
        # Sample random timesteps
        batch_size = clean_images.shape[0]
        t = torch.randint(0, self.num_diffusion_steps, (batch_size,), device=self.device).long()
        # Get noisy image and true noise to predict
        noisy_images, true_noise = self.get_noisy_image(clean_images, t)
        # Predict noise
        pred_noise = self.model(noisy_images, t, embeddings)
        # Calculate loss
        loss = criterion(pred_noise, true_noise)
        # Backpropagation
        loss.backward()
        optimizer.step()
        return loss.item()
    @torch.no_grad()
    def sample(self, embedding, image_size=224, batch_size=1, channels=3):
        # Start with random noise
        img = torch.randn(batch_size, channels, image_size, image_size).to(self.device)
        embedding = embedding.to(self.device)
        # Iterative denoising
        for i in tqdm(reversed(range(0, self.num_diffusion_steps)), desc='Sampling'):
            timesteps = torch.full((batch_size,), i, device=self.device, dtype=torch.long)
            # Get model prediction (predicted noise)
            predicted_noise = self.model(img, timesteps, embedding)
            # Compute denoise step
            alpha = self.alphas[i]
            alpha_cumprod = self.alphas_cumprod[i]
            beta = self.betas[i]
            if i > 0:
                noise = torch.randn_like(img)
            else:
                noise = torch.zeros_like(img)
            # Compute step according to the formula
            img = (1 / torch.sqrt(alpha)) * (
                img - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
            ) + torch.sqrt(beta) * noise
        # Scale image to [0, 1]
        img = (img.clamp(-1, 1) + 1) / 2
        return img

def train_face_generation(
    data_dir='/content/drive/MyDrive/img_align_celeba',
    output_dir='outputs',
    batch_size=32,
    img_size=224,
    embedding_dim=768,
    encoder_epochs=10,
    diffusion_epochs=10,
    device='cuda'
):
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    # Get dataloaders
    train_loader, val_loader = get_dataloaders(data_dir, batch_size=batch_size, img_size=img_size)
    # Initialize encoder
    encoder = DINOv2FaceEncoder(embedding_dim=embedding_dim, finetune=True).to(device)
    # Initialize diffusion model
    unet = ConditionalUNet(
        in_channels=3,
        out_channels=3,
        base_channels=64,
        time_dim=256,
        embedding_dim=embedding_dim,
        device=device
    ).to(device)
    diffusion = DiffusionModel(unet, num_diffusion_steps=1000, device=device)
    # Optimizers and loss
    encoder_optimizer = optim.AdamW(encoder.parameters(), lr=1e-4)
    diffusion_optimizer = optim.AdamW(diffusion.model.parameters(), lr=2e-4)
    criterion = nn.MSELoss()

    # For tracking loss per epoch
    loss_history = []

    # Training loop
    for epoch in range(diffusion_epochs):
        encoder.train()
        diffusion.model.train()
        total_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{diffusion_epochs}"):
            images = batch.to(device)
            # Generate embeddings (without gradient computation for efficiency)
            with torch.no_grad():
                embeddings = encoder(images)
            # Train diffusion model
            loss = diffusion.train_step(images, embeddings, diffusion_optimizer, criterion)
            total_loss += loss

        avg_loss = total_loss / len(train_loader)
        loss_history.append(avg_loss)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Periodically generate and save samples
        if (epoch + 1) % 5 == 0:
            diffusion.model.eval()
            with torch.no_grad():
                # Sample some embeddings from validation set
                val_batch = next(iter(val_loader)).to(device)
                val_embeddings = encoder(val_batch)
                # Generate samples
                samples = diffusion.sample(val_embeddings[:16], image_size=img_size)
                # Save samples
                save_path = os.path.join(output_dir, f'samples_epoch_{epoch+1}.png')
                plt.figure(figsize=(10, 10))
                grid = make_grid(samples, nrow=4, normalize=True)
                plt.imshow(grid.permute(1, 2, 0).cpu())
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(save_path)
                plt.close()

    # Save final models
    torch.save({
        'encoder': encoder.state_dict(),
        'diffusion': diffusion.model.state_dict()
    }, os.path.join(output_dir, 'final_model.pth'))

    # Plot and save the training loss curve
    plt.figure()
    plt.plot(np.arange(1, diffusion_epochs+1), loss_history, marker='o')
    plt.title("Training Loss per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Average Loss")
    plt.grid(True)
    loss_plot_path = os.path.join(output_dir, "training_loss.png")
    plt.savefig(loss_plot_path)
    plt.close()

    return encoder, diffusion

def make_grid(tensor, nrow=8, padding=2, normalize=False, scale_each=False, pad_value=0):
    """
    Make a grid of images from a tensor
    """
    from torchvision.utils import make_grid as tv_make_grid
    return tv_make_grid(tensor, nrow=nrow, padding=padding,
                        normalize=normalize,
                        scale_each=scale_each,
                        pad_value=pad_value)

# Main execution
if __name__ == "__main__":
    # Specific path to face dataset
    data_dir = '/content/drive/MyDrive/img_align_celeba'
    # Train the face generation model
    encoder, diffusion_model = train_face_generation(
        data_dir=data_dir,
        output_dir='face_generation_outputs',
        batch_size=32,
        img_size=224,
        embedding_dim=768,
        encoder_epochs=10,
        diffusion_epochs=10
    )


Found 5434 images in /content/drive/MyDrive/img_align_celeba


Epoch 1/10: 100%|██████████| 153/153 [02:39<00:00,  1.04s/it]


Epoch 1, Average Loss: 0.3279


Epoch 2/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]


Epoch 2, Average Loss: 0.1126


Epoch 3/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]


Epoch 3, Average Loss: 0.0873


Epoch 4/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]


Epoch 4, Average Loss: 0.0737


Epoch 5/10: 100%|██████████| 153/153 [01:47<00:00,  1.42it/s]

Epoch 5, Average Loss: 0.0659



Sampling: 1000it [00:04, 220.37it/s]
Epoch 6/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]


Epoch 6, Average Loss: 0.0576


Epoch 7/10: 100%|██████████| 153/153 [02:37<00:00,  1.03s/it]


Epoch 7, Average Loss: 0.0548


Epoch 8/10: 100%|██████████| 153/153 [01:47<00:00,  1.42it/s]


Epoch 8, Average Loss: 0.0508


Epoch 9/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]


Epoch 9, Average Loss: 0.0462


Epoch 10/10: 100%|██████████| 153/153 [01:47<00:00,  1.43it/s]

Epoch 10, Average Loss: 0.0500



Sampling: 1000it [00:04, 219.42it/s]
