<a href="https://colab.research.google.com/github/vihaankrishna100/AI-Neural-Networks/blob/main/Generative_Image_Diffusion_With_MAB_UBC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q gradio diffusers transformers accelerate safetensors bitsandbytes
!pip install --upgrade diffusers transformers safetensors



In [None]:
!curl -L -o train_dreambooth_lora.py https://raw.githubusercontent.com/ShivamShrirao/diffusers/main/examples/dreambooth/train_dreambooth_lora.py
!sed -i '/logging_dir/d' train_dreambooth_lora.py

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 42870  100 42870    0     0   100k      0 --:--:-- --:--:-- --:--:--  100k


In [None]:
!pip show diffusers

from google.colab import drive
drive.mount('/content/drive')

import os
import shutil
import numpy as np
from math import sqrt, log
import random
import json
import subprocess
from typing import List
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import glob
import time

Name: diffusers
Version: 0.33.1
Summary: State-of-the-art diffusion in PyTorch and JAX.
Home-page: https://github.com/huggingface/diffusers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)
Author-email: diffusers@huggingface.co
License: Apache 2.0 License
Location: /usr/local/lib/python3.11/dist-packages
Requires: filelock, huggingface-hub, importlib-metadata, numpy, Pillow, regex, requests, safetensors
Required-by: 
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!accelerate config default

# Set up directories
drive_path = "/content/drive/MyDrive/tesla_dreambooth_project/dataset"
os.makedirs(drive_path, exist_ok=True)

# Define folder names
folder_names = ["image_luxury", "image_modern", "image_rugged", "image_sporty", "image_vibrant"]

# Copy data from local to Google Drive if it exists
for folder in folder_names:
    src = f"/content/{folder}"
    dst = os.path.join(drive_path, folder)
    if os.path.exists(src):
        shutil.copytree(src, dst, dirs_exist_ok=True)

Configuration already exists at /root/.cache/huggingface/accelerate/default_config.yaml, will not override. Run `accelerate config` manually or pass a different `save_location`.


In [None]:

# Define prompts
prompts = [
    "a [high_ctr] web banner ad of a [sks_product] car in a [rugged] image style",
    "a [high_ctr] web banner ad of a [sks_product] car in a [luxury] image style",
    "a [high_ctr] web banner ad of a [sks_product] car in a [modern] image style",
    "a [high_ctr] web banner ad of a [sks_product] car in a [vibrant] image style",
    "a [high_ctr] web banner ad of a [sks_product] car in a [sporty] image style"
]

# Initialize UCB variables
usage = np.zeros(len(prompts))
reward = np.zeros(len(prompts))
t = 1

In [None]:
def select_prompt_ucb() -> int:
    global t
    ucb_scores = [
        (reward[i] / usage[i]) + sqrt(2 * log(t) / usage[i]) if usage[i] > 0 else float("inf")
        for i in range(len(prompts))
    ]
    i = int(np.argmax(ucb_scores))
    usage[i] += 1
    t += 1
    return i

In [None]:
def update_reward(prompt_id: int, ctr: float):
    reward[prompt_id] += ctr
    print(f"Updated reward for prompt {prompt_id}: +{ctr}, total: {reward[prompt_id]}")

def check_image_files(folder_path):
    """Check if the folder has enough image files for training."""
    if not os.path.exists(folder_path):
        return 0

    # Count files with common image extensions
    image_files = glob.glob(os.path.join(folder_path, "*.jpg"))
    image_files += glob.glob(os.path.join(folder_path, "*.jpeg"))
    image_files += glob.glob(os.path.join(folder_path, "*.png"))
    image_files += glob.glob(os.path.join(folder_path, "*.webp"))
    image_files += glob.glob(os.path.join(folder_path, "*.avif"))

    return len(image_files)

In [None]:
def train_on_prompt(prompt_id: int):
    prompt = prompts[prompt_id]
    folder_map = {
        0: "image_rugged",
        1: "image_luxury",
        2: "image_modern",
        3: "image_vibrant",
        4: "image_sporty"
    }
    folder_name = folder_map[prompt_id]
    instance_data_dir = f"/content/drive/MyDrive/tesla_dreambooth_project/dataset/{folder_name}"
    output_dir = f"./output/prompt_{prompt_id}"

    # Check if the dataset folder exists and has enough images
    image_count = check_image_files(instance_data_dir)
    if image_count < 3:
        return f"Not enough images in {instance_data_dir}. Found {image_count}, need at least 3-5 images for Dreambooth."

    os.makedirs(output_dir, exist_ok=True)
    print(f"Training on: {prompt}")
    print(f"Found {image_count} images in {instance_data_dir}")

    # List the output dir to check if anything exists there
    try:
        os.system(f"ls -la {output_dir}")
    except Exception as e:
        print(f"Error listing directory: {e}")

    # Define a unique class token for each style
    class_token = f"sks{prompt_id}"
    instance_prompt = f"a photo of {class_token} car in {folder_name.replace('image_', '')} style"
    class_prompt = "a photo of car"

    print(f"Using instance prompt: {instance_prompt}")
    print(f"Using class prompt: {class_prompt}")

    command = [
        "accelerate", "launch", "train_dreambooth_lora.py",
        "--pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1-base",
        f"--instance_data_dir={instance_data_dir}",
        f"--output_dir={output_dir}",
        f"--instance_prompt={instance_prompt}",
        f"--class_prompt={class_prompt}",
        "--resolution=512",
        "--train_batch_size=1",
        "--gradient_accumulation_steps=1",
        "--learning_rate=1e-4",
        "--report_to=none",
        "--lr_scheduler=constant",
        "--lr_warmup_steps=0",
        "--max_train_steps=400",
        f"--validation_prompt={prompt}",
        "--validation_epochs=50",
        "--seed=0",
        "--mixed_precision=fp16",  # Add mixed precision for better performance
        "--use_8bit_adam",         # Use 8-bit Adam for memory efficiency
        "--checkpointing_steps=100",
        "--output_format=safetensors",  # Explicitly specify safetensors format
        "--save_steps=100",             # Save every 100 steps
        "--rank=4"                      # LoRA rank
    ]



    # List files in the output directory to see what was created
    print(f"Files in {output_dir}:")
    try:
        os.system(f"ls -la {output_dir}")
    except Exception as e:
        print(f"Error listing directory: {e}")

    # Check for any PyTorch weight files as well
    safetensors_files = glob.glob(os.path.join(output_dir, "*.safetensors"))
    bin_files = glob.glob(os.path.join(output_dir, "*.bin"))
    print(f"Found {len(safetensors_files)} .safetensors files and {len(bin_files)} .bin files")

    if len(safetensors_files) == 0 and len(bin_files) == 0:
        print("WARNING: No model weight files found. Training may have failed.")

    # Copy output to Google Drive
    src_folder = output_dir
    dst_folder = f"/content/drive/MyDrive/tesla_dreambooth_project/output/prompt_{prompt_id}"

    # Make sure destination parent folder exists
    os.makedirs(os.path.dirname(dst_folder), exist_ok=True)

    try:
        shutil.copytree(src_folder, dst_folder, dirs_exist_ok=True)
        print(f"Successfully copied output to {dst_folder}")
    except Exception as e:
        print(f"Error copying to Drive: {e}")

    return f"Trained on: {prompts[prompt_id]}, output copied to Drive."

In [None]:
def generate_image(prompt_id: int):
    prompt = prompts[prompt_id]
    lora_dir = f"./output/prompt_{prompt_id}"

    # List files in the LoRA directory to see what's available
    print(f"Files in {lora_dir}:")
    try:
        os.system(f"ls -la {lora_dir}")
    except Exception as e:
        print(f"Error listing directory: {e}")

    # Check for both safetensors and bin files
    lora_files = glob.glob(os.path.join(lora_dir, "*.safetensors"))
    bin_files = glob.glob(os.path.join(lora_dir, "*.bin"))

    print(f"Found {len(lora_files)} .safetensors files and {len(bin_files)} .bin files")

    pipe = StableDiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-base",
        torch_dtype=torch.float16
    ).to("cuda")

    if len(lora_files) > 0:
        # Use the first safetensors file found
        lora_file = os.path.basename(lora_files[0])
        print(f"Using LoRA weights file (safetensors): {lora_file}")
        pipe.load_lora_weights(lora_dir, weight_name=lora_file, local_files_only=True)
    elif len(bin_files) > 0:
        # Use the first bin file found
        bin_file = os.path.basename(bin_files[0])
        print(f"Using LoRA weights file (bin): {bin_file}")
        pipe.load_lora_weights(lora_dir, weight_name=bin_file, local_files_only=True)
    else:
        print(f"No weight files found in {lora_dir}, using base model without fine-tuning")

    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

    # Create a more structured prompt for better results
    folder_map = {
        0: "rugged",
        1: "luxury",
        2: "modern",
        3: "vibrant",
        4: "sporty"
    }
    style = folder_map[prompt_id]
    enhanced_prompt = f"a high-quality advertisement banner showing a Tesla car in {style} style, detailed, professional photography"

    print(f"Generating with prompt: {enhanced_prompt}")
    image = pipe(enhanced_prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    filename = f"generated_prompt_{style}.jpg"
    image.save(filename)

    # Also save to Google Drive
    drive_path = f"/content/drive/MyDrive/tesla_dreambooth_project/generated_images"
    os.makedirs(drive_path, exist_ok=True)
    drive_filename = os.path.join(drive_path, filename)
    image.save(drive_filename)
    print(f"Image saved to {drive_filename}")

    return filename

In [None]:
def run_ucb_interface():
    def run_round():
        prompt_id = select_prompt_ucb()
        status = train_on_prompt(prompt_id)
        image_path = generate_image(prompt_id)
        return status, image_path, prompt_id

    def submit_ctr_and_enable(ctr, prompt_id):
        update_reward(int(prompt_id), ctr)
        return f"CTR {ctr} recorded for prompt {int(prompt_id)}."

    def debug_dataset():
        results = []
        for i, folder in enumerate(folder_names):
            path = f"/content/drive/MyDrive/tesla_dreambooth_project/dataset/{folder}"
            count = check_image_files(path)
            results.append(f"Folder {folder}: {count} images")
        return "\n".join(results)

    with gr.Blocks() as iface:
        with gr.Row():
            run_button = gr.Button("Train & Generate")
            debug_button = gr.Button("Check Datasets")

        with gr.Row():
            output_text = gr.Textbox(label="Status", lines=10)
            output_image = gr.Image(label="Generated Image")
            prompt_id_box = gr.Number(label="Prompt ID", interactive=False)

        with gr.Row():
            ctr_input = gr.Number(label="CTR Feedback (0.0-1.0)")
            submit_button = gr.Button("Submit CTR")
            feedback_output = gr.Textbox(label="Feedback Status")

        run_button.click(run_round, outputs=[output_text, output_image, prompt_id_box])
        submit_button.click(submit_ctr_and_enable, inputs=[ctr_input, prompt_id_box], outputs=feedback_output)
        debug_button.click(debug_dataset, outputs=output_text)

    iface.launch(share=True, debug=True)



In [None]:
if __name__ == "__main__":
    run_ucb_interface()

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://34cbf5cc306ab6da16.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Training on: a [high_ctr] web banner ad of a [sks_product] car in a [rugged] image style
Found 5 images in /content/drive/MyDrive/tesla_dreambooth_project/dataset/image_rugged
Using instance prompt: a photo of sks0 car in rugged style
Using class prompt: a photo of car
Files in ./output/prompt_0:
Found 0 .safetensors files and 0 .bin files
Successfully copied output to /content/drive/MyDrive/tesla_dreambooth_project/output/prompt_0
Files in ./output/prompt_0:
Found 0 .safetensors files and 0 .bin files


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

No weight files found in ./output/prompt_0, using base model without fine-tuning
Generating with prompt: a high-quality advertisement banner showing a Tesla car in rugged style, detailed, professional photography


  0%|          | 0/30 [00:00<?, ?it/s]

Image saved to /content/drive/MyDrive/tesla_dreambooth_project/generated_images/generated_prompt_rugged.jpg
Updated reward for prompt 0: +0.2, total: 0.2
Training on: a [high_ctr] web banner ad of a [sks_product] car in a [luxury] image style
Found 4 images in /content/drive/MyDrive/tesla_dreambooth_project/dataset/image_luxury
Using instance prompt: a photo of sks1 car in luxury style
Using class prompt: a photo of car
Files in ./output/prompt_1:
Found 0 .safetensors files and 0 .bin files
Successfully copied output to /content/drive/MyDrive/tesla_dreambooth_project/output/prompt_1
Files in ./output/prompt_1:
Found 0 .safetensors files and 0 .bin files


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

No weight files found in ./output/prompt_1, using base model without fine-tuning
Generating with prompt: a high-quality advertisement banner showing a Tesla car in luxury style, detailed, professional photography


  0%|          | 0/30 [00:00<?, ?it/s]

Image saved to /content/drive/MyDrive/tesla_dreambooth_project/generated_images/generated_prompt_luxury.jpg
Updated reward for prompt 1: +0.5, total: 0.5
Training on: a [high_ctr] web banner ad of a [sks_product] car in a [modern] image style
Found 4 images in /content/drive/MyDrive/tesla_dreambooth_project/dataset/image_modern
Using instance prompt: a photo of sks2 car in modern style
Using class prompt: a photo of car
Files in ./output/prompt_2:
Found 0 .safetensors files and 0 .bin files
Successfully copied output to /content/drive/MyDrive/tesla_dreambooth_project/output/prompt_2
Files in ./output/prompt_2:
Found 0 .safetensors files and 0 .bin files


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

No weight files found in ./output/prompt_2, using base model without fine-tuning
Generating with prompt: a high-quality advertisement banner showing a Tesla car in modern style, detailed, professional photography


  0%|          | 0/30 [00:00<?, ?it/s]

Image saved to /content/drive/MyDrive/tesla_dreambooth_project/generated_images/generated_prompt_modern.jpg
Updated reward for prompt 2: +0.3, total: 0.3
Training on: a [high_ctr] web banner ad of a [sks_product] car in a [vibrant] image style
Found 5 images in /content/drive/MyDrive/tesla_dreambooth_project/dataset/image_vibrant
Using instance prompt: a photo of sks3 car in vibrant style
Using class prompt: a photo of car
Files in ./output/prompt_3:
Found 0 .safetensors files and 0 .bin files
Successfully copied output to /content/drive/MyDrive/tesla_dreambooth_project/output/prompt_3
Files in ./output/prompt_3:
Found 0 .safetensors files and 0 .bin files


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

No weight files found in ./output/prompt_3, using base model without fine-tuning
Generating with prompt: a high-quality advertisement banner showing a Tesla car in vibrant style, detailed, professional photography


  0%|          | 0/30 [00:00<?, ?it/s]

Image saved to /content/drive/MyDrive/tesla_dreambooth_project/generated_images/generated_prompt_vibrant.jpg
Updated reward for prompt 3: +0.4, total: 0.4
Training on: a [high_ctr] web banner ad of a [sks_product] car in a [sporty] image style
Found 4 images in /content/drive/MyDrive/tesla_dreambooth_project/dataset/image_sporty
Using instance prompt: a photo of sks4 car in sporty style
Using class prompt: a photo of car
Files in ./output/prompt_4:
Found 0 .safetensors files and 0 .bin files
Successfully copied output to /content/drive/MyDrive/tesla_dreambooth_project/output/prompt_4
Files in ./output/prompt_4:
Found 0 .safetensors files and 0 .bin files


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

No weight files found in ./output/prompt_4, using base model without fine-tuning
Generating with prompt: a high-quality advertisement banner showing a Tesla car in sporty style, detailed, professional photography


  0%|          | 0/30 [00:00<?, ?it/s]

Image saved to /content/drive/MyDrive/tesla_dreambooth_project/generated_images/generated_prompt_sporty.jpg
Updated reward for prompt 4: +0.2, total: 0.2
Training on: a [high_ctr] web banner ad of a [sks_product] car in a [luxury] image style
Found 4 images in /content/drive/MyDrive/tesla_dreambooth_project/dataset/image_luxury
Using instance prompt: a photo of sks1 car in luxury style
Using class prompt: a photo of car
Files in ./output/prompt_1:
Found 0 .safetensors files and 0 .bin files
Successfully copied output to /content/drive/MyDrive/tesla_dreambooth_project/output/prompt_1
Files in ./output/prompt_1:
Found 0 .safetensors files and 0 .bin files


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

No weight files found in ./output/prompt_1, using base model without fine-tuning
Generating with prompt: a high-quality advertisement banner showing a Tesla car in luxury style, detailed, professional photography


  0%|          | 0/30 [00:00<?, ?it/s]

Image saved to /content/drive/MyDrive/tesla_dreambooth_project/generated_images/generated_prompt_luxury.jpg
Updated reward for prompt 1: +0.8, total: 1.3
Training on: a [high_ctr] web banner ad of a [sks_product] car in a [vibrant] image style
Found 5 images in /content/drive/MyDrive/tesla_dreambooth_project/dataset/image_vibrant
Using instance prompt: a photo of sks3 car in vibrant style
Using class prompt: a photo of car
Files in ./output/prompt_3:
Found 0 .safetensors files and 0 .bin files
Successfully copied output to /content/drive/MyDrive/tesla_dreambooth_project/output/prompt_3
Files in ./output/prompt_3:
Found 0 .safetensors files and 0 .bin files


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

No weight files found in ./output/prompt_3, using base model without fine-tuning
Generating with prompt: a high-quality advertisement banner showing a Tesla car in vibrant style, detailed, professional photography


  0%|          | 0/30 [00:00<?, ?it/s]

Image saved to /content/drive/MyDrive/tesla_dreambooth_project/generated_images/generated_prompt_vibrant.jpg
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://34cbf5cc306ab6da16.gradio.live
