In [2]:
import os
import numpy as np
import cv2
from glob import glob
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping
import matplotlib.pyplot as plt
from tqdm import tqdm
import urllib
import IPython

In [3]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

In [4]:
def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

In [5]:
def decoder_block(inputs, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [6]:
def build_unet(input_shape):
    inputs = Input(input_shape)

    """ Encoder """
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)
    b1 = conv_block(p4, 1024)

    """ Decoder """
    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    """ Output """
    outputs = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(d4)
    return Model(inputs, outputs, name="U-Net")

In [7]:
def load_data(dataset_path):
    images = sorted(glob(os.path.join(dataset_path, "images/*")))
    masks = sorted(glob(os.path.join(dataset_path, "masks/*")))
    train_x, test_x = train_test_split(images, test_size=0.2, random_state=42)
    train_y, test_y = train_test_split(masks, test_size=0.2, random_state=42)
    return (train_x, train_y), (test_x, test_y)

In [8]:
def read_image(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (256, 256))
    x = x/255.0
    x = x.astype(np.float32)
    # (256, 256, 3)
    return x

In [9]:
def read_mask(path):
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    x = cv2.resize(x, (256, 256))
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=-1)
    return x

In [10]:
def preprocess(image_path, mask_path):
    def f(image_path, mask_path):
        image_path = image_path.decode()
        mask_path = mask_path.decode()
        x = read_image(image_path)
        y = read_mask(mask_path)
        return x, y
    image, mask = tf.numpy_function(f, [image_path, mask_path], [tf.float32, tf.float32])
    image.set_shape([256, 256, 3])
    mask.set_shape([256, 256, 1])
    return image, mask

In [16]:
def tf_dataset(images, masks, batch=8):
    dataset = tf.data.Dataset.from_tensor_slices((images, masks))
    dataset = dataset.shuffle(buffer_size=5000)
    dataset = dataset.map(preprocess)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(2)
    dataset = dataset.repeat()
    return dataset

In [22]:
""" Hyperparameters """
dataset_path = "../input/person-segmentation/people_segmentation"
input_shape = (256, 256, 3)
batch_size = 8
epochs = 20
lr = 1e-4
model_path = "./unet.keras"
csv_path = "./data.csv"

""" Loading the dataset """
(train_x, train_y), (test_x, test_y) = load_data(dataset_path)

train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(test_x, test_y, batch=batch_size)

In [18]:
valid_dataset

<_RepeatDataset element_spec=(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, name=None))>

In [23]:

""" Model """
model = build_unet(input_shape)
model.compile(
    loss="binary_crossentropy",
    optimizer=tf.keras.optimizers.Adam(lr),
    metrics=[
        tf.keras.metrics.MeanIoU(num_classes=2),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision()
    ]
)

callbacks = [
    ModelCheckpoint(model_path, monitor="val_loss", verbose=1),
    ReduceLROnPlateau(monitor="val_loss", patience=5, factor=0.1, verbose=1),
    CSVLogger(csv_path),
    EarlyStopping(monitor="val_loss", patience=10)
]

train_steps = len(train_x)//batch_size
if len(train_x) % batch_size != 0:
    train_steps += 1
    
valid_steps = len(test_x)//batch_size
if len(test_x) % batch_size != 0:
    valid_steps += 1

model.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs = epochs,
    steps_per_epoch=train_steps,
    validation_steps=valid_steps,
    callbacks=callbacks
)

Epoch 1/20
[1m568/568[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 219ms/step - loss: 0.4482 - mean_io_u_2: 0.2671 - precision_2: 0.6082 - recall_2: 0.5524
Epoch 1: saving model to ./unet.keras
[1m568/568[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m156s[0m 241ms/step - loss: 0.4481 - mean_io_u_2: 0.2671 - precision_2: 0.6084 - recall_2: 0.5525 - val_loss: 0.3494 - val_mean_io_u_2: 0.2500 - val_precision_2: 0.7007 - val_recall_2: 0.7633 - learning_rate: 1.0000e-04
Epoch 2/20
[1m568/568[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 207ms/step - loss: 0.3053 - mean_io_u_2: 0.2664 - precision_2: 0.7734 - recall_2: 0.6915
Epoch 2: saving model to ./unet.keras
[1m568/568[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 227ms/step - loss: 0.3053 - mean_io_u_2: 0.2664 - precision_2: 0.7734 - recall_2: 0.6915 - val_loss: 0.3285 - val_mean_io_u_2: 0.2500 - val_precision_2: 0.6879 - val_recall_2: 0.8007 - learning_rate: 1.0000e-04
Epoch 3/20
[1m568/568[0m [32m━━

<keras.src.callbacks.history.History at 0x78d6d0c9ebc0>

In [24]:
model.save("unet_model.keras")

In [25]:
# https://raw.githubusercontent.com/nikhilroxtomar/Unet-for-Person-Segmentation/main/images/Black-Widow-Avengers.jpg

#test_images = glob("images/*")


test_images = [
    'https://raw.githubusercontent.com/nikhilroxtomar/Unet-for-Person-Segmentation/main/images/Black-Widow-Avengers.jpg'
]

model = tf.keras.models.load_model("./unet.keras")
for path in tqdm(test_images, total=len(test_images)):
    
    req = urllib.request.urlopen(path)
    imgarr = np.asarray(bytearray(req.read()), dtype=np.uint8)
    
    x = cv2.imdecode(imgarr, -1)
    
    #x = cv2.imread(path, cv2.IMREAD_COLOR)
    original_image = x
    h, w, _ = x.shape
    
    x = cv2.resize(x, (256, 256))
    x = x/255.0
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=0)

    pred_mask = model.predict(x)[0]
    pred_mask = cv2.resize(pred_mask, (w, h))
    pred_mask = np.expand_dims(pred_mask, axis=-1)
    pred_mask = pred_mask > 0.5
    
    background_mask = np.abs(1- pred_mask)
        
    masked_image = original_image * pred_mask
    
    background_mask = np.concatenate([background_mask, background_mask, background_mask], axis=-1)
    background_mask = background_mask * [0, 0, 0]
    
    masked_image = masked_image + background_mask
    name = path.split("/")[-1]
    cv2.imwrite(f"{name}.png", masked_image)

  0%|          | 0/1 [00:00<?, ?it/s]

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step


100%|██████████| 1/1 [00:01<00:00,  1.31s/it]
