In [8]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
import os
import shutil
from ultralytics import YOLO  # YOLOv8 for person detection
import random
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from datasets import load_dataset
import tensorflow as tf

In [10]:
model_type = "vit_b"
sam_checkpoint = "sam_vit_b.pth"  # Ensure this file is downloaded

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to("cuda")
predictor = SamPredictor(sam)

yolo_model = YOLO("yolov8n.pt")  # Small model for fast inference

ds = load_dataset("datasets/Harvard-Edge___wake-vision", cache_dir=".")

In [54]:
def segment_person_samv2(image):
    """
    Segment persons in an image using SAM.
    """
    # Run YOLOv8 to detect persons
    yolo_results = yolo_model(image)[0]
    
    person_boxes = []
    for result in yolo_results.boxes.data:
        x1, y1, x2, y2, conf, cls = result.cpu().numpy()
        if int(cls) == 0:  # Class 0 = Person
            
            # # check confidence
            if conf < 0.9:
                print(f"Person detected with low confidence {conf}.")
                return None, None
            
            person_boxes.append([int(x1), int(y1), int(x2), int(y2)])

    if not person_boxes:
        print("No persons detected.")
        return np.zeros_like(image[:, :, 0]), image  # Empty mask

    predictor.set_image(image)

    # Generate masks for all detected persons
    masks = []
    for box in person_boxes:
        mask, _, _ = predictor.predict(box=np.array(box), multimask_output=False)
        masks.append(mask[0])

    # Merge all person masks
    if masks:
        person_mask = np.any(masks, axis=0).astype(np.uint8) * 255
    else:
        person_mask = np.zeros_like(image[:, :, 0])  # Empty mask

    return person_mask, image

In [None]:
""" Segment persons in the dataset and save original image and mask """
save_folder = "persons_with_masks"
if os.path.exists(save_folder):
    shutil.rmtree(save_folder)
os.makedirs(save_folder)

for i in range(len(ds)):
    if ds[i]["person"] == 0:
        print(f"Person not present in image {i}.")
        continue
    
    image = np.asarray(ds[i]["image"])
    
    person_mask, original_image = segment_person_samv2(image)
    
    if person_mask is None:
        continue

    # Save the original image
    cv2.imwrite(f"{save_folder}/original_{i}.png", cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR))
    
    # Save the person mask
    cv2.imwrite(f"{save_folder}/mask_{i}.png", person_mask)
    

In [62]:
""" Get background images """
n_backgrounds = len(os.listdir("persons_with_masks"))
background_folder_path = "backgrounds"

if os.path.exists(background_folder_path):
    shutil.rmtree(background_folder_path)
os.makedirs(background_folder_path)

# get images that are not labeled as persons and save to folder
for i in range(n_backgrounds):
# for i in range(10):
    if ds["train"][i]["person"] == 1:
        continue
    
    image = np.asarray(ds["train"][i]["image"])
    cv2.imwrite(f"{background_folder_path}/background_{i}.png", cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

In [63]:
def blend_person_with_background(person, mask, background):
    """
    Blend a person with a background using Poisson seamless cloning.
    """
    # Resize person to fit background
    h_bg, w_bg, _ = background.shape
    h_p, w_p, _ = person.shape

    scale_factor = random.uniform(0.6, 0.8)  # Random scaling
    # scale_factor = 0.9
    new_w, new_h = int(w_p * scale_factor), int(h_p * scale_factor)

    person_resized = cv2.resize(person, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    mask_resized = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

    # Random position on background
    x_offset = random.randint(0, w_bg - new_w)
    y_offset = random.randint(0, h_bg - new_h)
    center = (x_offset + new_w // 2, y_offset + new_h // 2)

    # Apply Poisson seamless cloning
    blended = cv2.seamlessClone(person_resized, background, mask_resized, center, cv2.NORMAL_CLONE)

    return blended

In [73]:
""" Blend persons with backgrounds and save blended images + backgrounds """
persons_with_masks_files = os.listdir("persons_with_masks")
person_files = [f for f in persons_with_masks_files if "original" in f]

# get background file names
background_files = os.listdir("backgrounds")

if os.path.exists("blended_images"):
    shutil.rmtree("blended_images")
os.makedirs("blended_images")

if os.path.exists("background_images"):
    shutil.rmtree("background_images")
os.makedirs("background_images")

# blend persons with backgrounds
for i in range(len(person_files)):
# for i in range(10):
    try:
        person = cv2.imread(f"persons_with_masks/{person_files[i]}")
        mask = cv2.imread(f"persons_with_masks/{person_files[i].replace('original', 'mask')}", cv2.IMREAD_GRAYSCALE)
        background = cv2.imread(f"backgrounds/{random.choice(background_files)}")
    except Exception as e:
        print(f"Error: {e}")
        break
    
    try:
        blended = blend_person_with_background(person, mask, background)
    except Exception as e:
        print(f"Error: {e}")
        continue
    # cv2.imwrite(f"blended_images/blended_{i}.png", cv2.cvtColor(blended, cv2.COLOR_BGR2RGB))
    
    # check again if yolo detects person in blended image
    yolo_results = yolo_model(blended)[0]
    person_detected = False
    
    for result in yolo_results.boxes.data:
        x1, y1, x2, y2, conf, cls = result.cpu().numpy()
        if int(cls) == 0:
            if conf > 0.05:
                person_detected = True
                break
        
    if person_detected:
        cv2.imwrite(f"blended_images/blended_{i}.png", blended)
    else:
        cv2.imwrite(f"background_images/background_{i}.png", blended)
    
    # show as subplots
    # fig, axs = plt.subplots(1, 4, figsize=(15, 5))
    # axs[0].imshow(cv2.cvtColor(person, cv2.COLOR_BGR2RGB))
    # axs[0].set_title("Person")
    # axs[0].axis("off")
    
    # axs[1].imshow(mask, cmap="gray")
    # axs[1].set_title("Mask")
    # axs[1].axis("off")
    
    # axs[2].imshow(cv2.cvtColor(background, cv2.COLOR_BGR2RGB))
    # axs[2].set_title("Background")
    # axs[2].axis("off")
    
    # axs[3].imshow(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB))
    # if person_detected:
    #     axs[3].set_title("Person detected")
    # else:
    #     axs[3].set_title("Person not detected")
    # axs[3].axis("off")


0: 448x640 1 person, 1 bench, 14.0ms
Speed: 24.2ms preprocess, 14.0ms inference, 2.2ms postprocess per image at shape (1, 3, 448, 640)
Error: empty range for randrange() (0, -9, -9)
Error: empty range for randrange() (0, -15, -15)

0: 448x640 1 potted plant, 7.1ms
Speed: 2.0ms preprocess, 7.1ms inference, 1.0ms postprocess per image at shape (1, 3, 448, 640)

0: 640x448 (no detections), 6.1ms
Speed: 1.3ms preprocess, 6.1ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 448)

0: 480x640 (no detections), 6.1ms
Speed: 1.4ms preprocess, 6.1ms inference, 0.4ms postprocess per image at shape (1, 3, 480, 640)

0: 416x640 1 person, 6.3ms
Speed: 1.3ms preprocess, 6.3ms inference, 0.9ms postprocess per image at shape (1, 3, 416, 640)

0: 480x640 4 bicycles, 6.1ms
Speed: 1.4ms preprocess, 6.1ms inference, 0.9ms postprocess per image at shape (1, 3, 480, 640)

0: 416x640 1 car, 2 bottles, 3 chairs, 1 dining table, 11.2ms
Speed: 1.3ms preprocess, 11.2ms inference, 2.1ms postprocess pe

In [9]:
# Paths
persons_folder = "blended_images"
no_persons_folder = "background_images"

# Function to create a TFRecord example
def create_tfrecord_example(image_path, label):
    img = tf.io.read_file(image_path)  # Read image
    feature = {
        "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.numpy()])),
        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

input_shape = (144, 144, 3)

data_preprocessing = tf.keras.Sequential([
    # Resize images to desired input shape
    tf.keras.layers.Resizing(input_shape[0], input_shape[1])
])

TFRecord saved at dataset_test.tfrecord


In [107]:
""" Create TFRecord for images not containing persons """
tfrecord_path = "dataset_no_person.tfrecord"

# Write images into TFRecord file
with tf.io.TFRecordWriter(tfrecord_path) as writer:
    for folder, label in [(no_persons_folder, 0)]:
        for filename in os.listdir(folder):
            image_path = os.path.join(folder, filename)
            if os.path.isfile(image_path):
                example = create_tfrecord_example(tf.convert_to_tensor(image_path), label)
                writer.write(example.SerializeToString())
            break

print(f"TFRecord saved at {tfrecord_path}")

In [None]:
""" Create TFRecord for images containing persons """
tfrecord_path = "dataset_person.tfrecord"

# Write images into TFRecord file
with tf.io.TFRecordWriter(tfrecord_path) as writer:
    for folder, label in [(persons_folder, 1)]:
        for filename in os.listdir(folder):
            image_path = os.path.join(folder, filename)
            if os.path.isfile(image_path):
                example = create_tfrecord_example(tf.convert_to_tensor(image_path), label)
                writer.write(example.SerializeToString())
            break

print(f"TFRecord saved at {tfrecord_path}")

In [None]:
""" Function to parse TFRecord example"""
def parse_tfrecord(example_proto):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.int64),
    }
    parsed_example = tf.io.parse_single_example(example_proto, feature_description)
    
    # Decode image
    img = tf.image.decode_png(parsed_example["image"], channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img, parsed_example["label"]

# Create TensorFlow dataset from TFRecord
raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
train_ds = raw_dataset.map(parse_tfrecord).map(lambda img, label: (data_preprocessing(img), label)).shuffle(1000).batch(32)

# Check dataset
for img, label in train_ds.take(1):
    plt.imshow(img[0])
    plt.show()
    print(f"Image shape: {img.shape}, Label: {label}")