# Fine tune stable diffusion model for brain MRI image generation
Victor Micha, Mindiiarova Renata

In [None]:
!pip install datasets
!pip install kagglehub
!pip install transformers torch torchvision accelerate
!pip uninstall diffusers -y
!pip install git+https://github.com/huggingface/diffusers.git@main # need latest version

In [2]:
import os
import pandas as pd
from diffusers import StableDiffusionPipeline
import torch

import kagglehub
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
assert str(device)=='cuda' or str(device)=='mps', f"{device}" # need a GPU!
print(device)

mps


## Pair images with prompts
Images need to be paired with prompts for fine tuning of diffusion model

In [3]:
# DATASET AT https://www.kaggle.com/datasets/sartajbhuvaji/brain-tumor-classification-mri
# OR RUN FOLLOWING TO DOWNLOAD IT:

# Download latest version
path = kagglehub.dataset_download("sartajbhuvaji/brain-tumor-classification-mri")

# Get the current working directory
current_dir = os.getcwd()
new_path = os.path.join(current_dir, "archive")

# Move the downloaded directory to the current directory
shutil.move(path, new_path)

print("Path to dataset files:", new_path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/sartajbhuvaji/brain-tumor-classification-mri?dataset_version_number=2...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 86.8M/86.8M [00:02<00:00, 36.4MB/s]

Extracting files...





Path to dataset files: /Users/victormicha/PythonProjects/IP_Paris/CVGenerativeAI_EP/GenAIBrainMRI/archive


In [3]:
#!ls archive/Training/* | wc -l
import subprocess

# Run the command and capture its output
result = subprocess.run("ls archive/Training/* | wc -l", shell=True, capture_output=True, text=True)

# Convert the output to an integer
file_count = int(result.stdout.strip())

# Assert that the file count is 2877
assert file_count == 2877, f"Expected 2877 files, but found {file_count}"

In [12]:

# Define paths
base_path = "archive/Training"  # Use training data for fine-tuning
classes = ["glioma_tumor", "meningioma_tumor", "pituitary_tumor", "no_tumor"]
N_CLASSES = len(classes)
assert N_CLASSES==4
prompts = {
    "glioma_tumor": "MRI of brain with glioma tumor",
    "meningioma_tumor": "MRI of brain with meningioma tumor",
    "pituitary_tumor": "MRI of brain with pituitary tumor",
    "no_tumor": "MRI of brain with no tumor"
}

# Collect image paths and prompts
data = []
for cls in classes:
    cls_path = os.path.join(base_path, cls)
    for img_file in os.listdir(cls_path):
        if img_file.endswith(".jpg"):
            img_path = os.path.join(cls_path, img_file)
            data.append({"image_path": img_path, "prompt": prompts[cls]})

# Create DataFrame
df = pd.DataFrame(data)

####################################################
######### TO TAKE A CERTAIN NUMBER OF SAMPLES PER CLASS
# Take 500-1000 samples (e.g., 125-250 per class)
NUM_SAMPLES_PER_CLASS=200 # increase this to fine tune model with more image, prompt pairs
# assert NUM_SAMPLES_PERCLASS <= 395
print(f"Total amount of image prompt pairs: {NUM_SAMPLES_PER_CLASS*N_CLASSES}, {NUM_SAMPLES_PER_CLASS} per class ({N_CLASSES} classes total)")
# samples a certain of number per class
df = df.groupby("prompt").sample(n=NUM_SAMPLES_PER_CLASS, random_state=42).reset_index(drop=True)
####################################################
####################################################
########### TO FINE TUNE ON ALL IMAGES IN TRAINING DIR!
# we are keeping all images
# make sure we have whole training data in df
"""
print("Number of samples per class:")
class_counts = df['prompt'].value_counts()
for prompt, count in class_counts.items():
    print(f"{prompt}: {count}")
    if 'pituitary' in prompt:
        assert count==827
    elif 'glioma' in prompt:
        assert count==826
    elif 'meningioma' in prompt:
        assert count==822
    elif 'no' in prompt:
        assert count==395
assert len(df) == 2870
"""
####################################################



print(f"Total amount of image prompt pairs: {len(df)}")

# Save to CSV
IMAGE_PROMPT_PAIRS_CSV="fine_tuning_metadata.csv"
df.to_csv(IMAGE_PROMPT_PAIRS_CSV, index=False, encoding='utf-8') # utf-8 for fine tuning (later)

print(f"'{IMAGE_PROMPT_PAIRS_CSV}' contains the image prompt pairs")
print(df.head())
print(df.tail())


Total amount of image prompt pairs: 800, 200 per class (4 classes total)
Total amount of image prompt pairs: 800
'fine_tuning_metadata.csv' contains the image prompt pairs
                                   image_path                          prompt
0  archive/Training/glioma_tumor/gg (399).jpg  MRI of brain with glioma tumor
1  archive/Training/glioma_tumor/gg (135).jpg  MRI of brain with glioma tumor
2  archive/Training/glioma_tumor/gg (544).jpg  MRI of brain with glioma tumor
3  archive/Training/glioma_tumor/gg (123).jpg  MRI of brain with glioma tumor
4  archive/Training/glioma_tumor/gg (333).jpg  MRI of brain with glioma tumor
                                       image_path  \
795  archive/Training/pituitary_tumor/p (132).jpg   
796  archive/Training/pituitary_tumor/p (122).jpg   
797  archive/Training/pituitary_tumor/p (186).jpg   
798  archive/Training/pituitary_tumor/p (628).jpg   
799  archive/Training/pituitary_tumor/p (618).jpg   

                                prompt  


In [18]:

#!ls archive/Training/pituitary_tumor | wc -l
#!ls archive/Training/glioma_tumor | wc -l
#!ls archive/Training/meningioma_tumor | wc -l
#!ls archive/Training/no_tumor | wc -l

    2877


## Obtaining Pre-trained Diffusion Model

In [None]:
# Load pretrained model
model_id = "runwayml/stable-diffusion-v1-5"
sd_model = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
sd_model = sd_model.to(device)  # Move to GPU

### Generate images before fine tuning just for testing

In [None]:
# Define prompts from your CSV
prompts = [
    "MRI of brain with glioma tumor",
    "MRI of brain with no tumor"
]

# Generate 1-2 test images
os.makedirs("pre_fine_tuning_test_images", exist_ok=True)
for i, prompt in enumerate(prompts):
    image = sd_model(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
    image.save(f"pre_fine_tuning_test_images/test_image_{i}.jpg")
    print(f"Generated: test_image_{i}.jpg with prompt: '{prompt}'")

## Fine tuning model with image, prompt pairs

In [None]:
#Download train_text_to_image.py from diffusers examples
#!wget -O train_text_to_image.py https://raw.githubusercontent.com/huggingface/diffusers/main/examples/text_to_image/train_text_to_image.py
# USE TESTING_train_text_to_image.py, it is the same file with a couple of tweaks for our use case

In [None]:
#!rm -rf ~/.cache/huggingface/datasets # clear datasets cache...

#from datasets import load_dataset
#dataset = load_dataset("csv", data_files={"train": "fine_tuning_metadata.csv"})
#print(dataset["train"][0])  # Should show first row

In [None]:
####################################################################################################################################################
mixed_precision = "fp16" if str(device) == "cuda" else "no"  # FP16 for CUDA, no for MPS
print(device)
print(mixed_precision)
####################################################################################################################################################

In [None]:
""" VERY LONG COMMAND TO FINE TUNE MODEL
!python TESTING_train_text_to_image.py \
    --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
    --train_data_dir="." \
    --dataset_name="csv" \
    --dataset_config_name="fine_tuning_metadata.csv" \
    --image_column="image_path" \
    --caption_column="prompt" \
    --resolution=256 \
    --train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=1e-6 \
    --max_train_steps=500 \
    --output_dir="fine_tuned_model" \
    --checkpointing_steps=250 \
    --mixed_precision={mixed_precision}
"""

## Notes

In [None]:
# NOTES
#Model: runwayml/stable-diffusion-v1-5 is a widely used pretrained Stable Diffusion model.

