In this notebook, I will use Stable Diffusion inpainting and Segment Anything Model to [change clothes in a photo](https://pjoshi15.com/change-clothes-stable-diffusion-sam/).

I will be using a stock image for this experiment. You will be able to access that image once you run this notebook.

Let's get started!

In [None]:
!pip install diffusers

In [None]:
import torch
from torchvision import transforms
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image, make_image_grid
import matplotlib.pyplot as plt

## Load input image

In [None]:
# load image
img = load_image("/kaggle/input/stock-images/white_tshirt.jpg")

# display image
img

## Import SlimSAM for creating masks

Now I will import the [SLimSAM model](https://github.com/czg1225/SlimSAM). This model will be used to segment the object of our choice.

In [None]:
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50").to("cuda")
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50")

## Create mask for T-shirt

In [6]:
input_points = [[[320, 600]]] # coordinates of a point on the object of interest

inputs = processor(img, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)

# extract mask tensors
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())

In [None]:
# count number of masks
len(masks[0][0])

Now you can plot all these masks one-by-one and see which one is the best for the object that you want to select.

In [None]:
# best mask tensor
plt.imshow(masks[0][0][2])

## Create mask for pants

In [9]:
input_points_2 = [[[200, 850]]] # coordinates of a point on the object of interest

inputs_2 = processor(img, input_points=input_points_2, return_tensors="pt").to("cuda")
outputs_2 = model(**inputs_2)
masks_2 = processor.image_processor.post_process_masks(outputs_2.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())

In [None]:
# count number of masks
len(masks_2[0][0])

In [None]:
plt.imshow(masks_2[0][0][1])

## Convert PyTorch tensors to PIL image format

In [12]:
# Create a ToPILImage transform
to_pil = transforms.ToPILImage()

# Convert boolean tensors to binary tensors
binary_matrix_1 = masks[0][0][2].to(dtype=torch.uint8)
binary_matrix_2 = masks_2[0][0][1].to(dtype=torch.uint8)

# apply the transform to the tensors (tensor to PIL)
mask_1 = to_pil(binary_matrix_1*255)
mask_2 = to_pil(binary_matrix_2*255)

In [None]:
# display original image with masks
make_image_grid([img, mask_1, mask_2], cols = 3, rows = 1)

## Import Stable Diffusion inpainting model

In [None]:
# Create inpainting pipeline
pipeline = AutoPipelineForInpainting.from_pretrained(
    "redstonehero/ReV_Animated_Inpainting", 
    torch_dtype=torch.float16
)
 
pipeline.enable_model_cpu_offload()

## Edit T-shirt in image

In [None]:
prompt = "flower-print, t-shirt"

# inpainting pipeline
image1 = pipeline(prompt=prompt,
                 width=512,
                 height=768,
                 num_inference_steps=28,
                 image=img, 
                 mask_image=mask_1,
                 guidance_scale=3,
                 strength=1.0).images[0]

In [None]:
# compare input and output
make_image_grid([img.resize([512,768]), image1], rows = 1, cols = 2)

## Edit pants in image

In [None]:
prompt = "tactical pants"
 
image = pipeline(prompt=prompt,
                 width=512,
                 height=768,
                 num_inference_steps=30,
                 image=img, 
                 mask_image=mask_2,
                 guidance_scale=2.5,
                 strength=1.0).images[0]
 
make_image_grid([img.resize([512,768]), image], rows = 1, cols = 2)

We can also use an [IP-Adapter model](https://pjoshi15.com/generate-images-ipadapters-diffusers/) here to use an existing dress and guide the model to generate a similar dress. There is a wide range of things that we can do if we combine different AI models, like how combining SlimSAM and Stable Diffusion helped in generating new clothes in the images.