In [1]:
!pip install torch torchvision tensorflow transformers diffusers

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
from diffusers import AutoencoderKL
from torchvision import transforms

# Tải VAE từ Stable Diffusion
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="vae")
vae.eval()  # Chuyển sang chế độ inference

# Đóng băng trọng số VAE
for param in vae.parameters():
    param.requires_grad = False

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/716 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

# dataloader

In [3]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import shutil


def preprocess_directory(data_dir):
    """
    Flatten nested directory structure to ensure ImageFolder can handle it.
    """
    print(f"Preprocessing directory: {data_dir}")
    temp_dir = os.path.join(data_dir, "processed")
    os.makedirs(temp_dir, exist_ok=True)

    for root, dirs, files in os.walk(data_dir):
        for folder in dirs:
            folder_path = os.path.join(root, folder)
            if folder_path == temp_dir:
                continue  # Skip the processed directory
            label = os.path.basename(folder_path)  # Use folder name as label
            label_dir = os.path.join(temp_dir, label)
            os.makedirs(label_dir, exist_ok=True)

            for img_file in os.listdir(folder_path):
                img_path = os.path.join(folder_path, img_file)
                if img_file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp')):
                    shutil.copy(img_path, label_dir)

    print(f"Preprocessed directory created at: {temp_dir}")
    return temp_dir


def get_imagenet_dataloader(data_dir, batch_size=32, num_workers=4):
    """
    Load dataset using ImageFolder and return DataLoader and Dataset.
    """
    # Define transformations for the images
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize images to 256x256
        transforms.ToTensor(),          # Convert images to tensor
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
    ])

    # Preprocess the directory to flatten structure
    processed_dir = preprocess_directory(data_dir)

    # Load dataset using ImageFolder
    dataset = torchvision.datasets.ImageFolder(root=processed_dir, transform=transform)

    # Print label mappings for verification
    print("Class-to-Index Mapping:")
    for label, idx in dataset.class_to_idx.items():
        print(f"Label: {label}, Index: {idx}")

    # Create DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )
    return dataloader, dataset


# latent DiffiT

In [4]:
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn as nn
import torch.nn.functional as F
import math

In [5]:
# gọi model vae để tái sử dụng
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
class Encoder(nn.Module):
    def __init__(self, vae, img_size=512, patch_size=16, in_channels=3, hidden_dim=768):
        super().__init__()
        self.vae = vae

        latent_channels = 4
        latent_size = img_size // 8  # 64x64 cho ảnh 512x512

        self.patch_size = patch_size
        self.latent_size = latent_size
        self.hidden_dim = hidden_dim

        # Kiểm tra latent_size chia hết cho patch_size
        assert latent_size % patch_size == 0, "latent_size phải chia hết cho patch_size"
        self.patches_per_side = latent_size // patch_size
        self.num_patches = self.patches_per_side ** 2

        self.patch_embedding = nn.Conv2d(
            in_channels=latent_channels,
            out_channels=hidden_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

        self.position_embedding = nn.Parameter(torch.zeros(1, self.num_patches, hidden_dim))

    # Trong Encoder
    def encode_to_latent(self, noisy_images):
        if noisy_images.ndim == 3:
            noisy_images = noisy_images.unsqueeze(0)
        latents = self.vae.encode_to_latent(noisy_images)  # Gọi phương thức encode_to_latent của VAEWrapper
        return latents

    def forward(self, noisy_images):
        # Chuẩn hóa ảnh về [-1, 1] nếu đầu vào là [0, 1]
        if noisy_images.max() <= 1.0:
            noisy_images = noisy_images * 2 - 1

        latents = self.encode_to_latent(noisy_images)  # [batch, 4, 64, 64]
        patches = self.patch_embedding(latents)        # [batch, hidden_dim, 4, 4]

        # Reshape và thêm positional embedding
        embedded = rearrange(patches, 'b c h w -> b (h w) c')
        embedded = embedded + self.position_embedding

        return embedded  # [batch, num_patches, hidden_dim]

In [7]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [8]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

In [9]:
class TimeEmbedding(nn.Module):
    def __init__(self, time_embed_dim, model_dim):
        super().__init__()
        self.time_embed_dim = time_embed_dim

        # MLP với hàm kích hoạt Swish theo paper
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(time_embed_dim),
            nn.Linear(time_embed_dim, model_dim),
            Swish(),
            nn.Linear(model_dim, model_dim)
        )

    def forward(self, time):
        return self.time_embed(time)

In [10]:
class LabelEmbedding(nn.Module):
    def __init__(self, num_classes, embed_dim, model_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, embed_dim)

        # MLP với hàm kích hoạt Swish
        self.projection = nn.Sequential(
            nn.Linear(embed_dim, model_dim),
            Swish(),
            nn.Linear(model_dim, model_dim)
        )

    def forward(self, labels):
        x = self.embedding(labels)
        return self.projection(x)

In [11]:
class TimeDependentMultiHeadAttention(nn.Module):
    """
    Time-dependent Multi-head Self-Attention (TMSA) theo paper:
    Sử dụng công thức:
    qs = xs*Wqs + xt*Wqt
    ks = xs*Wks + xt*Wkt
    vs = xs*Wvs + xt*Wvt
    """
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        # Spatial projection weights (Wqs, Wks, Wvs)
        self.to_q_spatial = nn.Linear(dim, inner_dim, bias=False)
        self.to_k_spatial = nn.Linear(dim, inner_dim, bias=False)
        self.to_v_spatial = nn.Linear(dim, inner_dim, bias=False)

        # Temporal projection weights (Wqt, Wkt, Wvt)
        self.to_q_temporal = nn.Linear(dim, inner_dim, bias=False)
        self.to_k_temporal = nn.Linear(dim, inner_dim, bias=False)
        self.to_v_temporal = nn.Linear(dim, inner_dim, bias=False)

        # Output projection
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

        # Relative position bias
        self.rel_pos_bias = nn.Parameter(torch.zeros(heads, 49, 49))  # Relative position bias (B trong paper)

    def forward(self, x, time_emb):
        """
        x: [batch_size, seq_len, dim] - Spatial embeddings (xs)
        time_emb: [batch_size, dim] - Time token (xt)
        """
        batch_size, seq_len, _ = x.shape
        h = self.heads

        # 1. Tính phần spatial của queries, keys, values (xs*Wqs, xs*Wks, xs*Wvs)
        q_spatial = self.to_q_spatial(x).reshape(batch_size, seq_len, h, -1).permute(0, 2, 1, 3)  # [b, h, seq, d_head]
        k_spatial = self.to_k_spatial(x).reshape(batch_size, seq_len, h, -1).permute(0, 2, 1, 3)  # [b, h, seq, d_head]
        v_spatial = self.to_v_spatial(x).reshape(batch_size, seq_len, h, -1).permute(0, 2, 1, 3)  # [b, h, seq, d_head]

        # 2. Tính phần temporal của queries, keys, values (xt*Wqt, xt*Wkt, xt*Wvt)
        time_emb_expanded = time_emb.unsqueeze(1)  # [batch_size, 1, dim]
        q_temporal = self.to_q_temporal(time_emb_expanded).reshape(batch_size, 1, h, -1).permute(0, 2, 1, 3)  # [b, h, 1, d_head]
        k_temporal = self.to_k_temporal(time_emb_expanded).reshape(batch_size, 1, h, -1).permute(0, 2, 1, 3)  # [b, h, 1, d_head]
        v_temporal = self.to_v_temporal(time_emb_expanded).reshape(batch_size, 1, h, -1).permute(0, 2, 1, 3)  # [b, h, 1, d_head]

        # 3. Tính tổng theo công thức trong paper (qs = xs*Wqs + xt*Wqt)
        # Broadcast q_temporal, k_temporal, v_temporal để cộng với mỗi token trong chuỗi
        q_temporal = q_temporal.expand(-1, -1, seq_len, -1)
        k_temporal = k_temporal.expand(-1, -1, seq_len, -1)
        v_temporal = v_temporal.expand(-1, -1, seq_len, -1)

        q = q_spatial + q_temporal  # qs = xs*Wqs + xt*Wqt
        k = k_spatial + k_temporal  # ks = xs*Wks + xt*Wkt
        v = v_spatial + v_temporal  # vs = xs*Wvs + xt*Wvt

        # 4. Tính attention với relative position bias (B)
        # Softmax((QK^T)/sqrt(d) + B)V theo công thức (6) trong paper
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # QK^T/sqrt(d)

        # Thêm relative position bias
        if seq_len <= 64:  # theo research căn 64 trả ra kết quả ok hơn //03/30/2025//
            bias = self.rel_pos_bias[:, :seq_len, :seq_len]
            dots = dots + bias.unsqueeze(0)  # Thêm bias vào attention scores

        attn = self.attend(dots)  # Softmax
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)  # Nhân với values
        out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)

        return self.to_out(out)

In [12]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        # MLP cho spatial features
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            Swish(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

        # MLP cho time conditioning
        self.time_mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x, time_emb):
        # Tạo time conditioning
        time_out = self.time_mlp(time_emb).unsqueeze(1)  # [batch_size, 1, dim]

        # Áp dụng gated mechanism
        return self.net(x) + time_out  # Additive conditioning

In [13]:
class LatentDiffiTTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        heads=8,
        dim_head=64,
        mlp_dim=None,
        dropout=0.0
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = TimeDependentMultiHeadAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)

        mlp_dim = mlp_dim or (dim * 4)
        self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)

    def forward(self, x, time_emb):
        # LayerNorm và TMSA với residual connection
        x = x + self.attn(self.norm1(x), time_emb)

        # LayerNorm và MLP với residual connection
        x = x + self.mlp(self.norm2(x), time_emb)

        return x

In [14]:
class LatentDiffiTTransformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads=8,
        dim_head=64,
        mlp_dim=None,
        dropout=0.0,
        time_embed_dim=None,
        label_embed_dim=None,
        num_classes=1000
    ):
        super().__init__()

        # Thông số
        self.dim = dim
        time_embed_dim = time_embed_dim or dim * 4
        label_embed_dim = label_embed_dim or dim

        # Time và Label Embedding với MLP và Swish activation
        self.time_embedding = TimeEmbedding(time_embed_dim, dim)
        self.label_embedding = LabelEmbedding(num_classes, label_embed_dim, dim)

        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            LatentDiffiTTransformerBlock(
                dim=dim,
                heads=heads,
                dim_head=dim_head,
                mlp_dim=mlp_dim,
                dropout=dropout
            ) for _ in range(depth)
        ])

        # Final layer norm
        self.final_norm = nn.LayerNorm(dim)

    def combine_embeddings(self, time_emb, label_emb=None):
        # Kết hợp time embedding và label embedding (nếu có)
        if label_emb is not None:
            combined_emb = time_emb + label_emb
        else:
            combined_emb = time_emb
        return combined_emb

    def forward(self, x, time, labels=None):
        # Tạo time token từ timestep
        time_emb = self.time_embedding(time)

        # Tạo và kết hợp với label embedding nếu có
        if labels is not None:
            label_emb = self.label_embedding(labels)
            combined_emb = self.combine_embeddings(time_emb, label_emb)
        else:
            combined_emb = time_emb

        # Đi qua từng transformer block, time token được truyền qua mỗi block
        for block in self.transformer_blocks:
            x = block(x, combined_emb)

        # Final layer norm
        x = self.final_norm(x)

        return x

In [15]:
class Unpatchify(nn.Module):
    def __init__(self, patch_size, hidden_dim):
        super().__init__()
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

    def forward(self, x):
        """
        x: (B, L, C) - batch size, số lượng patch, số kênh (C = hidden_dim)
        return: (B, C, H, W) - grid với số kênh là hidden_dim
        """
        B, L, C = x.shape
        assert C == self.hidden_dim, f"Số kênh đầu vào phải là {self.hidden_dim}, nhận được {C}"
        patches_per_side = int(math.sqrt(L))
        H = W = patches_per_side * self.patch_size

        # Reshape từ chuỗi patch về grid
        x = x.reshape(B, patches_per_side, patches_per_side, C)

        # Chuyển từ [B, patches_per_side, patches_per_side, C] sang [B, C, H, W]
        x = x.permute(0, 3, 1, 2)  # [B, C, patches_per_side, patches_per_side]
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)  # Upsample về [B, C, H, W]

        return x

In [16]:
class Decoder(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels=4):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x):
        """
        x: (B, C, H, W) - đầu vào từ unpatchify, C = hidden_dim
        return: (B, 4, H, W) - nhiễu dự đoán trong latent space
        """
        x = self.decoder(x)
        return x

# train

In [17]:
import torch
import torch.optim as optim
import torchvision.utils as vutils
import torch.optim as optim

def train_diffit(pipeline, dataloader, num_epochs, num_timesteps, device, learning_rate):
    optimizer = optim.Adam(pipeline.parameters(), lr= learning_rate)
    criterion = torch.nn.MSELoss()

    for epoch in range(num_epochs):
        pipeline.train()
        total_loss = 0
        for batch_idx, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            labels = labels.to(device)
            batch_size = images.shape[0]

            # Tạo timestep ngẫu nhiên
            timesteps = torch.randint(0, num_timesteps, (batch_size,), device=device).float()

            # Mã hóa ảnh thành latent space
            with torch.no_grad():
                latents = pipeline.encoder.encode_to_latent(images)  # [B, 4, H, W]

            # Thêm nhiễu vào latent space (DDPM đơn giản)
            noise = torch.randn_like(latents)
            t = timesteps / num_timesteps
            noisy_latents = (1 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise

            # Dự đoán nhiễu trong latent space
            optimizer.zero_grad()
            predicted_noise = pipeline(noisy_latents, timesteps, labels)  # [B, 4, H, W]
            loss = criterion(predicted_noise, noise)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if batch_idx % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

        torch.save(pipeline.state_dict(), f"latent_diffit_epoch_{epoch+1}.pth")

        pipeline.eval()
        with torch.no_grad():
            num_samples = 8
            generated_images = pipeline.sample(num_samples=num_samples, timesteps=num_timesteps, device=device)
            generated_images = (generated_images + 1) / 2
            vutils.save_image(generated_images, f"generated_images/epoch_{epoch+1}.png", nrow=4)

# main

pipeline


In [18]:
import torch
import torch.nn as nn

class LatentDiffiTPipeline(nn.Module):
    def __init__(
        self,
        vae,
        img_size=256,
        patch_size=16,
        in_channels=3,
        hidden_dim=1152,
        depth=30,
        heads=16,
        dim_head=64,
        mlp_dim=None,
        dropout=0.0,
        time_embed_dim=None,
        label_embed_dim=None,
        num_classes=1000
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        if mlp_dim is None:
            mlp_dim = hidden_dim * 4

        # Khởi tạo các thành phần
        self.encoder = Encoder(
            vae=vae,
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            hidden_dim=hidden_dim
        )
        self.transformer = LatentDiffiTTransformer(
            dim=hidden_dim,
            depth=depth,
            heads=heads,
            dim_head=dim_head,
            mlp_dim=mlp_dim,
            dropout=dropout,
            time_embed_dim=time_embed_dim,
            label_embed_dim=label_embed_dim,
            num_classes=num_classes
        )
        self.unpatchify = Unpatchify(patch_size=patch_size, hidden_dim=self.hidden_dim)  # Thêm hidden_dim
        self.decoder = Decoder(
            in_channels=hidden_dim,
            hidden_dim=hidden_dim // 2,
            out_channels=4  # Khớp với số kênh của latent space
        )
        self.vae = vae

    def forward(self, noisy_latents, timesteps, labels=None):
        noisy_latents[:, :3, :, :] *= (1 + noisy_latents[:, :3, :, :]) # Apply classifier-free guidance to the first three input channels
        # Đầu vào là noisy_latents [B, 4, H, W], không cần mã hóa lại
        # Chuyển noisy_latents thành dạng phù hợp để đưa vào encoder
        embedded = self.encoder.patch_embedding(noisy_latents)  # [B, hidden_dim, H/patch_size, W/patch_size]
        embedded = rearrange(embedded, 'b c h w -> b (h w) c') + self.encoder.position_embedding
        transformer_output = self.transformer(embedded, timesteps, labels)
        unpatched = self.unpatchify(transformer_output)  # [B, hidden_dim, H, W]
        predicted_noise = self.decoder(unpatched)  # [B, 4, H, W]
        return predicted_noise

    def sample(self, num_samples, timesteps, device, labels=None):
        latent_size = self.img_size // 8
        latents = torch.randn(num_samples, 4, latent_size, latent_size).to(device)
        timesteps_tensor = torch.arange(timesteps - 1, -1, -1, device=device).float()
        if labels is None:
            labels = torch.randint(0, self.num_classes, (num_samples,), device=device)

        for t in timesteps_tensor:
            t_batch = t.repeat(num_samples).float()
            predicted_noise = self.forward(latents, t_batch, labels)
            # Cập nhật latents theo quy trình diffusion (DDPM đơn giản)
            alpha = 1 - t / timesteps  # Đây là một ví dụ đơn giản, cần điều chỉnh theo lịch trình DDPM thực tế
            latents = (latents - (1 - alpha) * predicted_noise) / alpha

        with torch.no_grad():
            images = self.vae.decode_from_latent(latents.permute(0, 2, 3, 1))
        return images

In [19]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [24]:
import os

def main():
    # Mount Google Drive
    drive.mount('/content/drive', force_remount=True)
    os.makedirs("generated_images", exist_ok=True)

    # Check data directory
    data_dir = "/content/drive/MyDrive/DiffiT_latent_space/image-net-256/archive"
    if not os.path.exists(data_dir):
        print(f"Error: Path {data_dir} does not exist!")
        # Show subdirectories for debugging
        parent_dir = os.path.dirname(data_dir)
        if os.path.exists(parent_dir):
            print(f"Parent directory {parent_dir} exists. Subdirectories are:")
            print(os.listdir(parent_dir))
        return

    # Load data
    batch_size = 32
    dataloader, dataset = get_imagenet_dataloader(data_dir, batch_size=batch_size)
    num_classes = len(dataset.classes)

    # Print number of classes and verify labels
    print(f"Number of classes: {num_classes}")
    print(f"Labels: {dataset.classes}")

    # Initialize pipeline
    pipeline = LatentDiffiTPipeline(
        vae=vae,
        img_size=256,
        patch_size=32,
        in_channels=3,
        hidden_dim=768,
        depth=12,
        heads=8,
        dim_head=64,
        num_classes=num_classes
    ).to(device)

    # Training
    num_epochs = 10
    num_timesteps = 1000
    learning_rate = 0.00003
    train_diffit(pipeline, dataloader, num_epochs, num_timesteps, device, learning_rate)

In [25]:
if __name__ == "__main__":
    main()

Mounted at /content/drive
Preprocessing directory: /content/drive/MyDrive/DiffiT_latent_space/image-net-256/archive


KeyboardInterrupt: 