In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install --upgrade pip setuptools wheel

In [None]:
%cd '/content/drive/MyDrive/guided_grounded_instructpix2pix/'

In [None]:
!pip install groundingdino-py

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
!pip install diffusers transformers accelerate scipy safetensors

In [None]:
!pip install segment-anything

In [None]:
!pip install torchmetrics

In [None]:
!pip install git+https://github.com/openai/CLIP.git

In [None]:
!pip install transformers torch peft tqdm numpy scikit-learn

In [None]:
import math
import torch
import os
from PIL import Image
from pathlib import Path

from tqdm.notebook import tqdm
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

from diffusers import DDIMScheduler, DDIMInverseScheduler
from pipeline_stable_diffusion_grounded_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline

from external_mask_extractor_improved import ExternalMaskExtractor
from transformers import logging

import pandas as pd
import re

logging.set_verbosity_error()

In [10]:
def load_pil_image_no_pad(path, max_res=512):
    img = Image.open(path).convert("RGB")
    orig_w, orig_h = img.size
    # Keep original aspect ratio, but limit longest side to max_res
    scale = max_res / max(orig_w, orig_h)
    new_w = int((orig_w * scale) // 8) * 8
    new_h = int((orig_h * scale) // 8) * 8
    resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
    return resized, (orig_w, orig_h)

def restore_original_resolution(edited_img, original_size):
    return edited_img.resize(original_size, Image.Resampling.LANCZOS)

In [None]:
device = 'cuda:0'

# pipeline
num_timesteps = 100
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
                                                                  torch_dtype=torch.float16,
                                                                  safety_checker=None).to(device)
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config, set_alpha_to_zero=False)

pipeline.scheduler.set_timesteps(num_timesteps)
pipeline.inverse_scheduler.set_timesteps(num_timesteps)

In [None]:
torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Define model paths here
SAM_PATH = '/content/drive/MyDrive/SAM/sam_vit_h_4b8939.pth'
TARGET_PREDICTOR_PATH = "/content/drive/MyDrive/guided_grounded_instructpix2pix/models/target_finder_bert_lora_model/"
SIZE_PREDICTOR_PATHS = ["/content/drive/MyDrive/guided_grounded_instructpix2pix/models/size_predictor_models/standardSmoothL1_model.pth",
                        "/content/drive/MyDrive/guided_grounded_instructpix2pix/models/size_predictor_models/CombLoss_model.pth"]

device = 'cuda:0'
extractor = ExternalMaskExtractor(
    device=device,
    target_predictor_model_path=TARGET_PREDICTOR_PATH,
    size_predictor_model_path=SIZE_PREDICTOR_PATHS,
    sam_path=SAM_PATH
)

def inference(pipeline, image_pil, instruction, image_guidance_scale, text_guidance_scale, seed, blending_range):
    final_mask, first_target_phrase, second_target_phrase, directional_phrase, final_sentence, predicted_size_mask = extractor.get_external_mask(image_pil, instruction, verbose=True)

    inv_results = pipeline.invert(final_sentence, image_pil, num_inference_steps=num_timesteps, inv_range=blending_range)
    generator = torch.Generator(device).manual_seed(seed) if seed is not None else torch.Generator(device)
    edited_image = pipeline(final_sentence, src_mask=final_mask, image=image_pil,
                            guidance_scale=text_guidance_scale, image_guidance_scale=image_guidance_scale,
                            num_inference_steps=num_timesteps, generator=generator).images[0]
    return edited_image, final_mask, predicted_size_mask

# **##################################**
For single image edit:
"If you want to edit a single image using a prompt, run the following code:"

In [None]:
def edit_image(
    pipeline,
    image_path,
    edit_instruction,
    image_guidance_scale=1.5,
    guidance_scale=7.5,
    seed=42,
    start_blending_at_tstep=100,
    end_blending_at_tstep=1,
    verbose=False
):
    blending_range = [start_blending_at_tstep, end_blending_at_tstep]

    image_resized, original_size = load_pil_image_no_pad(image_path)
    edited_image, final_mask, predicted_size_mask = inference(
        pipeline, image_resized, edit_instruction,
        image_guidance_scale=image_guidance_scale,
        text_guidance_scale=guidance_scale,
        seed=seed,
        blending_range=blending_range
    )

##############################################################
    # Extract original image filename without extension
    original_filename = Path(image_path).stem

    # Define output directories
    edited_dir = Path('/content/drive/MyDrive/guided_grounded_instructpix2pix/output/ggip2p_edited_images')
    mask_dir = Path('/content/drive/MyDrive/guided_grounded_instructpix2pix/output/ggip2p_edited_masks')

    # Create directories if they don't exist
    edited_dir.mkdir(parents=True, exist_ok=True)
    mask_dir.mkdir(parents=True, exist_ok=True)

    # Convert mask to PIL Images if needed
    if not isinstance(final_mask, Image.Image):
        final_mask = Image.fromarray(final_mask.astype('uint8'))
    # Handle predicted_size_mask safely
    if predicted_size_mask is not None:
        if torch.is_tensor(predicted_size_mask):
            predicted_size_mask = Image.fromarray((predicted_size_mask.cpu().numpy() * 255).astype('uint8'))
    else:
        predicted_size_mask = Image.new('L', image_resized.size, 128)  # dummy gray image, or skip plotting it

    # Create path and name with instruction + original_filename
    edited_image_path = edited_dir / f'{original_filename}.jpg'
    final_mask_path = mask_dir / f'{original_filename}.png'
    predicted_size_mask_path = mask_dir / f'{original_filename}_primary_mask.png'

    # Restore size
    edited_image = restore_original_resolution(edited_image, original_size)
    final_mask = restore_original_resolution(final_mask, original_size)
    predicted_size_mask = restore_original_resolution(predicted_size_mask, original_size)

    # Save
    edited_image.save(edited_image_path)
    final_mask.save(final_mask_path)
    predicted_size_mask.save(predicted_size_mask_path)
##############################################################

    fig, axes = plt.subplots(1, 4, figsize=(15, 5))

    axes[0].imshow(image_resized)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(edited_image)
    axes[1].set_title('Edited Image')
    axes[1].axis('off')

    axes[2].imshow(final_mask)
    axes[2].set_title('Final Mask')
    axes[2].axis('off')

    axes[3].imshow(predicted_size_mask, cmap='gray')
    axes[3].set_title('Predicted Size Mask')
    axes[3].axis('off')

    plt.tight_layout()
    plt.show()

edit_image(
    pipeline=pipeline,
    image_path = '/content/drive/MyDrive/imgs/cairn-2806850_640-pixabay.com.jpg',
    edit_instruction = 'convert the stone on the left of the bird to autumn leaves',
    image_guidance_scale=1.5,
    guidance_scale=7.5,
    seed=42,
    start_blending_at_tstep = 100
)

# **##################################**
For the full test set:  
"If you want to edit the entire test set of images using their corresponding prompt pairs from an Excel file, run the following code:"

In [None]:
verbose = False

def edit_image(
    pipeline,
    image_path,
    edit_instruction,
    image_guidance_scale=1.5,
    guidance_scale=7.5,
    seed=42,
    start_blending_at_tstep=100,
    end_blending_at_tstep=1,
    verbose=False
):
    blending_range = [start_blending_at_tstep, end_blending_at_tstep]

    image_resized, original_size = load_pil_image_no_pad(image_path)
    edited_image, final_mask, predicted_size_mask = inference(
        pipeline, image_resized, edit_instruction,
        image_guidance_scale=image_guidance_scale,
        text_guidance_scale=guidance_scale,
        seed=seed,
        blending_range=blending_range
    )

##############################################################
    # Extract original image filename without extension
    original_filename = Path(image_path).stem

    # Sanitize edit_instruction for filename (replace spaces and special chars with _)
    safe_instruction = re.sub(r'[^a-zA-Z0-9]+', '_', edit_instruction.strip()).strip('_')

    # Define output directories
    edited_dir = Path('/content/drive/MyDrive/guided_grounded_instructpix2pix/output/ggip2p_edited_images')
    mask_dir = Path('/content/drive/MyDrive/guided_grounded_instructpix2pix/output/ggip2p_edited_masks')

    # Create directories if they don't exist
    edited_dir.mkdir(parents=True, exist_ok=True)
    mask_dir.mkdir(parents=True, exist_ok=True)

    # Convert masks to PIL Images if needed
    if not isinstance(final_mask, Image.Image):
        final_mask = Image.fromarray(final_mask.astype('uint8'))
    # Handle predicted_size_mask safely
    if predicted_size_mask is not None:
        if torch.is_tensor(predicted_size_mask):
            predicted_size_mask = Image.fromarray((predicted_size_mask.cpu().numpy() * 255).astype('uint8'))
    else:
        predicted_size_mask = Image.new('L', image_resized.size, 128)  # dummy gray image, or skip plotting it

    # Create paths and names with instruction + original_filename
    edited_image_path = edited_dir / f'{original_filename}_{safe_instruction}.jpg'
    final_mask_path = mask_dir / f'{original_filename}_{safe_instruction}.png'
    predicted_size_mask_path = mask_dir / f'{original_filename}_{safe_instruction}_primary_mask.png'

    # Restore sizes
    edited_image = restore_original_resolution(edited_image, original_size)
    final_mask = restore_original_resolution(final_mask, original_size)
    predicted_size_mask = restore_original_resolution(predicted_size_mask, original_size)

    # Save
    edited_image.save(edited_image_path)
    final_mask.save(final_mask_path)
    # predicted_size_mask.save(predicted_size_mask_path)
##############################################################

excel_path = "/content/drive/MyDrive/All_RESULTS/All_img_instruct_info.xlsx"
df = pd.read_excel(excel_path)
images_dir = "/content/drive/MyDrive/imgs/"

for index, row in df.iterrows():
    image_name = row['image_name']
    edit_instruction = row['instruct']
    image_path = str(Path(images_dir) / image_name)

    edit_image(
        pipeline=pipeline,
        image_path=image_path,
        edit_instruction=edit_instruction,
        image_guidance_scale=1.5,
        guidance_scale=7.5,
        seed=42,
        start_blending_at_tstep=100,
        verbose=verbose
    )