下載必要資源

In [None]:
!pip install torch torchvision diffusers transformers tqdm

import所有要用的libriry

In [None]:
import os
import json
import argparse
from glob import glob
import random

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel

Dataset Loader負責將所有圖片rescale以及與文字配對，讀進程式中，有1成的文字會變成空白，用於之後Classifier-free的部分。
training則是用5層的unet，並把loss最小的模型存起來。
batchsize開到128，剛好把memory吃滿，並訓練30個epoch。

In [None]:
# -------- Dataset Loader --------
class TextImageDataset(Dataset):
    def __init__(self, data_root, caption_file, tokenizer, size=256):
        self.data_root = data_root
        self.tokenizer = tokenizer
        with open(caption_file, 'r') as f:
            self.captions = json.load(f)
        self.image_files = glob(os.path.join(data_root, "*.png"))
        self.keys = list(self.captions.keys())
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(256, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_file = self.image_files[idx]
        image = Image.open(img_file).convert("RGB")
        image = self.transform(image)
        key = "_".join(os.path.basename(img_file).split(".")[0].split("_")[:-1])
        given_description = random.choice(self.captions[key]['given_description'])
        caption = f"{given_description} {self.captions[key]['action_description']}"
        if random.random() < 0.1:
            caption = ""
        input_ids = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids
        return {"pixel_values": image, "input_ids": input_ids.squeeze(0)}

  # -------- Training Function --------
def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load pretrained modules
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").eval().to(device)
    text_encoder.requires_grad_(False)
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
    vae.requires_grad_(False)

    if args.resume_ckpt:
        print(f"Resuming U-Net from {args.resume_ckpt}")
        unet = UNet2DConditionModel.from_pretrained(args.resume_ckpt).to(device)
    else:
        unet = UNet2DConditionModel(
            sample_size=32,  # size in latent space after VAE, e.g. 256 // 8
            in_channels=4,
            out_channels=4,
            layers_per_block=2,
            block_out_channels=(160, 320, 640, 1280, 1280),
            down_block_types=(
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D",  # optional: no attention at deepest
            ),
            up_block_types=(
                "UpBlock2D",  # optional: no attention
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
            ),
            cross_attention_dim=512,
          ).to(device)

    unet.train()

    # optimizer = torch.optim.AdamW(unet.parameters(), lr=args.lr)
    optimizer = torch.optim.AdamW(unet.parameters(), lr=args.lr, weight_decay=0.01)

    from transformers import get_scheduler


    noise_scheduler = DDPMScheduler()

    if args.resume_opt:
        print(f"Resuming optimizer from {args.resume_opt}")
        optimizer.load_state_dict(torch.load(args.resume_opt, map_location=device))

    # Load dataset
    dataset = TextImageDataset(args.train_data, args.caption_file, tokenizer)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)

    # Total training steps
    num_epochs = args.epochs
    num_training_steps = num_epochs * len(dataloader)

    lr_scheduler = get_scheduler(
        name="cosine",                  # or "linear"
        optimizer=optimizer,
        num_warmup_steps=100,           # you can also try 500
        num_training_steps=num_training_steps
    )

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.ckpt_dir, exist_ok=True)

    step = 0
    best_loss = float('inf')

    for epoch in range(args.start_epoch, args.epochs):
        pbar = tqdm(dataloader)
        for batch in pbar:
            step += 1
            images = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)

            with torch.no_grad():
                text_emb = text_encoder(input_ids)[0]
                latents = vae.encode(images).latent_dist.sample() * 0.18215

            noise = torch.randn_like(latents)
            # noise = latents
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.size(0),), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_emb).sample

            loss = F.mse_loss(noise_pred, noise)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()

            pbar.set_description(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

        if loss.item() < best_loss:
            best_loss = loss.item()
            print(f"New best loss: {best_loss:.4f}. Saving model...")
            # Save the model and optimizer
            unet.save_pretrained(os.path.join(args.ckpt_dir, f"best_unet"))
            torch.save(optimizer.state_dict(), os.path.join(args.ckpt_dir, f"best_optimizer.pt"))



# Step 1: Define your options manually as a namespace
import argparse

args = argparse.Namespace(
    train_data="/content/drive/MyDrive/GAI_HW6/train",
    caption_file="/content/drive/MyDrive/GAI_HW6/train_info.json",
    start_epoch=0,
    resume_opt="",
    batch_size=128,
    epochs=100,
    lr=1e-4,
    save_every=5000,
    ckpt_dir="/content/drive/MyDrive/GAI_HW6/ckpts",
    output_dir="/content/drive/MyDrive/GAI_HW6/output",
    resume_ckpt=""
)

# Step 2: Make sure train() is imported or defined above
train(args)

生成時用PNDMScheduler，雖然生成時間長，但效果比較好，timesteps = 70。
guidance_scale設定為10，讓模型還是較為多聽test.json的指示。

In [None]:
import json
import os
from tqdm import tqdm
import torch
from PIL import Image
from torchvision import transforms
from diffusers import PNDMScheduler
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from transformers import CLIPTokenizer, CLIPTextModel
from diffusers import DPMSolverMultistepScheduler
from diffusers import EulerDiscreteScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load pretrained modules
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").eval().to(device)
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# Load your trained U-Net model
unet = UNet2DConditionModel.from_pretrained("/content/drive/MyDrive/GAI_HW6/ckpts/best_unet_last").to(device)
unet.eval()

scheduler = PNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
scheduler.set_timesteps(70)

# Load test prompts
with open("/content/drive/MyDrive/GAI_HW6/test.json", "r") as f:
    test_data = json.load(f)

# Output folder
output_dir = "/content/drive/MyDrive/GAI_HW6/results"
os.makedirs(output_dir, exist_ok=True)

# For each prompt
for key, item in tqdm(test_data.items()):
    text_prompt = item["text_prompt"]
    filename = item["image_name"]

    text_input = tokenizer([text_prompt], return_tensors="pt", padding="max_length", max_length=77, truncation=True).input_ids.to(device)
    uncond_input = tokenizer([""], return_tensors="pt", padding="max_length", max_length=77).input_ids.to(device)

    cond_emb = text_encoder(text_input)[0]
    uncond_emb = text_encoder(uncond_input)[0]

    latents = torch.randn((1, 4, 32, 32)).to(device)
    guidance_scale = 10

    for t in scheduler.timesteps:
        with torch.no_grad():
            noise_pred_cond = unet(latents, t, encoder_hidden_states=cond_emb).sample
            noise_pred_uncond = unet(latents, t, encoder_hidden_states=uncond_emb).sample

            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            latents = scheduler.step(noise_pred, t, latents).prev_sample


    # Decode latent to image
    latents = latents / 0.18215
    with torch.no_grad():
        image = vae.decode(latents).sample
        image = (image.clamp(-1, 1) + 1) / 2
        image = transforms.ToPILImage()(image[0].cpu())
        image.save(os.path.join(output_dir, filename))


除了一般的隨機生成，我也常是先用text embeddings比對的方式，找到與test中的text_prompt相近的taining descriptions，再利用training descriptions找到最相近的圖片。

In [None]:
import os
import json
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModel

# -------------------- Setup --------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 512

# Load full CLIP model (text + vision encoders)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", force_download=True)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", force_download=True).eval()

# -------------------- Load Training Descriptions --------------------
with open("/content/drive/MyDrive/GAI_HW6/train_info.json", "r") as f:
    train_data = json.load(f)

train_desc_list = []      # List of all descriptions
train_key_list = []       # Corresponding image key for each description
train_desc_idx_list = []  # Index of description in given_description[]

for key, entry in train_data.items():
    for i, desc in enumerate(entry["given_description"]):
        full_desc = f"{desc} {entry['action_description']}".strip()
        # print(full_desc)
        train_desc_list.append(full_desc)
        train_key_list.append(key)
        train_desc_idx_list.append(i)

# -------------------- Encode Training Descriptions in Batches --------------------
train_embs = []

for i in tqdm(range(0, len(train_desc_list), BATCH_SIZE), desc="Encoding train descriptions"):
    batch_texts = train_desc_list[i:i + BATCH_SIZE]
    with torch.no_grad():
        tokens = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=77)
        embs = model.get_text_features(**tokens)  # ✅ includes final projection
        embs = F.normalize(embs, dim=-1)
        # print(embs)
        train_embs.append(embs.cpu())  # Save to CPU to reduce GPU usage

train_embs = torch.cat(train_embs, dim=0)  # Shape: (N_descriptions, 512)

# -------------------- Load Test Prompts --------------------
with open("/content/drive/MyDrive/GAI_HW6/test.json", "r") as f:
    test_data = json.load(f)

# -------------------- Match Each Prompt --------------------
results = {}

for key, item in tqdm(test_data.items(), desc="Matching test prompts"):
    prompt = item["text_prompt"]
    # print(prompt)

    # Encode prompt
    with torch.no_grad():
        prompt_token = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
        prompt_emb = model.get_text_features(**prompt_token)
        prompt_emb = F.normalize(prompt_emb, dim=-1)
        print(f"prompt_emb {prompt_emb}")

    # Compute cosine similarity
    sims = torch.matmul(train_embs, prompt_emb.T).squeeze(1)  # (N_train,)
    best_idx = torch.argmax(sims).item()

    # Store result
    results[key] = {
        "text_prompt": prompt,
        "matched_train_key": train_key_list[best_idx],
        "matched_description": train_desc_list[best_idx],
        "description_index": train_desc_idx_list[best_idx],
        "similarity_score": sims[best_idx].item()
    }

# -------------------- Save Result --------------------
with open("/content/drive/MyDrive/GAI_HW6/matched_prompt_to_best_description.json", "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print("✅ Matching complete. Results saved to matched_prompt_to_best_description.json")


並用找到的圖片取代隨機生成的起始點，用該圖片的embeddings先再加雜訊，然後再去雜訊，可得到不錯的成果。

In [None]:
import os
import json
from tqdm import tqdm
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from diffusers import PNDMScheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------- Load pretrained modules ----------------
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", force_download=True)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", force_download=True).eval().to(device)

vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
vae.requires_grad_(False)

unet = UNet2DConditionModel.from_pretrained("/content/drive/MyDrive/GAI_HW6/ckpts/best_unet_last").to(device)
unet.eval()

scheduler = PNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
scheduler.set_timesteps(70)

# ---------------- Load prompts ----------------
with open("/content/drive/MyDrive/GAI_HW6/matched_prompt_to_best_description.json", "r") as f:
    test_data = json.load(f)

# ---------------- Output folder ----------------
output_dir = "/content/drive/MyDrive/GAI_HW6/results_img2img"
os.makedirs(output_dir, exist_ok=True)

# ---------------- Image transform ----------------
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

start_timestep = 60  # More noise → more variation

# ---------------- Generation loop ----------------
for key, item in tqdm(test_data.items()):
    # 1. Prepare text prompt
    text_prompt = item["text_prompt"]
    filename = item["matched_train_key"] + "_0.png"

    with torch.no_grad():
        tokens = tokenizer(text_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(device)

        cond_emb = text_encoder(**tokens).last_hidden_state  # Shape: (1, 77, 512)

    # 2. Load correlated image
    hint_path = os.path.join("/content/drive/MyDrive/GAI_HW6/train", filename)
    hint_image = Image.open(hint_path).convert("RGB")
    image_tensor = transform(hint_image).unsqueeze(0).to(device)

    # 3. Encode to latent
    with torch.no_grad():
        latent = vae.encode(image_tensor).latent_dist.sample() * 0.18215

    # 4. Add noise at timestep
    t_start = scheduler.timesteps[start_timestep]
    noise = torch.randn_like(latent)
    noisy_latent = scheduler.add_noise(latent, noise, t_start)

    # 5. Denoise using U-Net
    latents = noisy_latent
    for t in scheduler.timesteps[start_timestep:]:
        with torch.no_grad():
            noise_pred = unet(latents, t, encoder_hidden_states=cond_emb).sample
            latents = scheduler.step(noise_pred, t, latents).prev_sample

    # 6. Decode from latent
    latents = latents / 0.18215
    with torch.no_grad():
        image = vae.decode(latents).sample
        image = (image.clamp(-1, 1) + 1) / 2
        image = transforms.ToPILImage()(image[0].cpu())
        image.save(os.path.join(output_dir, key + ".png"))
