In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#/content/drive/My Drive/Project/chest_xray/train

In [None]:
# Install required libraries
!pip install diffusers transformers accelerate
!pip install --upgrade diffusers  # Ensure the latest version
!pip install huggingface_hub

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import numpy as np
from PIL import Image
import os

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Dataset path
dataset_path = '/content/drive/My Drive/Project/chest_xray/train'

# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # Converts to [0, 1]
    transforms.Normalize([0.5], [0.5]),  # Normalizes to [-1, 1]
])

Using device: cuda


In [None]:
# Custom Dataset
from torch.utils.data import Dataset
import glob

class ChestXRayDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        # Recursively search for images in subdirectories
        self.image_paths = glob.glob(os.path.join(image_folder, '**', '*.png'), recursive=True) + \
                           glob.glob(os.path.join(image_folder, '**', '*.jpg'), recursive=True) + \
                           glob.glob(os.path.join(image_folder, '**', '*.jpeg'), recursive=True)
        self.transform = transform

        print(f"Found {len(self.image_paths)} images in {image_folder}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

dataset = ChestXRayDataset(dataset_path, transform=transform)

# Check if dataset is empty
if len(dataset) == 0:
    raise ValueError(f"No images found in {dataset_path}. Please check the path and file extensions.")

data_loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Load pre-trained models from Stable Diffusion
model_name = "stabilityai/stable-diffusion-2-1-base"

pretrained_unet = UNet2DConditionModel.from_pretrained(
    model_name, subfolder="unet"
).to(device)
pretrained_unet.eval()

vae = AutoencoderKL.from_pretrained(
    model_name, subfolder="vae"
).to(device)
vae.eval()

tokenizer = CLIPTokenizer.from_pretrained(
    model_name, subfolder="tokenizer"
)

text_encoder = CLIPTextModel.from_pretrained(
    model_name, subfolder="text_encoder"
).to(device)
text_encoder.eval()


In [None]:
# Freeze parameters of the pre-trained models
for param in pretrained_unet.parameters():
    param.requires_grad = False

for param in vae.parameters():
    param.requires_grad = False

for param in text_encoder.parameters():
    param.requires_grad = False

NameError: name 'pretrained_unet' is not defined

In [None]:
# Adaptation layer definition with residual connections
class AdaptationLayer(nn.Module):
    def __init__(self, channels):
        super(AdaptationLayer, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = out + residual
        return out

In [None]:
# Function to get the number of channels from the model
def get_layer_output_channels():
    dummy_input = torch.randn(1, 4, 64, 64).to(device)
    dummy_timestep = torch.tensor([0], dtype=torch.long).to(device)
    dummy_encoder_hidden_states = torch.randn(1, 77, pretrained_unet.config.cross_attention_dim).to(device)

    output_channels = {}
    output_sizes = {}

    def hook_fn(name):
        def hook(module, input, output):
            if isinstance(output, tuple):
                out = output[0]
            else:
                out = output
            output_channels[name] = out.shape[1]
            output_sizes[name] = out.shape[2:]
        return hook

    hooks = []
    hooks.append(pretrained_unet.down_blocks[0].register_forward_hook(hook_fn('down_block_0')))
    hooks.append(pretrained_unet.down_blocks[1].register_forward_hook(hook_fn('down_block_1')))
    hooks.append(pretrained_unet.down_blocks[2].register_forward_hook(hook_fn('down_block_2')))
    hooks.append(pretrained_unet.mid_block.register_forward_hook(hook_fn('mid_block')))
    hooks.append(pretrained_unet.up_blocks[2].register_forward_hook(hook_fn('up_block_2')))
    hooks.append(pretrained_unet.up_blocks[1].register_forward_hook(hook_fn('up_block_1')))
    hooks.append(pretrained_unet.up_blocks[0].register_forward_hook(hook_fn('up_block_0')))

    with torch.no_grad():
        _ = pretrained_unet(dummy_input, dummy_timestep, dummy_encoder_hidden_states)

    for hook in hooks:
        hook.remove()

    return output_channels, output_sizes

In [None]:
!pip install git+https://github.com/huggingface/transformers.git

from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

import requests
from PIL import Image

url = 'https://media.newyorker.com/cartoons/63dc6847be24a6a76d90eb99/master/w_1160,c_limit/230213_a26611_838.jpg'
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
display(image.resize((596, 437)))


inputs = processor(image, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)


print("Pixel values shape:", inputs.get("pixel_values", None).shape)
print("Input IDs shape:", inputs.get("input_ids", None).shape)
print(generated_text)




In [None]:
# Get actual output channels and sizes from the model
output_channels, output_sizes = get_layer_output_channels()
print("Output channels of each block:", output_channels)
print("Output sizes of each block:", output_sizes)

# Initialize adaptation layers with correct channel sizes
adaptation_layers = nn.ModuleDict()
for name, channels in output_channels.items():
    adaptation_layers[name] = AdaptationLayer(channels).to(device)

# Choose a feature map with larger spatial dimensions
# Let's use 'down_block_1' which has larger spatial dimensions
discriminator_input_layer = 'down_block_1'  # You can change this to 'down_block_0' for even larger size


Output channels of each block: {'down_block_0': 320, 'down_block_1': 640, 'down_block_2': 1280, 'mid_block': 1280, 'up_block_0': 1280, 'up_block_1': 1280, 'up_block_2': 640}
Output sizes of each block: {'down_block_0': torch.Size([32, 32]), 'down_block_1': torch.Size([16, 16]), 'down_block_2': torch.Size([8, 8]), 'mid_block': torch.Size([8, 8]), 'up_block_0': torch.Size([16, 16]), 'up_block_1': torch.Size([32, 32]), 'up_block_2': torch.Size([64, 64])}


In [None]:
# Discriminator network with spectral normalization
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0),
        )

    def forward(self, x):
        x = self.main(x)
        return x.view(x.size(0), -1)

discriminator = Discriminator(output_channels[discriminator_input_layer]).to(device)


In [None]:
# Loss functions and optimizers
bce_loss = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(adaptation_layers.parameters(), lr=1e-4)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-6)


In [None]:
# Scheduler
scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")


scheduler/scheduler_config.json:   0%|          | 0.00/346 [00:00<?, ?B/s]

In [None]:
# Function to encode text prompts
def encode_prompts(prompts):
    text_inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
    return text_embeddings

In [None]:
# Modified hook function
def create_hook(name):
    def hook(module, input, output):
        if isinstance(output, tuple):
            # Process the first element
            out0 = output[0]
            original_features[name] = out0.detach()
            adaptation_layer = adaptation_layers[name]
            adapted_out0 = adaptation_layer(out0)
            adapted_features[name] = adapted_out0
            # Return a tuple with the adapted output and the rest unchanged
            output = (adapted_out0,) + output[1:]
        else:
            # Process the output
            original_features[name] = output.detach()
            adaptation_layer = adaptation_layers[name]
            adapted_output = adaptation_layer(output)
            adapted_features[name] = adapted_output
            output = adapted_output
        return output
    return hook

In [None]:
# Training loop
epochs = 10  # Adjust as needed
adv_loss_weight = 0.1
cons_loss_weight = 0.0001

for epoch in range(epochs):
    for batch_idx, images in enumerate(data_loader):
        images = images.to(device)
        batch_size = images.size(0)

        # Encode images to latent space
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215  # Scaling factor from Stable Diffusion

        # Generate random noise
        noise = torch.randn_like(latents).to(device)

        # Sample random timesteps
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=device).long()

        # Get noise corresponding to timesteps
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        # Encode text prompts
        encoder_hidden_states = encode_prompts(["a chest X-ray image"] * batch_size)

        # Define dictionaries to store features for consistency loss
        adapted_features = {}
        original_features = {}

        # Register forward hooks
        hooks = []
        hooks.append(pretrained_unet.down_blocks[0].register_forward_hook(create_hook('down_block_0')))
        hooks.append(pretrained_unet.down_blocks[1].register_forward_hook(create_hook('down_block_1')))
        hooks.append(pretrained_unet.down_blocks[2].register_forward_hook(create_hook('down_block_2')))
        hooks.append(pretrained_unet.mid_block.register_forward_hook(create_hook('mid_block')))
        hooks.append(pretrained_unet.up_blocks[2].register_forward_hook(create_hook('up_block_2')))
        hooks.append(pretrained_unet.up_blocks[1].register_forward_hook(create_hook('up_block_1')))
        hooks.append(pretrained_unet.up_blocks[0].register_forward_hook(create_hook('up_block_0')))

        # Forward pass through the model with hooks
        model_pred = pretrained_unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Compute losses
        # Diffusion loss
        target = noise
        diffusion_loss = F.mse_loss(model_pred, target)

        # Consistency loss
        cons_loss = 0
        for key in adapted_features.keys():
            # Ensure dimensions match
            adapted_feat = adapted_features[key]
            original_feat = original_features[key]

            # Adjust channels if needed
            if adapted_feat.shape[1] != original_feat.shape[1]:
                min_channels = min(adapted_feat.shape[1], original_feat.shape[1])
                adapted_feat = adapted_feat[:, :min_channels, :, :]
                original_feat = original_feat[:, :min_channels, :, :]

            cons_loss += F.mse_loss(adapted_feat, original_feat)

        # Adversarial loss
        # Use the feature from the selected layer as input to the discriminator
        disc_input_real = original_features[discriminator_input_layer]
        disc_input_fake = adapted_features[discriminator_input_layer]

        disc_optimizer.zero_grad()
        # Discriminator loss on real data
        disc_real = discriminator(disc_input_real.detach())
        real_labels = torch.ones_like(disc_real, device=device)
        disc_loss_real = bce_loss(disc_real, real_labels)

        # Discriminator loss on fake data
        disc_fake = discriminator(disc_input_fake.detach())
        fake_labels = torch.zeros_like(disc_fake, device=device)
        disc_loss_fake = bce_loss(disc_fake, fake_labels)

        disc_loss = (disc_loss_real + disc_loss_fake) / 2
        disc_loss.backward()
        disc_optimizer.step()

        # Generator adversarial loss
        adv_output = discriminator(disc_input_fake)
        adv_loss = bce_loss(adv_output, real_labels)

        # Total loss
        total_loss = diffusion_loss + cons_loss_weight * cons_loss + adv_loss_weight * adv_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Logging
        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], Total Loss: {total_loss.item():.4f}, "
                  f"Diffusion Loss: {diffusion_loss.item():.4f}, Consistency Loss: {cons_loss.item():.4f}, "
                  f"Adversarial Loss: {adv_loss.item():.4f}")

# Save the trained adaptation layers
os.makedirs('saved_models', exist_ok=True)
torch.save(adaptation_layers.state_dict(), 'saved_models/adaptation_layers.pth')
torch.save(discriminator.state_dict(), 'saved_models/discriminator.pth')



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch [31/50], Batch [2200], Total Loss: 0.1014, Diffusion Loss: 0.0309, Consistency Loss: 12.0525, Adversarial Loss: 0.6929
Epoch [31/50], Batch [2210], Total Loss: 0.2719, Diffusion Loss: 0.2014, Consistency Loss: 13.5040, Adversarial Loss: 0.6921
Epoch [31/50], Batch [2220], Total Loss: 0.2620, Diffusion Loss: 0.1914, Consistency Loss: 13.2393, Adversarial Loss: 0.6927
Epoch [31/50], Batch [2230], Total Loss: 0.1708, Diffusion Loss: 0.0998, Consistency Loss: 16.1581, Adversarial Loss: 0.6941
Epoch [31/50], Batch [2240], Total Loss: 0.3831, Diffusion Loss: 0.3119, Consistency Loss: 17.0882, Adversarial Loss: 0.6947
Epoch [31/50], Batch [2250], Total Loss: 0.1099, Diffusion Loss: 0.0393, Consistency Loss: 13.5381, Adversarial Loss: 0.6927
Epoch [31/50], Batch [2260], Total Loss: 0.1257, Diffusion Loss: 0.0549, Consistency Loss: 14.3212, Adversarial Loss: 0.6934
Epoch [31/50], Batch [2270], Total Loss: 0.1997, Diffusion L

In [None]:
# Generation Function
# Generation Function
def generate_images(prompt, num_images=2, guidance_scale=7.5):
    pretrained_unet.eval()
    vae.eval()
    text_encoder.eval()
    adaptation_layers.eval()

    # Encode text prompt
    encoder_hidden_states_cond = encode_prompts([prompt] * num_images)
    encoder_hidden_states_uncond = encode_prompts([""] * num_images)
    encoder_hidden_states_full = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond])

    # Prepare latent variables
    latents = torch.randn((num_images, pretrained_unet.config.in_channels, 64, 64)).to(device)

    scheduler.set_timesteps(50)  # Reduce number of inference steps for speed
    timesteps = scheduler.timesteps

    for t in timesteps:
        # Expand latents for classifier-free guidance
        latent_model_input = torch.cat([latents] * 2)

        # Predict noise with adaptation layers
        # Hook function to apply adaptation layers
        def apply_adaptation(name):
            def hook(module, input, output):
                if isinstance(output, tuple):
                    out0 = output[0]
                    adaptation_layer = adaptation_layers[name]
                    adapted_out0 = adaptation_layer(out0)
                    output = (adapted_out0,) + output[1:]
                else:
                    adaptation_layer = adaptation_layers[name]
                    output = adaptation_layer(output)
                return output
            return hook

        # Register hooks
        hooks = []
        hooks.append(pretrained_unet.down_blocks[0].register_forward_hook(apply_adaptation('down_block_0')))
        hooks.append(pretrained_unet.down_blocks[1].register_forward_hook(apply_adaptation('down_block_1')))
        hooks.append(pretrained_unet.down_blocks[2].register_forward_hook(apply_adaptation('down_block_2')))
        hooks.append(pretrained_unet.mid_block.register_forward_hook(apply_adaptation('mid_block')))
        hooks.append(pretrained_unet.up_blocks[2].register_forward_hook(apply_adaptation('up_block_2')))
        hooks.append(pretrained_unet.up_blocks[1].register_forward_hook(apply_adaptation('up_block_1')))
        hooks.append(pretrained_unet.up_blocks[0].register_forward_hook(apply_adaptation('up_block_0')))

        # Predict noise
        with torch.no_grad():
            noise_pred = pretrained_unet(latent_model_input, t, encoder_hidden_states_full).sample

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Perform guidance
        noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

        # Update latents
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    # Decode latents
    with torch.no_grad():
        images = vae.decode(1 / 0.18215 * latents).sample

    # Convert images to PIL format
    images = (images / 2 + 0.5).clamp(0, 1)
    images = images.cpu().permute(0, 2, 3, 1).numpy()
    pil_images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images]

    return pil_images

# Generate and save images
generated_images = generate_images("a chest X-ray image", num_images=50)

os.makedirs('generated_images', exist_ok=True)
for idx, img in enumerate(generated_images):
    img.save(f'generated_images/generated_image_{idx}.png')
    print(f'Saved generated_images/generated_image_{idx}.png')

os.makedirs('generated_images', exist_ok=True)
for idx, img in enumerate(generated_images):
    img.save(f'generated_images/generated_image_{idx}.png')
    print(f'Saved generated_images/generated_image_{idx}.png')

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.50 GiB. GPU 0 has a total capacity of 14.75 GiB of which 1.58 GiB is free. Process 2941 has 13.17 GiB memory in use. Of the allocated memory 9.69 GiB is allocated by PyTorch, and 3.34 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
#/content/drive/My Drive/Project/chest_xray/train

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Install required libraries
!pip install diffusers transformers accelerate
!pip install --upgrade diffusers  # Ensure the latest version
!pip install huggingface_hub

Collecting diffusers
  Downloading diffusers-0.31.0-py3-none-any.whl.metadata (18 kB)
Downloading diffusers-0.31.0-py3-none-any.whl (2.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m75.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: diffusers
  Attempting uninstall: diffusers
    Found existing installation: diffusers 0.30.3
    Uninstalling diffusers-0.30.3:
      Successfully uninstalled diffusers-0.30.3
Successfully installed diffusers-0.31.0


In [None]:
# -----------------------------------------------
# Adaptive Diffusion Network with Cross-Domain Consistency Regularization (ADN-CDCR)
# Implementation in PyTorch
# Using the Food-101 Dataset
# -----------------------------------------------

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import numpy as np
from PIL import Image
import os

# Set random seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# -----------------------------------------------
# 1. Dataset Preparation
# -----------------------------------------------

# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # Converts to [0, 1]
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Normalizes to [-1, 1]
])

# Load Food-101 Dataset
# Ensure you have enough disk space (~5GB)
dataset = datasets.Food101(
    root='.',  # Change this to your desired root directory
    split='train',
    transform=transform,
    download=True
)

print(f"Total images in the Food-101 dataset: {len(dataset)}")

# DataLoader
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)

# -----------------------------------------------
# 2. Model Loading
# -----------------------------------------------

# Load pre-trained models from Stable Diffusion
model_name = "stabilityai/stable-diffusion-2-1-base"

pretrained_unet = UNet2DConditionModel.from_pretrained(
    model_name, subfolder="unet"
).to(device)
pretrained_unet.eval()

vae = AutoencoderKL.from_pretrained(
    model_name, subfolder="vae"
).to(device)
vae.eval()

tokenizer = CLIPTokenizer.from_pretrained(
    model_name, subfolder="tokenizer"
)

text_encoder = CLIPTextModel.from_pretrained(
    model_name, subfolder="text_encoder"
).to(device)
text_encoder.eval()

# Freeze parameters of the pre-trained models
for param in pretrained_unet.parameters():
    param.requires_grad = False

for param in vae.parameters():
    param.requires_grad = False

for param in text_encoder.parameters():
    param.requires_grad = False

# -----------------------------------------------
# 3. Adaptation Layers and Discriminator
# -----------------------------------------------

# Adaptation layer definition with residual connections
class AdaptationLayer(nn.Module):
    def __init__(self, channels):
        super(AdaptationLayer, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = out + residual
        return out

# Function to get the number of channels from the model
def get_layer_output_channels():
    dummy_input = torch.randn(1, 4, 64, 64).to(device)
    dummy_timestep = torch.tensor([0], dtype=torch.long).to(device)
    dummy_encoder_hidden_states = torch.randn(1, 77, pretrained_unet.config.cross_attention_dim).to(device)

    output_channels = {}
    output_sizes = {}

    def hook_fn(name):
        def hook(module, input, output):
            if isinstance(output, tuple):
                out = output[0]
            else:
                out = output
            output_channels[name] = out.shape[1]
            output_sizes[name] = out.shape[2:]
        return hook

    hooks = []
    for idx, block in enumerate(pretrained_unet.down_blocks):
        hooks.append(block.register_forward_hook(hook_fn(f'down_block_{idx}')))
    hooks.append(pretrained_unet.mid_block.register_forward_hook(hook_fn('mid_block')))
    for idx, block in enumerate(pretrained_unet.up_blocks):
        hooks.append(block.register_forward_hook(hook_fn(f'up_block_{idx}')))

    with torch.no_grad():
        _ = pretrained_unet(dummy_input, dummy_timestep, dummy_encoder_hidden_states)

    for hook in hooks:
        hook.remove()

    return output_channels, output_sizes

# Get actual output channels and sizes from the model
output_channels, output_sizes = get_layer_output_channels()
print("Output channels of each block:", output_channels)
print("Output sizes of each block:", output_sizes)

# Initialize adaptation layers with correct channel sizes
adaptation_layers = nn.ModuleDict()
for name, channels in output_channels.items():
    adaptation_layers[name] = AdaptationLayer(channels).to(device)

# Choose a feature map with larger spatial dimensions for the discriminator
discriminator_input_layer = 'down_block_1'  # Adjust as needed

# Discriminator network with spectral normalization
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0),
        )

    def forward(self, x):
        x = self.main(x)
        return x.view(x.size(0), -1)

discriminator = Discriminator(output_channels[discriminator_input_layer]).to(device)

# -----------------------------------------------
# 4. Loss Functions and Optimizers
# -----------------------------------------------

# Loss functions and optimizers
bce_loss = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(adaptation_layers.parameters(), lr=1e-4)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-6)

# Scheduler
scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")

# -----------------------------------------------
# 5. Helper Functions
# -----------------------------------------------

# Function to encode text prompts
def encode_prompts(prompts):
    text_inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
    return text_embeddings

# Hook function to apply adaptation layers
def create_hook(name):
    def hook(module, input, output):
        if isinstance(output, tuple):
            # Process the first element
            out0 = output[0]
            original_features[name] = out0.detach()
            adaptation_layer = adaptation_layers[name]
            adapted_out0 = adaptation_layer(out0)
            adapted_features[name] = adapted_out0
            # Return a tuple with the adapted output and the rest unchanged
            output = (adapted_out0,) + output[1:]
        else:
            # Process the output
            original_features[name] = output.detach()
            adaptation_layer = adaptation_layers[name]
            adapted_output = adaptation_layer(output)
            adapted_features[name] = adapted_output
            output = adapted_output
        return output
    return hook

# -----------------------------------------------
# 6. Training Loop
# -----------------------------------------------

# Training loop
epochs = 10  # Adjust as needed
adv_loss_weight = 0.1
cons_loss_weight = 0.0001

for epoch in range(epochs):
    print(f"\nStarting Epoch {epoch+1}/{epochs}")
    for batch_idx, (images, labels) in enumerate(data_loader):
        images = images.to(device)
        batch_size = images.size(0)

        # Encode images to latent space
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215  # Scaling factor from Stable Diffusion

        # Generate random noise
        noise = torch.randn_like(latents).to(device)

        # Sample random timesteps
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=device).long()

        # Get noisy latents
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        # Convert labels to class names
        class_names = [dataset.classes[label] for label in labels]

        # Encode text prompts
        encoder_hidden_states = encode_prompts([f"a photo of {class_name}" for class_name in class_names])

        # Define dictionaries to store features for consistency loss
        adapted_features = {}
        original_features = {}

        # Register forward hooks
        hooks = []
        for idx, block in enumerate(pretrained_unet.down_blocks):
            hooks.append(block.register_forward_hook(create_hook(f'down_block_{idx}')))
        hooks.append(pretrained_unet.mid_block.register_forward_hook(create_hook('mid_block')))
        for idx, block in enumerate(pretrained_unet.up_blocks):
            hooks.append(block.register_forward_hook(create_hook(f'up_block_{idx}')))

        # Forward pass through the model with hooks
        model_pred = pretrained_unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Compute losses
        # Diffusion loss
        target = noise
        diffusion_loss = F.mse_loss(model_pred, target)

        # Consistency loss
        cons_loss = 0
        for key in adapted_features.keys():
            # Ensure dimensions match
            adapted_feat = adapted_features[key]
            original_feat = original_features[key]

            # Adjust channels if needed
            if adapted_feat.shape[1] != original_feat.shape[1]:
                min_channels = min(adapted_feat.shape[1], original_feat.shape[1])
                adapted_feat = adapted_feat[:, :min_channels, :, :]
                original_feat = original_feat[:, :min_channels, :, :]

            cons_loss += F.mse_loss(adapted_feat, original_feat)

        # Adversarial loss
        # Use the feature from the selected layer as input to the discriminator
        disc_input_real = original_features[discriminator_input_layer]
        disc_input_fake = adapted_features[discriminator_input_layer]

        disc_optimizer.zero_grad()
        # Discriminator loss on real data
        disc_real = discriminator(disc_input_real.detach())
        real_labels = torch.ones_like(disc_real, device=device)
        disc_loss_real = bce_loss(disc_real, real_labels)

        # Discriminator loss on fake data
        disc_fake = discriminator(disc_input_fake.detach())
        fake_labels = torch.zeros_like(disc_fake, device=device)
        disc_loss_fake = bce_loss(disc_fake, fake_labels)

        disc_loss = (disc_loss_real + disc_loss_fake) / 2
        disc_loss.backward()
        disc_optimizer.step()

        # Generator adversarial loss
        adv_output = discriminator(disc_input_fake)
        adv_loss = bce_loss(adv_output, real_labels)

        # Total loss
        total_loss = diffusion_loss + cons_loss_weight * cons_loss + adv_loss_weight * adv_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Logging
        if batch_idx % 50 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], Total Loss: {total_loss.item():.4f}, "
                  f"Diffusion Loss: {diffusion_loss.item():.4f}, Consistency Loss: {cons_loss.item():.4f}, "
                  f"Adversarial Loss: {adv_loss.item():.4f}")

# Save the trained adaptation layers
os.makedirs('saved_models', exist_ok=True)
torch.save(adaptation_layers.state_dict(), 'saved_models/adaptation_layers_food101.pth')
torch.save(discriminator.state_dict(), 'saved_models/discriminator_food101.pth')

print("Training completed and models saved.")

# -----------------------------------------------
# 7. Image Generation
# -----------------------------------------------

def generate_images(prompts, num_images=2, guidance_scale=7.5):
    pretrained_unet.eval()
    vae.eval()
    text_encoder.eval()
    adaptation_layers.eval()

    # Encode text prompt
    encoder_hidden_states_cond = encode_prompts(prompts)
    encoder_hidden_states_uncond = encode_prompts([""] * num_images)
    encoder_hidden_states_full = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond])

    # Prepare latent variables
    latents = torch.randn((num_images, pretrained_unet.config.in_channels, 64, 64)).to(device)

    scheduler.set_timesteps(50)  # Adjust as needed
    timesteps = scheduler.timesteps

    for t in timesteps:
        # Expand latents for classifier-free guidance
        latent_model_input = torch.cat([latents] * 2)

        # Predict noise with adaptation layers
        # Hook function to apply adaptation layers
        def apply_adaptation(name):
            def hook(module, input, output):
                if isinstance(output, tuple):
                    out0 = output[0]
                    adaptation_layer = adaptation_layers[name]
                    adapted_out0 = adaptation_layer(out0)
                    output = (adapted_out0,) + output[1:]
                else:
                    adaptation_layer = adaptation_layers[name]
                    output = adaptation_layer(output)
                return output
            return hook

        # Register hooks
        hooks = []
        for idx, block in enumerate(pretrained_unet.down_blocks):
            hooks.append(block.register_forward_hook(apply_adaptation(f'down_block_{idx}')))
        hooks.append(pretrained_unet.mid_block.register_forward_hook(apply_adaptation('mid_block')))
        for idx, block in enumerate(pretrained_unet.up_blocks):
            hooks.append(block.register_forward_hook(apply_adaptation(f'up_block_{idx}')))

        # Predict noise
        with torch.no_grad():
            noise_pred = pretrained_unet(latent_model_input, t, encoder_hidden_states_full).sample

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Perform guidance
        noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

        # Update latents
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    # Decode latents
    with torch.no_grad():
        images = vae.decode(1 / 0.18215 * latents).sample

    # Convert images to PIL format
    images = (images / 2 + 0.5).clamp(0, 1)
    images = images.cpu().permute(0, 2, 3, 1).numpy()
    pil_images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images]

    return pil_images

# Generate and save images
class_names = dataset.classes

# Generate images for each class
os.makedirs('generated_images', exist_ok=True)

for class_name in class_names[:10]:  # Generate images for first 10 classes
    print(f"Generating images for: {class_name}")
    prompts = [f"a photo of {class_name}"] * 5  # Generate 5 images per class
    generated_images = generate_images(prompts, num_images=5)

    for idx, img in enumerate(generated_images):
        img.save(f'generated_images/{class_name}_{idx}.png')
        print(f'Saved generated_images/{class_name}_{idx}.png')

print("Image generation completed and images saved.")

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

Using device: cuda
Downloading https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz to ./food-101.tar.gz


100%|██████████| 5.00G/5.00G [00:24<00:00, 208MB/s]


Extracting ./food-101.tar.gz to .
Total images in the Food-101 dataset: 75750


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


unet/config.json:   0%|          | 0.00/911 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/807 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]



text_encoder/config.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

Output channels of each block: {'down_block_0': 320, 'down_block_1': 640, 'down_block_2': 1280, 'down_block_3': 1280, 'mid_block': 1280, 'up_block_0': 1280, 'up_block_1': 1280, 'up_block_2': 640, 'up_block_3': 320}
Output sizes of each block: {'down_block_0': torch.Size([32, 32]), 'down_block_1': torch.Size([16, 16]), 'down_block_2': torch.Size([8, 8]), 'down_block_3': torch.Size([8, 8]), 'mid_block': torch.Size([8, 8]), 'up_block_0': torch.Size([16, 16]), 'up_block_1': torch.Size([32, 32]), 'up_block_2': torch.Size([64, 64]), 'up_block_3': torch.Size([64, 64])}


scheduler/scheduler_config.json:   0%|          | 0.00/346 [00:00<?, ?B/s]


Starting Epoch 1/10
Epoch [1/10], Batch [0], Total Loss: 0.1352, Diffusion Loss: 0.0649, Consistency Loss: 4.5010, Adversarial Loss: 0.6978
Epoch [1/10], Batch [50], Total Loss: 0.3670, Diffusion Loss: 0.2960, Consistency Loss: 4.5400, Adversarial Loss: 0.7057
Epoch [1/10], Batch [100], Total Loss: 0.0811, Diffusion Loss: 0.0104, Consistency Loss: 4.5190, Adversarial Loss: 0.7026
Epoch [1/10], Batch [150], Total Loss: 0.2714, Diffusion Loss: 0.2003, Consistency Loss: 4.6436, Adversarial Loss: 0.7064
Epoch [1/10], Batch [200], Total Loss: 0.1350, Diffusion Loss: 0.0639, Consistency Loss: 4.6126, Adversarial Loss: 0.7066
Epoch [1/10], Batch [250], Total Loss: 0.2314, Diffusion Loss: 0.1602, Consistency Loss: 4.7101, Adversarial Loss: 0.7072
Epoch [1/10], Batch [300], Total Loss: 0.4619, Diffusion Loss: 0.3903, Consistency Loss: 4.7664, Adversarial Loss: 0.7107
Epoch [1/10], Batch [350], Total Loss: 0.2140, Diffusion Loss: 0.1429, Consistency Loss: 4.6959, Adversarial Loss: 0.7066
Epoch 

KeyboardInterrupt: 

In [None]:
# -----------------------------------------------
# Adaptive Diffusion Network with Cross-Domain Consistency Regularization (ADN-CDCR)
# Implementation in PyTorch
# Using the Food-101 Dataset with Guided Diffusion based on Labels
# -----------------------------------------------

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import numpy as np
from PIL import Image
import os

# Set random seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# -----------------------------------------------
# 1. Dataset Preparation
# -----------------------------------------------

# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # Converts to [0, 1]
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Normalizes to [-1, 1]
])

# Load Food-101 Dataset
# Ensure you have enough disk space (~5GB)
dataset = datasets.Food101(
    root='.',  # Change this to your desired root directory
    split='train',
    transform=transform,
    download=True
)

print(f"Total images in the Food-101 dataset: {len(dataset)}")

# DataLoader
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)

# -----------------------------------------------
# 2. Model Loading
# -----------------------------------------------

# Load pre-trained models from Stable Diffusion
model_name = "stabilityai/stable-diffusion-2-1-base"

pretrained_unet = UNet2DConditionModel.from_pretrained(
    model_name, subfolder="unet"
).to(device)
pretrained_unet.eval()

vae = AutoencoderKL.from_pretrained(
    model_name, subfolder="vae"
).to(device)
vae.eval()

tokenizer = CLIPTokenizer.from_pretrained(
    model_name, subfolder="tokenizer"
)

text_encoder = CLIPTextModel.from_pretrained(
    model_name, subfolder="text_encoder"
).to(device)
text_encoder.eval()

# Freeze parameters of the pre-trained models
for param in pretrained_unet.parameters():
    param.requires_grad = False

for param in vae.parameters():
    param.requires_grad = False

for param in text_encoder.parameters():
    param.requires_grad = False

# -----------------------------------------------
# 3. Adaptation Layers and Discriminator
# -----------------------------------------------

# Adaptation layer definition with residual connections
class AdaptationLayer(nn.Module):
    def __init__(self, channels):
        super(AdaptationLayer, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = out + residual
        return out

# Function to get the number of channels from the model
def get_layer_output_channels():
    dummy_input = torch.randn(1, 4, 64, 64).to(device)
    dummy_timestep = torch.tensor([0], dtype=torch.long).to(device)
    dummy_encoder_hidden_states = torch.randn(1, 77, pretrained_unet.config.cross_attention_dim).to(device)

    output_channels = {}
    output_sizes = {}

    def hook_fn(name):
        def hook(module, input, output):
            if isinstance(output, tuple):
                out = output[0]
            else:
                out = output
            output_channels[name] = out.shape[1]
            output_sizes[name] = out.shape[2:]
        return hook

    hooks = []
    for idx, block in enumerate(pretrained_unet.down_blocks):
        hooks.append(block.register_forward_hook(hook_fn(f'down_block_{idx}')))
    hooks.append(pretrained_unet.mid_block.register_forward_hook(hook_fn('mid_block')))
    for idx, block in enumerate(pretrained_unet.up_blocks):
        hooks.append(block.register_forward_hook(hook_fn(f'up_block_{idx}')))

    with torch.no_grad():
        _ = pretrained_unet(dummy_input, dummy_timestep, dummy_encoder_hidden_states)

    for hook in hooks:
        hook.remove()

    return output_channels, output_sizes

# Get actual output channels and sizes from the model
output_channels, output_sizes = get_layer_output_channels()
print("Output channels of each block:", output_channels)
print("Output sizes of each block:", output_sizes)

# Initialize adaptation layers with correct channel sizes
adaptation_layers = nn.ModuleDict()
for name, channels in output_channels.items():
    adaptation_layers[name] = AdaptationLayer(channels).to(device)

# Choose a feature map with larger spatial dimensions for the discriminator
discriminator_input_layer = 'down_block_1'  # Adjust as needed

# Discriminator network with spectral normalization
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            spectral_norm(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=False),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0),
        )

    def forward(self, x):
        x = self.main(x)
        return x.view(x.size(0), -1)

discriminator = Discriminator(output_channels[discriminator_input_layer]).to(device)

# -----------------------------------------------
# 4. Loss Functions and Optimizers
# -----------------------------------------------

# Loss functions and optimizers
bce_loss = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(adaptation_layers.parameters(), lr=1e-4)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-6)

# Scheduler
scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")

# -----------------------------------------------
# 5. Helper Functions
# -----------------------------------------------

# Function to encode text prompts
def encode_prompts(prompts):
    text_inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
    return text_embeddings

# Hook function to apply adaptation layers
def create_hook(name):
    def hook(module, input, output):
        if isinstance(output, tuple):
            # Process the first element
            out0 = output[0]
            original_features[name] = out0.detach()
            adaptation_layer = adaptation_layers[name]
            adapted_out0 = adaptation_layer(out0)
            adapted_features[name] = adapted_out0
            # Return a tuple with the adapted output and the rest unchanged
            output = (adapted_out0,) + output[1:]
        else:
            # Process the output
            original_features[name] = output.detach()
            adaptation_layer = adaptation_layers[name]
            adapted_output = adaptation_layer(output)
            adapted_features[name] = adapted_output
            output = adapted_output
        return output
    return hook

# -----------------------------------------------
# 6. Training Loop with Guided Diffusion
# -----------------------------------------------

# Training loop
epochs = 10  # Adjust as needed
adv_loss_weight = 0.1
cons_loss_weight = 0.0001

for epoch in range(epochs):
    print(f"\nStarting Epoch {epoch+1}/{epochs}")
    for batch_idx, (images, labels) in enumerate(data_loader):
        images = images.to(device)
        labels = labels.to(device)
        batch_size = images.size(0)

        # Encode images to latent space
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215  # Scaling factor from Stable Diffusion

        # Generate random noise
        noise = torch.randn_like(latents).to(device)

        # Sample random timesteps
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=device).long()

        # Get noisy latents
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        # Convert labels to class names
        class_names = [dataset.classes[label] for label in labels]

        # Encode text prompts
        encoder_hidden_states = encode_prompts([f"a photo of {class_name}" for class_name in class_names])

        # Define dictionaries to store features for consistency loss
        adapted_features = {}
        original_features = {}

        # Register forward hooks
        hooks = []
        for idx, block in enumerate(pretrained_unet.down_blocks):
            hooks.append(block.register_forward_hook(create_hook(f'down_block_{idx}')))
        hooks.append(pretrained_unet.mid_block.register_forward_hook(create_hook('mid_block')))
        for idx, block in enumerate(pretrained_unet.up_blocks):
            hooks.append(block.register_forward_hook(create_hook(f'up_block_{idx}')))

        # Forward pass through the model with hooks
        model_pred = pretrained_unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Compute losses
        # Diffusion loss
        target = noise
        diffusion_loss = F.mse_loss(model_pred, target)

        # Consistency loss
        cons_loss = 0
        for key in adapted_features.keys():
            # Ensure dimensions match
            adapted_feat = adapted_features[key]
            original_feat = original_features[key]

            # Adjust channels if needed
            if adapted_feat.shape[1] != original_feat.shape[1]:
                min_channels = min(adapted_feat.shape[1], original_feat.shape[1])
                adapted_feat = adapted_feat[:, :min_channels, :, :]
                original_feat = original_feat[:, :min_channels, :, :]

            cons_loss += F.mse_loss(adapted_feat, original_feat)

        # Adversarial loss
        # Use the feature from the selected layer as input to the discriminator
        disc_input_real = original_features[discriminator_input_layer]
        disc_input_fake = adapted_features[discriminator_input_layer]

        disc_optimizer.zero_grad()
        # Discriminator loss on real data
        disc_real = discriminator(disc_input_real.detach())
        real_labels = torch.ones_like(disc_real, device=device)
        disc_loss_real = bce_loss(disc_real, real_labels)

        # Discriminator loss on fake data
        disc_fake = discriminator(disc_input_fake.detach())
        fake_labels = torch.zeros_like(disc_fake, device=device)
        disc_loss_fake = bce_loss(disc_fake, fake_labels)

        disc_loss = (disc_loss_real + disc_loss_fake) / 2
        disc_loss.backward()
        disc_optimizer.step()

        # Generator adversarial loss
        adv_output = discriminator(disc_input_fake)
        adv_loss = bce_loss(adv_output, real_labels)

        # Total loss
        total_loss = diffusion_loss + cons_loss_weight * cons_loss + adv_loss_weight * adv_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Logging
        if batch_idx % 50 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], Total Loss: {total_loss.item():.4f}, "
                  f"Diffusion Loss: {diffusion_loss.item():.4f}, Consistency Loss: {cons_loss.item():.4f}, "
                  f"Adversarial Loss: {adv_loss.item():.4f}")

# Save the trained adaptation layers
os.makedirs('saved_models', exist_ok=True)
torch.save(adaptation_layers.state_dict(), 'saved_models/adaptation_layers_food101.pth')
torch.save(discriminator.state_dict(), 'saved_models/discriminator_food101.pth')

print("Training completed and models saved.")

# -----------------------------------------------
# 7. Image Generation with Guided Diffusion
# -----------------------------------------------

def generate_images(labels, num_images=1, guidance_scale=7.5):
    pretrained_unet.eval()
    vae.eval()
    text_encoder.eval()
    adaptation_layers.eval()

    # Convert labels to class names
    class_names = [dataset.classes[label] for label in labels]

    # Encode text prompts
    prompts = [f"a photo of {class_name}" for class_name in class_names]
    encoder_hidden_states_cond = encode_prompts(prompts)
    encoder_hidden_states_uncond = encode_prompts([""] * num_images)
    encoder_hidden_states_full = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond])

    # Prepare latent variables
    latents = torch.randn((num_images, pretrained_unet.config.in_channels, 64, 64)).to(device)

    scheduler.set_timesteps(50)  # Adjust as needed
    timesteps = scheduler.timesteps

    for t in timesteps:
        # Expand latents for classifier-free guidance
        latent_model_input = torch.cat([latents] * 2)

        # Predict noise with adaptation layers
        # Hook function to apply adaptation layers
        def apply_adaptation(name):
            def hook(module, input, output):
                if isinstance(output, tuple):
                    out0 = output[0]
                    adaptation_layer = adaptation_layers[name]
                    adapted_out0 = adaptation_layer(out0)
                    output = (adapted_out0,) + output[1:]
                else:
                    adaptation_layer = adaptation_layers[name]
                    output = adaptation_layer(output)
                return output
            return hook

        # Register hooks
        hooks = []
        for idx, block in enumerate(pretrained_unet.down_blocks):
            hooks.append(block.register_forward_hook(apply_adaptation(f'down_block_{idx}')))
        hooks.append(pretrained_unet.mid_block.register_forward_hook(apply_adaptation('mid_block')))
        for idx, block in enumerate(pretrained_unet.up_blocks):
            hooks.append(block.register_forward_hook(apply_adaptation(f'up_block_{idx}')))

        # Predict noise
        with torch.no_grad():
            noise_pred = pretrained_unet(latent_model_input, t, encoder_hidden_states_full).sample

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Perform guidance
        noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

        # Update latents
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    # Decode latents
    with torch.no_grad():
        images = vae.decode(1 / 0.18215 * latents).sample

    # Convert images to PIL format
    images = (images / 2 + 0.5).clamp(0, 1)
    images = images.cpu().permute(0, 2, 3, 1).numpy()
    pil_images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images]

    return pil_images

# Generate and save images based on labels
os.makedirs('generated_images', exist_ok=True)

# Example: Generate images for specific labels
# Let's generate images for the first 10 classes
num_classes_to_generate = 10
images_per_class = 5

for label_idx in range(num_classes_to_generate):
    print(f"Generating images for class: {dataset.classes[label_idx]}")
    labels = [label_idx] * images_per_class
    generated_images = generate_images(labels, num_images=images_per_class)

    for idx, img in enumerate(generated_images):
        img.save(f'generated_images/{dataset.classes[label_idx]}_{idx}.png')
        print(f'Saved generated_images/{dataset.classes[label_idx]}_{idx}.png')

print("Image generation completed and images saved.")

Using device: cuda
Total images in the Food-101 dataset: 75750
Output channels of each block: {'down_block_0': 320, 'down_block_1': 640, 'down_block_2': 1280, 'down_block_3': 1280, 'mid_block': 1280, 'up_block_0': 1280, 'up_block_1': 1280, 'up_block_2': 640, 'up_block_3': 320}
Output sizes of each block: {'down_block_0': torch.Size([32, 32]), 'down_block_1': torch.Size([16, 16]), 'down_block_2': torch.Size([8, 8]), 'down_block_3': torch.Size([8, 8]), 'mid_block': torch.Size([8, 8]), 'up_block_0': torch.Size([16, 16]), 'up_block_1': torch.Size([32, 32]), 'up_block_2': torch.Size([64, 64]), 'up_block_3': torch.Size([64, 64])}

Starting Epoch 1/10
Epoch [1/10], Batch [0], Total Loss: 0.1352, Diffusion Loss: 0.0649, Consistency Loss: 4.5010, Adversarial Loss: 0.6978
Epoch [1/10], Batch [50], Total Loss: 0.3670, Diffusion Loss: 0.2960, Consistency Loss: 4.5400, Adversarial Loss: 0.7057


KeyboardInterrupt: 