<a href="https://colab.research.google.com/github/palis-dev/jupyter-notebooks/blob/main/ml_visao_comp_dnd_complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install diffusers transformers accelerate datasets torch torchvision


Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m26.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
import os
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL
from diffusers import DDPMScheduler
from datasets import load_dataset
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from accelerate import Accelerator
import json

dataset_dir = "data/dnd_dataset/images"
captions_file = "data/dnd_dataset/captions.txt"
output_dir = "models/dnd_fine_tuned"
os.makedirs(output_dir, exist_ok=True)

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, captions_file, resolution=512):
        self.dataset_dir = dataset_dir
        self.captions = []
        self.resolution = resolution

        with open(captions_file, "r") as f:
            for line in f:
                image_path, caption = line.strip().split("|")
                self.captions.append((image_path.strip(), caption.strip()))

        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        image_path, caption = self.captions[index]
        image = Image.open(os.path.join(self.dataset_dir, image_path)).convert("RGB")
        image = self.transform(image)
        return {"image": image, "caption": caption}


In [None]:
def load_models(pretrained_model_name="CompVis/stable-diffusion-v1-4"):
    vae = AutoencoderKL.from_pretrained(pretrained_model_name, subfolder="vae")
    unet = UNet2DConditionModel.from_pretrained(pretrained_model_name, subfolder="unet")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name, subfolder="text_encoder")
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name, subfolder="tokenizer")
    scheduler = DDPMScheduler.from_pretrained(pretrained_model_name, subfolder="scheduler")

    return vae, unet, text_encoder, tokenizer, scheduler

def fine_tune_stable_diffusion(dataset_dir, captions_file, output_dir, epochs=5, batch_size=8, resolution=512):

    vae, unet, text_encoder, tokenizer, scheduler = load_models()

    dataset = CustomDataset(dataset_dir, captions_file, resolution)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6)

    accelerator = Accelerator()
    vae, unet, text_encoder, optimizer, dataloader = accelerator.prepare(vae, unet, text_encoder, optimizer, dataloader)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        for batch in dataloader:
            images = batch["image"].to(accelerator.device)
            captions = batch["caption"]

            inputs = tokenizer(captions, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
            input_ids = inputs.input_ids.to(accelerator.device)

            noise = torch.randn((images.size(0), 4, 64, 64)).to(accelerator.device)
            timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.size(0),), device=accelerator.device).long()

            with torch.no_grad():
                latents = vae.encode(images).latent_dist.sample()
            latents = latents * 0.18215

            if latents.shape[-1] != 64:
                latents = torch.nn.functional.interpolate(latents, size=(64, 64), mode="bilinear")

            noise_pred = unet(latents, timesteps, encoder_hidden_states=text_encoder(input_ids)[0]).sample

            loss = torch.nn.functional.mse_loss(noise_pred, noise)
            accelerator.backward(loss)

            optimizer.step()
            optimizer.zero_grad()

        unet.save_pretrained(os.path.join(output_dir, f"unet_epoch_{epoch + 1}"))

    unet.save_pretrained(os.path.join(output_dir, "fine_tuned_unet"))

    print("Fine-tuning complete!")

def save_pipeline(output_dir, vae, unet, text_encoder, tokenizer, scheduler):
    print("Saving the complete fine-tuned pipeline...")
    vae.save_pretrained(os.path.join(output_dir, "vae"))
    unet.save_pretrained(os.path.join(output_dir, "unet"))
    text_encoder.save_pretrained(os.path.join(output_dir, "text_encoder"))
    tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
    scheduler.save_pretrained(os.path.join(output_dir, "scheduler"))
    print(f"Pipeline saved to {output_dir}")
def create_model_index_json(output_dir):
    model_index = {
        "_class_name": "StableDiffusionPipeline",
        "_module_name": "diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion",
        "feature_extractor": None,
        "safety_checker": None,
        "scheduler": ["diffusers", "DDPMScheduler"],
        "text_encoder": ["transformers", "CLIPTextModel"],
        "tokenizer": ["transformers", "CLIPTokenizer"],
        "unet": ["diffusers", "UNet2DConditionModel"],
        "vae": ["diffusers", "AutoencoderKL"]
    }

    with open(os.path.join(output_dir, "model_index.json"), "w") as f:
        json.dump(model_index, f, indent=2)
    print(f"model_index.json created in {output_dir}")


In [None]:
fine_tune_stable_diffusion(dataset_dir, captions_file, output_dir, epochs=5, batch_size=4, resolution=512)

Epoch 1/5


  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Fine-tuning complete!
Saving the complete fine-tuned pipeline...
Pipeline saved to models/dnd_fine_tuned
model_index.json created in models/dnd_fine_tuned


In [None]:
prompt = "A medieval fantasy map with forests, rivers, and mountains"
def generate_image(prompt, model_dir, output_file):
  vae = AutoencoderKL.from_pretrained(model_dir, subfolder="vae")
  unet = UNet2DConditionModel.from_pretrained(model_dir, subfolder="unet")
  text_encoder = CLIPTextModel.from_pretrained(model_dir, subfolder="text_encoder")
  tokenizer = CLIPTokenizer.from_pretrained(model_dir, subfolder="tokenizer")
  scheduler = DDPMScheduler.from_pretrained(model_dir, subfolder="scheduler")

  pipe = StableDiffusionPipeline(
      vae=vae,
      unet=unet,
      text_encoder=text_encoder,
      tokenizer=tokenizer,
      scheduler=scheduler,
      safety_checker=None,
      feature_extractor=None
  ).to("cuda")
  image = pipe(prompt, guidance_scale=7.5).images[0]
  image.save(output_file)

generate_image(prompt, "models/dnd_fine_tuned", "generated_map.png")

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


  0%|          | 0/50 [00:00<?, ?it/s]

model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

safety_checker/config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

(…)kpoints/scheduler_config-checkpoint.json:   0%|          | 0.00/209 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Pipeline loaded successfully!
