#Questions/Tasks to Pursue

1. StyleX doesn't work without CUDA/T4 GPU. Check for the processor and do not proceed if CUDA is not setup

2. Perfecting the prompt - Play with multiple prompts to ensure that the output matches the input_images style. Perhaps, (1) generating multiple prompts and images and (2) comparing the outputs with the inputs - is the way to go.

3. Need to dig deeper into STYLE_VOCAB and generate_one options.  What do these mean? How does it impact the image generation and final output?

4. Is it even feasible to do batch generation (giving 50 prompts and generating 50 output images) with the free model?

5. In the second cell, I tried to host the input_images on google_drive. But couldn't make it work! Something to explore and figure out how we can connect to google drive.

#StyleX
##Style-Conditioned Image Generation System

### ICS499 - Software Engineering and Capstone Project

In [None]:
#@title Install dependencies
# in Colab (run once per new runtime)
%pip -q install --upgrade diffusers transformers accelerate safetensors sentencepiece huggingface_hub

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
#@title Mount Google Drive  (TBD - Not working)
# Optional: mount Google Drive for persistent input/output folders
USE_GOOGLE_DRIVE = True
DRIVE_PROJECT_PATH = '/content/drive/MyDrive/stylex'  # Change if your folder is elsewhere

from pathlib import Path

DRIVE_PROJECT_ROOT = None
if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_PROJECT_ROOT = Path(DRIVE_PROJECT_PATH)
    print('Drive project root:', DRIVE_PROJECT_ROOT)
    print('Exists:', DRIVE_PROJECT_ROOT.exists())
    if DRIVE_PROJECT_ROOT.exists():
        print('Preview:', [x.name for x in DRIVE_PROJECT_ROOT.iterdir()][:10])
    else:
        print('Path not found. Update DRIVE_PROJECT_PATH and rerun this cell.')
else:
    print('Using local Colab runtime storage.')

In [None]:
#@title import torch, diffusers and transformers
import json
import time
from dataclasses import dataclass
from pathlib import Path

import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from diffusers import StableDiffusion3Pipeline
from huggingface_hub import whoami

PROJECT_ROOT = DRIVE_PROJECT_ROOT if 'DRIVE_PROJECT_ROOT' in globals() and DRIVE_PROJECT_ROOT else Path('.').resolve()
INPUT_ROOT = PROJECT_ROOT / 'input_images'
OUTPUT_ROOT = PROJECT_ROOT / 'output_images'
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)

print('Project root:', PROJECT_ROOT)
print('Input dir:', INPUT_ROOT)
print('Output dir:', OUTPUT_ROOT)
print('CUDA available:', torch.cuda.is_available())

Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.


Project root: /content
Input dir: /content/input_images
Output dir: /content/output_images
CUDA available: True


In [None]:
#@title Setup Hugging Face Credentials
# Hugging Face login (required for gated models like stabilityai/stable-diffusion-3.5-medium)
from huggingface_hub import login, notebook_login

# Optional: paste token here, otherwise interactive login opens
HF_TOKEN = 'Your token here'

if HF_TOKEN.strip():
    login(token=HF_TOKEN.strip(), add_to_git_credential=False)
else:
    notebook_login()

try:
    print('HF account:', whoami().get('name', '(unknown)'))
except Exception:
    print('Logged in, but could not verify account name.')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


HF account: sjasthi


In [None]:
#@title Setup the Style class
@dataclass
class Style:
    name: str
    folder: Path
    prompt_suffix: str = ''
    negative_prompt: str = ''
    use_style_embeddings: bool = True
    embeddings_top_k: int = 25
    embeddings_model_id: str = 'openai/clip-vit-base-patch32'


def list_styles(styles_root: Path) -> list[str]:
    if not styles_root.exists():
        return []
    return sorted([p.name for p in styles_root.iterdir() if p.is_dir()])


def _read_style_json(folder: Path) -> dict:
    for fname in ('style.json', 'styles.json'):
        p = folder / fname
        if p.exists():
            return json.loads(p.read_text(encoding='utf-8'))
    return {}


def load_style(styles_root: Path, style_name: str) -> Style:
    folder = styles_root / style_name
    if not folder.exists() or not folder.is_dir():
        raise FileNotFoundError(f"Style '{style_name}' not found in {styles_root}")

    data = _read_style_json(folder)

    return Style(
        name=style_name,
        folder=folder,
        prompt_suffix=data.get('prompt_suffix', ''),
        negative_prompt=data.get('negative_prompt', ''),
        use_style_embeddings=bool(data.get('use_style_embeddings', True)),
        embeddings_top_k=int(data.get('embeddings_top_k', 10)),
        embeddings_model_id=data.get('embeddings_model_id', 'openai/clip-vit-base-patch32'),
    )


def timestamp() -> str:
    return time.strftime('%Y%m%d_%H%M%S')

In [None]:
#@title Setup the model
IMG_EXTS = {'.png', '.jpg', '.jpeg', '.webp'}

STYLE_VOCAB = [
    'anime style', 'manga panel', 'cel shading', 'clean line art', 'vibrant colors',
    'soft pastel colors', 'high contrast lighting', 'cinematic lighting', 'dramatic shadows',
    'watercolor painting', 'oil painting', 'digital painting', 'concept art', '3d render',
    'pixel art', 'low poly 3d', 'photorealistic', 'film grain', 'bokeh background',
    'neon cyberpunk', 'futuristic cityscape', 'retro synthwave', 'comic book ink',
    'studio ghibli inspired', 'minimalist illustration', 'flat vector art', 'highly detailed',
    'intricate details', 'soft focus', 'sharp focus', 'dynamic composition',
    'symmetrical composition', 'black and white', 'monochrome', 'manga style',
    'painting style', 'drawing style', 'sketch style', 'fantasy art', 'sci-fi art',
    'nature landscape', 'urban cityscape', 'soft lighting', 'hard lighting',
    'volumetric lighting', 'golden hour lighting', 'moody lighting', 'studio lighting',
    'rim lighting', 'backlit subject', 'foggy atmosphere', 'misty environment',
    'dramatic sky', 'overcast lighting', 'professional photography', 'portrait photography',
    'street photography', 'cinematic photography', 'depth of field',
    'shallow depth of field', 'ultra realistic', 'high dynamic range', 'hdr photography',
    'natural skin texture', 'semi realistic painting', 'stylized illustration',
    'fantasy illustration', 'game concept art', 'character design',
    'environment concept art', 'matte painting', 'detailed background',
    'hand painted texture', 'realistic 3d render', 'cgi rendering',
    'octane render style', 'unreal engine style', 'ray traced lighting',
    'global illumination', 'subsurface scattering', 'high detail textures',
    'gritty texture', 'smooth surfaces', 'weathered materials', 'metallic reflections',
    'glass reflections', 'centered composition', 'rule of thirds composition',
    'high resolution', 'extremely detailed', 'masterpiece quality',
    'anime lighting', 'dynamic pose', 'expressive character', 'clean coloring',
    'illustration style shading',
]


@dataclass
class StyleEmbeddingResult:
    keywords: list[str]
    scores: list[float]
    used_images: list[str]


def list_reference_images(style_folder: Path) -> list[Path]:
    if not style_folder.exists():
        return []
    return sorted([p for p in style_folder.iterdir() if p.is_file() and p.suffix.lower() in IMG_EXTS])


def _device_and_dtype(device: str) -> tuple[str, torch.dtype]:
    use_cuda = device == 'cuda' and torch.cuda.is_available()
    return ('cuda' if use_cuda else 'cpu'), (torch.float16 if use_cuda else torch.float32)


@torch.no_grad()
def extract_style_keywords(
    style_folder: Path,
    device: str = 'cuda',
    top_k: int = 20,
    clip_model_id: str = 'openai/clip-vit-base-patch32',
) -> StyleEmbeddingResult:
    img_paths = list_reference_images(style_folder)
    if not img_paths:
        return StyleEmbeddingResult(keywords=[], scores=[], used_images=[])

    dev, _dtype = _device_and_dtype(device)

    model = CLIPModel.from_pretrained(clip_model_id).to(dev)
    processor = CLIPProcessor.from_pretrained(clip_model_id)
    if dev == 'cuda':
        model = model.half()

    imgs = [Image.open(p).convert('RGB') for p in img_paths]
    img_inputs = processor(images=imgs, return_tensors='pt')
    img_inputs = {k: v.to(dev) for k, v in img_inputs.items()}

    vision_out = model.vision_model(pixel_values=img_inputs['pixel_values'], return_dict=True)
    pooled = vision_out.pooler_output
    image_features = model.visual_projection(pooled)

    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    style_vec = image_features.mean(dim=0, keepdim=True)
    style_vec = style_vec / style_vec.norm(dim=-1, keepdim=True)

    text_inputs = processor(text=STYLE_VOCAB, return_tensors='pt', padding=True)
    text_inputs = {k: v.to(dev) for k, v in text_inputs.items()}

    text_out = model.text_model(
        input_ids=text_inputs['input_ids'],
        attention_mask=text_inputs.get('attention_mask', None),
        return_dict=True,
    )
    text_pooled = text_out.pooler_output
    text_features = model.text_projection(text_pooled)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    sims = (text_features @ style_vec.T).squeeze(1)
    top_k = max(0, min(int(top_k), sims.numel()))
    vals, idx = torch.topk(sims, k=top_k)

    keywords = [STYLE_VOCAB[i] for i in idx.tolist()]
    scores = [float(v) for v in vals.tolist()]
    used = [p.name for p in img_paths]

    return StyleEmbeddingResult(keywords=keywords, scores=scores, used_images=used)

In [None]:
#@title Building the prompt to the model
def build_prompt(user_prompt: str, style: Style, style_keywords: list[str]) -> str:
    parts = [user_prompt]
    if style.prompt_suffix:
        parts.append(style.prompt_suffix)
    if style_keywords:
        parts.append(', '.join(style_keywords))
    return ', '.join([p for p in parts if p])


def _load_style_cache(cache_path: Path) -> dict | None:
    if not cache_path.exists():
        return None
    try:
        return json.loads(cache_path.read_text(encoding='utf-8'))
    except Exception:
        return None


def _save_style_cache(cache_path: Path, data: dict) -> None:
    cache_path.write_text(json.dumps(data, indent=2), encoding='utf-8')


def generate_one(
    model_id: str,
    user_prompt: str,
    style: Style,
    out_root: Path,
    steps: int = 20,
    guidance: float = 3.5,
    height: int = 512,
    width: int = 512,
    device: str = 'cuda',
    cpu_offload: bool = True,
    no_t5: bool = True,
):
    use_cuda = device == 'cuda' and torch.cuda.is_available()
    dtype = torch.float16 if use_cuda else torch.float32

    style_keywords: list[str] = []
    cache_file = style.folder / '.style_keywords_cache.json'

    if style.use_style_embeddings:
        cache = _load_style_cache(cache_file)
        if cache and cache.get('top_k') == style.embeddings_top_k and cache.get('model_id') == style.embeddings_model_id:
            style_keywords = cache.get('keywords', [])
        else:
            emb = extract_style_keywords(
                style_folder=style.folder,
                device='cuda' if use_cuda else 'cpu',
                top_k=style.embeddings_top_k,
                clip_model_id=style.embeddings_model_id,
            )
            style_keywords = emb.keywords
            _save_style_cache(cache_file, {
                'model_id': style.embeddings_model_id,
                'top_k': style.embeddings_top_k,
                'keywords': style_keywords,
                'scores': emb.scores,
                'used_images': emb.used_images,
            })

        print('Embedding style keywords:', style_keywords)

    pipe_kwargs = dict(torch_dtype=dtype)
    if no_t5:
        pipe_kwargs['text_encoder_3'] = None
        pipe_kwargs['tokenizer_3'] = None

    try:
        pipe = StableDiffusion3Pipeline.from_pretrained(model_id, **pipe_kwargs)
    except Exception as e:
        msg = str(e)
        if 'gated' in msg.lower() or '401' in msg or 'unauthorized' in msg.lower():
            raise RuntimeError(
                f"Cannot access model '{model_id}'. It is likely gated on Hugging Face.Run the HF login cell, request access to the model page, then rerun."
            ) from e
        raise
    pipe.enable_attention_slicing()
    if hasattr(pipe, 'vae'):
        pipe.vae.enable_slicing()
        pipe.vae.enable_tiling()

    if cpu_offload:
        pipe.enable_model_cpu_offload()
    else:
        pipe = pipe.to('cuda' if use_cuda else 'cpu')

    prompt = build_prompt(user_prompt, style, style_keywords)
    print('Final prompt:', prompt)

    call_kwargs = dict(
        num_inference_steps=steps,
        guidance_scale=guidance,
        height=height,
        width=width,
    )

    negative_prompt = getattr(style, 'negative_prompt', '') or ''
    if negative_prompt:
        call_kwargs['negative_prompt'] = negative_prompt

    result = pipe(prompt, **call_kwargs)
    images = result.images

    out_dir = out_root / style.name
    out_dir.mkdir(parents=True, exist_ok=True)

    ts = timestamp()
    out_paths = []
    for i, img in enumerate(images):
        out_path = out_dir / f'{ts}_{style.name}_{i}.png'
        img.save(out_path)
        out_paths.append(out_path)

    return out_paths

In [None]:
#@title Take User inputs

#TBD: Experiment with different prompts to get closer to the input_images

PROMPT = "Portrait of an Indian woman AI and ML scientist at her workstation, digital painting, soft brush strokes, teal-orange palette, cinematic rim lighting, highly detailed, in the visual style of the provided reference images"

PROMPT = "Indian woman researcher in machine learning presenting neural network diagrams, anime illustration, clean line art, vibrant colors, soft lighting, expressive character design, in the visual style of the provided reference images"

# Must match a folder name under input_images/
STYLE_NAME = 'first'
MODEL_ID = 'stabilityai/stable-diffusion-3.5-medium'

# Increase the steps for better results (25 to 40)
STEPS = 30

GUIDANCE = 3.5
HEIGHT = 512
WIDTH = 512
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CPU_OFFLOAD = True

#Set the following to False for better prompt understanding/uses more memory
NO_T5 = True

available = list_styles(INPUT_ROOT)
print('Available styles:', available)
if STYLE_NAME not in available:
    raise ValueError(f"Unknown style '{STYLE_NAME}'. Available: {available}")

style = load_style(INPUT_ROOT, STYLE_NAME)
style_files = [p for p in style.folder.glob('*.*')]
print('Using style:', style.name)
print('Style folder:', style.folder)
print('Files in style folder:', len(style_files))
for p in style_files[:10]:
    print(' -', p.name)

paths = generate_one(
    model_id=MODEL_ID,
    user_prompt=PROMPT,
    style=style,
    out_root=OUTPUT_ROOT,
    steps=STEPS,
    guidance=GUIDANCE,
    height=HEIGHT,
    width=WIDTH,
    device=DEVICE,
    cpu_offload=CPU_OFFLOAD,
    no_t5=NO_T5,
)

print('Saved:')
for p in paths:
    print(' -', p)

Available styles: ['.ipynb_checkpoints', 'babu']
Using style: babu
Style folder: /content/input_images/babu
Files in style folder: 15
 - 07 Garad Saree.jpg
 - 06 Maharashtra Nauvari.jpg
 - 14 Lehenga Choli.jpg
 - 05 Kerala Kasavu.jpg
 - 11 Coorgi Saree.jpg
 - .style_keywords_cache.json
 - 13 Bihu Dress.jpg
 - 15 Goa Koli Dress.jpg
 - 08 Madisar Kattu.jpg
 - 10 Lambada dress.jpg
Embedding style keywords: ['stylized illustration', 'illustration style shading', 'flat vector art', 'character design', 'pixel art', 'expressive character', 'drawing style', 'clean coloring', 'painting style', 'sketch style']


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/197 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/517 [00:00<?, ?it/s]

Final prompt: Indian woman researcher in machine learning presenting neural network diagrams, anime illustration, clean line art, vibrant colors, soft lighting, expressive character design, in the visual style of the provided reference images, stylized illustration, illustration style shading, flat vector art, character design, pixel art, expressive character, drawing style, clean coloring, painting style, sketch style


  0%|          | 0/30 [00:00<?, ?it/s]

Saved:
 - /content/output_images/babu/20260226_044435_babu_0.png


In [None]:
#@title Download the output images
# Zip + download all generated images from output_images
import shutil
from google.colab import files

zip_base = str(PROJECT_ROOT / 'output_images')
zip_path = shutil.make_archive(zip_base, 'zip', root_dir=OUTPUT_ROOT)
print('Created zip:', zip_path)
files.download(zip_path)

Created zip: /content/output_images.zip


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>