In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import torch
import torchinfo
import random
import string

from torch.utils.data import Dataset, DataLoader

TRAIN_FOLDER = 'train/'
TEST_FOLDER = 'test/'
PHOTOS_FOLDER = 'photos/'

EPOCHS = 1000
BATCH_SIZE = 32
LR = 1e-4
IMAGES_PER_EPOCH = 1024
SIZE = 256

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Device: {device}")

In [None]:
class ImageDataset(Dataset):
  def __init__(self, X, Y):
    self.X = X
    self.Y = Y

  def __len__(self):
    return len(self.X)

  def __getitem__(self, index):
    return self.X[index], self.Y[index]

In [None]:
class Connect(torch.nn.Module):
  def __init__(self, input, output) -> None:
    super().__init__()
    self.conv2d = torch.nn.Conv2d(input, output, kernel_size=(3, 3), padding="same")
    self.batchnorm2d = torch.nn.BatchNorm2d(output)
    self.relu = torch.nn.ReLU()

  def forward(self, x):
    x = self.conv2d(x)
    x = self.batchnorm2d(x)
    x = self.relu(x)
    return x

In [None]:
class DownBlock(torch.nn.Module):
  def __init__(self, input, output) -> None:
    super().__init__()
    self.first_connection = Connect(input, output)
    self.second_connection = Connect(output, output)

  def forward(self, x):
    x = self.first_connection(x)
    x = self.second_connection(x)
    return x, x

In [None]:
class UpBlock(torch.nn.Module):
  def __init__(self, input, output) -> None:
    super().__init__()
    self.convtranspose2d = torch.nn.ConvTranspose2d(input, output, kernel_size=(2, 2), stride=(2, 2), padding=0)
    self.first_connection = Connect(input, output)
    self.second_connection = Connect(output, output)

  def forward(self, x, y):
    x = self.convtranspose2d(x)
    x = torch.cat((y, x), dim=1)
    x = self.first_connection(x)
    x = self.second_connection(x)
    return x

In [None]:
class UNet(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.first_down_block = DownBlock(3, 32)
    self.second_down_block = DownBlock(32, 64)
    self.third_down_block = DownBlock(64, 128)
    self.fourth_down_block = DownBlock(128, 256)

    self.first_up_block = UpBlock(256, 128)
    self.second_up_block = UpBlock(128, 64)
    self.third_up_block = UpBlock(64, 32)

    self.maxpool2d = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
    self.conv2d = torch.nn.Conv2d(32, 1, (1, 1), padding="same")

  def forward(self, x):
    x, first_layer = self.first_down_block(x)

    x = self.maxpool2d(x)
    x, second_layer = self.second_down_block(x)

    x = self.maxpool2d(x)
    x, third_layer = self.third_down_block(x)

    x = self.maxpool2d(x)
    x, _ = self.fourth_down_block(x)

    x = self.first_up_block(x, third_layer)
    x = self.second_up_block(x, second_layer)
    x = self.third_up_block(x, first_layer)
    x = self.conv2d(x)

    return x

In [None]:
def display_image(image, title, cmap="viridis"):
  plt.imshow(image, cmap=cmap)
  plt.title(title)
  plt.axis('off')
  plt.show()

In [None]:
def display_images(input, target, output):
  red = np.full(target.shape[:2]+(3,), [1, 0, 0], dtype=np.float32)
  green = np.full(target.shape[:2]+(3,), [0, 1, 0], dtype=np.float32)
  diff = np.where(target == output, green, red)

  _, axes = plt.subplots(1, 4, figsize=(15, 15), squeeze=False)
  axes[0, 0].imshow(input)
  axes[0, 0].set_title("Input")
  axes[0, 0].axis("off")

  axes[0, 1].imshow(target, cmap="gray")
  axes[0, 1].set_title("Target")
  axes[0, 1].axis("off")

  axes[0, 2].imshow(output, cmap="gray")
  axes[0, 2].set_title("Output")
  axes[0, 2].axis("off")

  axes[0, 3].imshow(diff)
  axes[0, 3].set_title("Diff")
  axes[0, 3].axis("off")

  plt.tight_layout()

In [None]:
def fix_image(image):
  image = image.astype(np.float32)
  image = image / 255
  image[image > 1] = 1
  image[image < 0] = 0
  return image

In [None]:
def transpose(image):
  image = image.transpose(-1, 0, 1)

  return image

In [None]:
def load_image(path):
  image = cv2.imread(path)

  if image is not None:
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return fix_image(image)

In [None]:
def load_images_from_folder(folder):
  images = []
  for filename in os.listdir(folder):
    image = cv2.imread(os.path.join(folder, filename))
    if image is not None:
      image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
      images.append(fix_image(image))

  return images

In [None]:
def iou(target, prediction):
  target[target >= 0.5] = 1
  target[target < 0.5] = 0

  target = target[:,:,0]
  prediction = prediction[:,:,0]

  ones = np.sum(prediction[target == 1]) / np.size(prediction)
  zeros = np.sum(prediction[target == 0]) / np.size(prediction)

  return (ones + zeros) / 2

In [None]:
def dice(target, prediction):
  target[target >= 0.5] = 1
  target[target < 0.5] = 0

  target = target[:,:,0]
  prediction = prediction[:,:,0]

  ones = 2.0 * np.sum(prediction[target == 1]) / (np.size(prediction) + np.size(target))
  zeros = 2.0 * np.sum(prediction[target == 0]) / (np.size(prediction) + np.size(target))

  return (ones + zeros) / 2

In [None]:
def get_image_portion(image):
  x1 = random.randint(0, image.shape[0] - SIZE * 2 - 1)
  y1 = random.randint(0, image.shape[1] - SIZE * 2 - 1)

  x2, y2 = x1 + SIZE * 2, y1 + SIZE * 2

  return image[x1:x2, y1:y2, 0:3]

In [None]:
def rotate_image(image, angle):
  image_center = tuple(np.array(image.shape[1::-1]) / 2)
  rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
  result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LINEAR)
  return result

In [None]:
def generate_images(background, foreground, display=False):
  background_portion = get_image_portion(background)
  foreground_portion = get_image_portion(foreground)

  text = ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.digits + ' ') for _ in range(random.randint(20, 50)))
  text_location = (random.randint(SIZE / 2, SIZE), random.randint(SIZE / 2, SIZE))
  fonts = [0, 1, 2, 3, 4, 5, 6, 7, 16]

  mask = np.zeros((SIZE * 2, SIZE * 2, 3), dtype=np.float32)
  cv2.putText(
    mask, 
    text, 
    text_location,
    random.choice(fonts),
    random.uniform(1, 3),
    (255, 255, 255),
    random.randint(1, 5),
    cv2.LINE_AA,
    False
  )
  #mask = cv2.warpAffine(mask, cv2.getRotationMatrix2D(text_location, random.uniform(0, 360), 1), (mask.shape[1], mask.shape[0]))
  mask = fix_image(mask)
  input = (1 - mask) * background_portion + mask * foreground_portion

  angle = random.randint(0, 360)
  flip = random.randint(-1, 1)

  mask = rotate_image(mask, angle)
  input = rotate_image(input, angle)

  mask = cv2.flip(mask, flip)
  input = cv2.flip(input, flip)

  mask[mask >= 0.5] = 1
  mask[mask < 0.5] = 0

  mask = mask[128:384, 128:384, 0:3]
  input = input[128:384, 128:384, 0:3]

  if display:
    _, axes = plt.subplots(2, 3, figsize=(15, 15))
    axes[0, 0].imshow(background)
    axes[0, 0].set_title("Background " + str(background.shape))
    axes[0, 0].axis("off")

    axes[0, 1].imshow(foreground)
    axes[0, 1].set_title("Foreground " + str(foreground.shape))
    axes[0, 1].axis("off")
    
    axes[0, 2].imshow(background_portion)
    axes[0, 2].set_title("Background portion " + str(background_portion.shape))
    axes[0, 2].axis("off")
    
    axes[1, 0].imshow(foreground_portion)
    axes[1, 0].set_title("Foreground portion " + str(foreground_portion.shape))
    axes[1, 0].axis("off")
    
    axes[1, 1].imshow(mask)
    axes[1, 1].set_title("Target " + str(mask.shape))
    axes[1, 1].axis("off")
    
    axes[1, 2].imshow(input)
    axes[1, 2].set_title("Input " + str(input.shape))
    axes[1, 2].axis("off")

  return input, mask

In [None]:
def train():
  model = UNet()
  model = model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=LR)
  loss_fn = torch.nn.BCEWithLogitsLoss()

  images = load_images_from_folder(TRAIN_FOLDER)
  print('IMAGES LOADED')
  
  model.train()

  for epoch in range(EPOCHS):
    inputs = []
    targets = []

    if (epoch + 1) % 20 == 0:
      print('epoch:' + str(epoch + 1) + ' of ' + str(EPOCHS))

    for i in range(IMAGES_PER_EPOCH):
      if (i + 1) % 128 == 0:
        print('generating image: ' + str(i))
      input, target = generate_images(random.choice(images), random.choice(images), i < 5 and epoch == 0)

      inputs.append(input)
      targets.append(target)

    for i in range(len(inputs)):
      inputs[i] = transpose(inputs[i])

    for i in range(len(targets)):
      targets[i] = transpose(targets[i])

    dataset = ImageDataset(inputs, targets)
    dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

    for step, (input_batch, target_batch) in enumerate(dataloader):
      if (step + 1) % 100 == 0:
        print('training image:' + str(step + 1) + ' of ' + str(len(dataloader)))

      input_batch = input_batch.to(device)
      target_batch = target_batch.to(device)

      optimizer.zero_grad()

      pred = model(input_batch)

      loss = loss_fn(pred, target_batch[0:1, 0:1, 0:SIZE, 0:SIZE])

      loss.backward()

      optimizer.step()

  torch.save(model, "text-segmentation.pt")

In [None]:
def predict(model, input):
  model.eval()
  with torch.no_grad():
    output = model(input)
    output = torch.sigmoid(output).round()
    output = output.cpu().detach()
    
    return output.squeeze(0).permute(1, 2, 0)

In [None]:
def test_generated(n, display=False):
  model = torch.load("text-segmentation.pt", map_location=device)
  ious = []
  dices = []

  for _ in range(n):
    background = load_image(TEST_FOLDER + random.choice(os.listdir(TEST_FOLDER)))
    foreground = load_image(TEST_FOLDER + random.choice(os.listdir(TEST_FOLDER)))
    input, target = generate_images(background, foreground, False)

    input = torch.from_numpy(transpose(input)).unsqueeze(0).to(device)

    prediction = predict(model, input).numpy()

    ious.append(iou(target, prediction))
    dices.append(dice(target, prediction))

    if display:
      print(f"IOU: {iou(target, prediction):.4f}\tDICE: {dice(target, prediction):.4f}")
      display_images(input.cpu().detach().squeeze().permute(1, 2, 0), target, prediction)

  print(f"IoU (povprečje +- standardni odklon): {np.mean(ious) :.2f} +- {np.std(ious) :.2f}")
  print(f"Dice (povprečje +- standardni odklon): {np.mean(dices) :.2f} +- {np.std(dices) :.2f}")

In [None]:
def test_photos(display=True):
  images = load_images_from_folder(PHOTOS_FOLDER)
  model = torch.load("text-segmentation.pt", map_location=device)

  for image in images:
    image = torch.from_numpy(transpose(image)).unsqueeze(0).to(device)

    prediction = predict(model, image).numpy()
    image = image.cpu().detach().squeeze().permute(1, 2, 0)

    if display:
      _, axes = plt.subplots(1, 2, figsize=(15, 15), squeeze=False)

      axes[0, 0].imshow(image)
      axes[0, 0].axis("off")
      axes[0, 0].set_title("Input")
      axes[0, 1].imshow(prediction, cmap="gray")
      axes[0, 1].axis("off")
      axes[0, 1].set_title("Output")
      plt.tight_layout()

In [None]:
#background = load_image(TRAIN_FOLDER + random.choice(os.listdir(TRAIN_FOLDER)))
#foreground = load_image(TRAIN_FOLDER + random.choice(os.listdir(TRAIN_FOLDER)))
#target, mask = generate_images(background, foreground, True)
train()

In [None]:
model = torch.load("text-segmentation.pt", map_location=device)
torchinfo.summary(model, (1, 3, 256, 256), depth=32)

In [None]:
test_generated(3, True)

In [None]:
test_generated(1000)

In [None]:
test_photos()