<a href="https://colab.research.google.com/github/ritwikraha/Open-Generative-Fill/blob/main/notebooks/open_generative_fill_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup and Installation

In [None]:
!pip install -qq git+https://github.com/ritwikraha/Open-Generative-Fill

In [None]:
import torch

from open_generative_fill import config
from open_generative_fill.lm_models import run_lm_model
from open_generative_fill.load_data import load_image
from open_generative_fill.vision_models import (
    run_caption_model,
    run_inpainting_pipeline,
    run_segmentaiton_pipeline,
)

In [None]:
# @title Enter values for generation
image_url = "https://i.imgur.com/4ujXoav.jpeg" # @param {type:"string"}
edit_prompt = "change the bottle to a firecracker" # @param {type:"string"}
seed_value = 178334 # @param {type:"slider", min:0, max:999999, step:1}

## Loading the Image and Models

In [None]:
# Set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
GENERATOR = torch.Generator().manual_seed(seed_value)
# Load the image from the url and get the text prompt
image = load_image(image_url=image_url, image_size=config.IMAGE_SIZE)

## Image Captioning Models

In [None]:
# Image captioning
caption = run_caption_model(
    model_id=config.CAPTION_MODEL_ID, image=image, device=device
)

print(caption)

## Language Model

In [None]:
# Language model
to_replace, replaced_caption = run_lm_model(
    model_id=config.LANGUAGE_MODEL_ID,
    caption=caption,
    edit_prompt=edit_prompt,
    device=device,
)

print(to_replace)
print(replaced_caption)

## Segmentation Model

In [None]:
# Segmentation pipeline
segmentation_mask = run_segmentaiton_pipeline(
    detection_model_id=config.DETECTION_MODEL_ID,
    segmentation_model_id=config.SEGMENTATION_MODEL_ID,
    to_replace=to_replace,
    image=image,
    device=device,
)

segmentation_mask

## Inpainting Model

In [None]:
# Inpainting pipeline
output = run_inpainting_pipeline(
    inpainting_model_id=config.INPAINTING_MODEL_ID,
    image=image,
    mask=segmentation_mask,
    replaced_caption=replaced_caption,
    image_size=config.IMAGE_SIZE,
    generator=GENERATOR,
    device=device,
)

## Final Output

In [None]:
output