In [6]:
!pip install diffusers transformers accelerate datasets safetensors peft



In [7]:
!pip install -U peft



In [1]:
from huggingface_hub import login
login("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")


In [2]:
import os
import json
import torch
from PIL import Image
from torchvision import transforms
from datasets import Dataset
from tqdm import tqdm
from diffusers import StableDiffusionPipeline
from peft import get_peft_model, LoraConfig, TaskType

# ---------------------
# Config
# ---------------------
model_id = "runwayml/stable-diffusion-v1-5"
data_dir = "/home/rguktongole/LIFE/NULLCLASS/TASK5/flowers"
images_dir = os.path.join(data_dir, "images")
captions_file = os.path.join(data_dir, "captions_blip.json")
output_dir = "home/rguktongole/LIFE/NULLCLASS/TASK5/lora-finetuned-flower"

# ---------------------
# Load and Prepare Dataset
# ---------------------
with open(captions_file, "r") as f:
    caption_data = json.load(f)

dataset_entries = []
for item in caption_data:
    if isinstance(item, dict) and "image" in item and "text" in item:
        image_path = os.path.join(images_dir, item["image"])
        if os.path.exists(image_path):
            dataset_entries.append({
                "image_path": image_path,
                "caption": item["text"]
            })

dataset = Dataset.from_list(dataset_entries)

def preprocess(example):
    image = Image.open(example["image_path"]).convert("RGB").resize((512, 512))
    transform = transforms.ToTensor()
    example["pixel_values"] = transform(image)
    return example

dataset = dataset.map(preprocess)

# ---------------------
# Load Stable Diffusion
# ---------------------
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16
).to("cuda")

# ---------------------
# Apply LoRA with Fallback
# ---------------------
try:
    lora_config = LoraConfig(
        r=4,
        lora_alpha=16,
        target_modules=["to_q", "to_k", "to_v", "to_out.0"],
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.UNET  # Preferred if available
    )
except AttributeError:
    print("⚠️ Warning: TaskType.UNET not found, using TaskType.FEATURE_EXTRACTION instead.")
    lora_config = LoraConfig(
        r=4,
        lora_alpha=16,
        target_modules=["to_q", "to_k", "to_v", "to_out.0"],
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.FEATURE_EXTRACTION
    )

pipe.unet = get_peft_model(pipe.unet, lora_config)

# ---------------------
# Optimizer
# ---------------------
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=1e-5)

# ---------------------
# Training Loop
# ---------------------
for epoch in range(3):  # You can adjust this
    for i in tqdm(range(len(dataset))):
        sample = dataset[i]

        inputs = pipe.tokenizer(
            sample["caption"],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=77
        ).to("cuda")

        image_tensor = sample["pixel_values"].unsqueeze(0).to("cuda")
        with torch.no_grad():
            latents = pipe.vae.encode(image_tensor).latent_dist.sample()

        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, 1000, (1,), device="cuda").long()
        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

        with torch.no_grad():
            encoder_hidden_states = pipe.text_encoder(**inputs).last_hidden_state

        model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample

        loss = torch.nn.functional.mse_loss(model_pred, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            print(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.4f}")

# ---------------------
# Save Fine-Tuned Model
# ---------------------
pipe.unet.save_pretrained(output_dir)
pipe.save_pretrained(output_dir)
print(f"✅ Model saved at {output_dir}")


2025-05-13 18:17:45.783641: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747140466.811461    5142 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747140467.019390    5142 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747140469.125898    5142 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747140469.125981    5142 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747140469.125987    5142 computation_placer.cc:177] computation placer alr

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [None]:
# Save full pipeline components manually
pipe.unet.save_pretrained(os.path.join(output_dir, "unet"))
pipe.vae.save_pretrained(os.path.join(output_dir, "vae"))
pipe.text_encoder.save_pretrained(os.path.join(output_dir, "text_encoder"))
pipe.tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
pipe.scheduler.save_pretrained(os.path.join(output_dir, "scheduler"))
pipe.feature_extractor.save_pretrained(os.path.join(output_dir, "feature_extractor"))
pipe.safety_checker.save_pretrained(os.path.join(output_dir, "safety_checker"))
print(f"✅ All components saved to {output_dir}")


In [None]:
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(
    "./lora-finetuned-flower",  # contains subfolders like "unet", "vae", etc.
    torch_dtype=torch.float16
).to("cuda")

prompt = "A watercolor painting of pink roses blooming in a garden"
image = pipe(prompt).images[0]
image.show()
