In [2]:
import os
import pandas as pd
import torch
from PIL import Image
from datasets import Dataset
from peft import LoraConfig, get_peft_model_state_dict
from transformers import CLIPTokenizer, CLIPTextModel
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from diffusers.training_utils import compute_snr, cast_training_params
from torch.utils.data import DataLoader
from torchvision import transforms
from accelerate import Accelerator
import torch.nn.functional as F
from tqdm import tqdm

    PyTorch 2.2.1+cu121 with CUDA 1201 (you have 2.2.1+cu118)
    Python  3.10.11 (you have 3.10.16)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [52]:
print(torch.__version__)
print(f"Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

2.7.1+cu118
Using device: cuda


In [4]:
def load_from_csv(csv_path):
    df = pd.read_csv(csv_path)
    examples = []

    for _, row in df.iterrows():
        # Use only the image filename, ignore bad prefix if needed
        image_filename = os.path.basename(row["image_path"])
        image_path = os.path.join("dataset", image_filename)

        if not os.path.exists(image_path):
            print(f"❌ Image not found: {image_path}")
            continue

        parts = [
            str(row.get("description", "")).strip(),
            "with side salad" if str(row.get("Beilagensalat", "")).lower() in ["yes", "ja"] else "",
            "with regional apple" if str(row.get("Regio Apfel", "")).lower() in ["yes", "ja"] else "",
            f"type: {row.get('type', '')}",
            f"diet: {row.get('diet', '')}",
            f"served at {row.get('mensa', '')}"
        ]

        prompt = ", ".join([p for p in parts if p])
        examples.append({"image": Image.open(image_path).convert("RGB"), "text": prompt})

    return examples


In [5]:
test_examples = load_from_csv("./dataset/dataset.csv")
print(test_examples)

[{'image': <PIL.Image.Image image mode=RGB size=3840x2880 at 0x2C9ABD0F610>, 'text': 'Apfelstrudel Vanillesauce Sauerkirschkompott, type: Essen 1, diet: Vegetarisch, served at Mensa Institutsviertel'}, {'image': <PIL.Image.Image image mode=RGB size=3468x4624 at 0x2C9ABD0F700>, 'text': 'Aglio Spaghetti mit getrockneten Tomaten, type: Schneller Teller, diet: Vegan, served at Mensa Institutsviertel'}]


In [6]:
def tokenize_and_transform(example, tokenizer, resolution):
    encoding = tokenizer(
        example["text"],
        padding="max_length",
        truncation=True,
        max_length=tokenizer.model_max_length,
        return_tensors="pt",
    )

    # Squeeze to get (seq_len,) shape instead of (1, seq_len)
    example["input_ids"] = encoding["input_ids"].squeeze(0)

    transform = transforms.Compose([
        transforms.Resize(resolution, interpolation=Image.BILINEAR),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    example["pixel_values"] = transform(example["image"])
    return example


In [7]:
accelerator = Accelerator()
device = accelerator.device
resolution = 512

# Set paths for the dataset and model output
dataset_csv = "./dataset/dataset.csv"  # CSV with metadata and paths
output_dir = "./lora_cafeteria_output"  # Where to save LoRA weights
pretrained_model = "CompVis/stable-diffusion-v1-4"  # Pretrained SD base

In [8]:


tokenizer = CLIPTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae").to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model, subfolder="unet").to(device)
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")



In [9]:
# Freeze all layers except LoRA
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
vae.requires_grad_(False)

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

In [10]:
lora_config = LoraConfig(
    r=4,
    lora_alpha=16,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],
    lora_dropout=0.1,
    bias="none"
)
unet.add_adapter(lora_config)
cast_training_params(unet, dtype=torch.float32)

In [11]:

examples = load_from_csv(dataset_csv)
dataset = Dataset.from_list(examples)
print(dataset)
dataset = dataset.map(lambda e: tokenize_and_transform(e, tokenizer, resolution))
dataset.set_format(type="torch", columns=["input_ids", "pixel_values"])

sample = dataset[0]
print(type(sample["input_ids"]))          # should be torch.Tensor
print(sample["input_ids"].shape)          # should be (77,)
print(sample["pixel_values"].shape) 

Dataset({
    features: ['image', 'text'],
    num_rows: 2
})


Map:   0%|          | 0/2 [00:00<?, ? examples/s]

<class 'torch.Tensor'>
torch.Size([77])
torch.Size([3, 512, 512])


In [12]:
def collate_fn(batch):
    return {
        "input_ids": torch.stack([x["input_ids"] for x in batch]),
        "pixel_values": torch.stack([x["pixel_values"] for x in batch])
    }


In [13]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, unet.parameters()), lr=1e-4)

unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)


In [14]:
unet.train()
for epoch in range(10):
    for step, batch in enumerate(tqdm(dataloader)):
        with accelerator.accumulate(unet):
            # VAE encode
            latents = vae.encode(batch["pixel_values"].to(device)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0]
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            loss = F.mse_loss(model_pred, noise)

            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
    print(f"Epoch {epoch + 1} complete. Loss: {loss.item():.4f}")

  hidden_states = F.scaled_dot_product_attention(
100%|██████████| 2/2 [00:07<00:00,  3.81s/it]


Epoch 1 complete. Loss: 0.0275


100%|██████████| 2/2 [00:09<00:00,  4.80s/it]


Epoch 2 complete. Loss: 0.0275


100%|██████████| 2/2 [00:08<00:00,  4.21s/it]


Epoch 3 complete. Loss: 0.1605


100%|██████████| 2/2 [00:07<00:00,  3.72s/it]


Epoch 4 complete. Loss: 0.0379


100%|██████████| 2/2 [00:07<00:00,  3.64s/it]


Epoch 5 complete. Loss: 0.1746


100%|██████████| 2/2 [00:07<00:00,  3.61s/it]


Epoch 6 complete. Loss: 0.0079


100%|██████████| 2/2 [00:07<00:00,  3.69s/it]


Epoch 7 complete. Loss: 0.2042


100%|██████████| 2/2 [00:07<00:00,  3.72s/it]


Epoch 8 complete. Loss: 0.0174


100%|██████████| 2/2 [00:07<00:00,  3.75s/it]


Epoch 9 complete. Loss: 0.0244


100%|██████████| 2/2 [00:07<00:00,  3.85s/it]


Epoch 10 complete. Loss: 0.2331


In [15]:
accelerator.wait_for_everyone()
lora_state_dict = get_peft_model_state_dict(unet)
torch.save(lora_state_dict, os.path.join(output_dir, "pytorch_lora_weights.bin"))
print("LoRA fine-tuning complete.")

LoRA fine-tuning complete.
