<a href="https://colab.research.google.com/github/ritwikraha/GenerativeFill-with-Keras-and-Diffusers/blob/main/imagediting_through_text.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image Editing Through Text

Using Keras CV, HuggingFace Diffusers, and Transformers to create a pipeline for editing images based on *just* the text prompt supplied.



## Installations and Imports

In [None]:
!pip install --upgrade -qq keras-cv tensorflow
!pip install --upgrade -qq keras

!pip install --upgrade -qq diffusers accelerate transformers

!pip install --upgrade -qq git+https://github.com/IDEA-Research/GroundingDINO.git

In [None]:
!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
!wget -q https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/v0.1.0-alpha2/groundingdino/config/GroundingDINO_SwinT_OGC.py

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import keras
from keras import ops
import keras_cv


from diffusers import AutoPipelineForInpainting
from groundingdino.util.inference import Model as GroundingDINO
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
GDINO_CONFIG_PATH = "GroundingDINO_SwinT_OGC.py"
GDINO_WEIGHTS_PATH = "groundingdino_swint_ogc.pth"
SAM_MODEL_NAME = "sam_huge_sa1b"

IMAGE_SIZE = (1024, 1024)

## Get the Image

In [None]:
image_url = "https://storage.googleapis.com/keras-cv/test-images/mountain-dog.jpeg"
filepath = keras.utils.get_file(origin=image_url)
image = np.array(keras.utils.load_img(filepath))
image = ops.convert_to_numpy(ops.image.resize(image[None, ...], IMAGE_SIZE)[0])

plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)
plt.axis("on")
plt.show()

## Get the Text  

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf").to(device)

In [None]:
# @title What do you want to do? { run: "auto", form-width: "100px" }
input_prompt = "" # @param {type:"string"}


In [None]:
messages = [
    {"role": "system", "content": "Find the objects that are swapped"},
    {"role": "user", "content": "Swap mountain and lion"},  # example 1
    {"role": "assistant", "content": "mountain, lion"},  # example 1
    {"role": "user", "content": "Change the dog with cat"},  # example 2
    {"role": "assistant", "content": "dog, cat"},  # example 2
    {"role": "user", "content": input_prompt}
]

input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")

generated_ids = model.generate(input_ids.to(device), max_new_tokens=1000, do_sample=True)
outputs = tokenizer.batch_decode(generated_ids)

In [None]:
data = outputs[0].split("[/INST]")[-1].split("</s>")[0].split(",")
target_object = data[0].strip()  # Remove leading/trailing spaces
replacement = data[1].strip()

print(f"object: {target_object}")
print(f"replacement: {replacement}")

## Utilities

In [None]:
def show_mask(mask, ax):
    color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax):
    box = box.reshape(-1)
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )

## Get the Model

In [None]:
sam_model = keras_cv.models.SegmentAnythingModel.from_preset(SAM_MODEL_NAME)
grounding_dino = GroundingDINO(GDINO_CONFIG_PATH, GDINO_WEIGHTS_PATH)

# Segmentation based on Text Input

In [None]:
object_name = target_object

# Grounding DINO
boxes = grounding_dino.predict_with_caption(
    image.astype(np.uint8),
    object_name,
)
boxes = np.array(boxes[0].xyxy)

# SAM
outputs = sam_model.predict(
    {
        "images": np.repeat(image[np.newaxis, ...], boxes.shape[0], axis=0),
        "boxes": boxes.reshape(-1, 1, 2, 2),
    },
    batch_size=1,
)

## Show the Segmentation

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)

for mask in outputs["masks"]:
    mask = ops.image.resize(mask[0][..., None], IMAGE_SIZE)[..., 0]
    mask = ops.convert_to_numpy(mask) > 0.0
    show_mask(mask, plt.gca())
    show_box(boxes, plt.gca())

plt.axis("off")
plt.show()

# Image Editing using Inpainting

In [None]:
pipeline = AutoPipelineForInpainting.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    torch_dtype=torch.float16,
    variant="fp16",
)
pipeline.enable_model_cpu_offload()

## Creating the Prompt based on Text Input

In [None]:
prompt = f"A {replacement} highly detailed, 8K"
output = pipeline(
    prompt=prompt,
    image=Image.fromarray(image.astype(np.uint8)),
    mask_image=Image.fromarray(mask),
    strength=0.6
).images[0]

## Ta-Daa!

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(output)
plt.axis("off")
plt.show()