In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import torch
import torchinfo
import random
import string

from torch.utils.data import Dataset, DataLoader

TRAIN_FOLDER = 'train/'
TEST_FOLDER = 'test/'

EPOCHS = 1000
BATCH_SIZE = 32
LR = 1e-4
IMAGES_PER_EPOCH = 1024
SIZE = 256

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Device: {device}")

In [None]:
class Connect(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.relu = torch.nn.ReLU()

  def forward(self, x, input, output):
    conv2d = torch.nn.Conv2d(input, output, kernel_size=(3, 3), padding="same")
    batchnorm2d = torch.nn.BatchNorm2d(output)

    x = conv2d(x)
    x = batchnorm2d(x)
    x = self.relu(x)
    return x

In [None]:
class UNet(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.connect = Connect()
    self.maxpool2d = torch.nn.MaxPool2d((2, 2))
    self.conv2d = torch.nn.Conv2d(32, 1, (1, 1), padding="same")

  def forward(self, x):
    first_convtranspose2d = torch.nn.ConvTranspose2d(256, 128, stride=(2, 2))
    second_convtranspose2d = torch.nn.ConvTranspose2d(128, 64, stride=(2, 2))
    third_convtranspose2d = torch.nn.ConvTranspose2d(64, 32, stride=(2, 2))

    x = self.connect(x, 3, 32)
    x = self.connect(x, 32, 32)
    first_layer = x

    x = self.maxpool2d(x)
    x = self.connect(x, 32, 64)
    x = self.connect(x, 64, 64)
    second_layer = x

    x = self.maxpool2d(x)
    x = self.connect(x, 64, 128)
    x = self.connect(x, 128, 128)
    third_layer = x

    x = self.maxpool2d(x)
    x = self.connect(x, 128, 256)
    x = self.connect(x, 256, 256)

    x = first_convtranspose2d(x)
    x = torch.cat((third_layer, x), dim=1)
    x = self.connect(x, 256, 128)
    x = self.connect(x, 128, 128)

    x = second_convtranspose2d(x)
    x = torch.cat((second_layer, x), dim=1)
    x = self.connect(x, 128, 64)
    x = self.connect(x, 64, 64)
    
    x = third_convtranspose2d(x)
    x = torch.cat((first_layer, x), dim=1)
    x = self.connect(x, 64, 32)
    x = self.connect(x, 32, 32)
    x = self.conv2d(x)

    return x

In [None]:
def display_image(image, title, cmap="viridis"):
  plt.imshow(image, cmap=cmap)
  plt.title(title)
  plt.axis('off')
  plt.show()

In [None]:
def fix_image(image):
  image = image / 255
  image[image > 1] = 1
  image[image < 0] = 0
  image = image.astype(np.float32)
  return image

In [None]:
def load_image(path):
  image = cv2.imread(path, -1)

  if image is not None:
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return fix_image(image)

In [None]:
def get_image_portion(image):
  x1 = random.randint(0, image.shape[0] - SIZE * 2 - 1)
  y1 = random.randint(0, image.shape[1] - SIZE * 2 - 1)

  x2, y2 = x1 + SIZE * 2, y1 + SIZE * 2

  return image[x1:x2, y1:y2, 0:3]

In [None]:
def rotate_image(image, angle):
  image_center = tuple(np.array(image.shape[1::-1]) / 2)
  rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
  result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LINEAR)
  return result

In [None]:
def generate_images(background, foreground, display=False):
  background_portion = get_image_portion(background)
  foreground_portion = get_image_portion(foreground)

  text = ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.digits + ' ') for _ in range(random.randint(20, 50)))
  text_location = (random.randint(SIZE / 2, SIZE), random.randint(SIZE / 2, SIZE))
  fonts = [0, 1, 2, 3, 4, 5, 6, 7, 16]

  mask = np.zeros((SIZE * 2, SIZE * 2, 3), dtype=np.float32)
  cv2.putText(
    mask, 
    text, 
    text_location,
    random.choice(fonts),
    random.uniform(1, 3),
    (255, 255, 255),
    random.randint(1, 5),
    cv2.LINE_AA,
    False
  )
  #mask = cv2.warpAffine(mask, cv2.getRotationMatrix2D(text_location, random.uniform(0, 360), 1), (mask.shape[1], mask.shape[0]))
  mask = fix_image(mask)
  target = (1 - mask) * background_portion + mask * foreground_portion

  angle = random.randint(0, 360)
  flip = random.randint(-1, 1)

  mask = rotate_image(mask, angle)
  target = rotate_image(target, angle)

  mask = cv2.flip(mask, flip)
  target = cv2.flip(target, flip)

  mask[mask >= 0.5] = 1
  mask[mask < 0.5] = 0

  mask = mask[128:384, 128:384, 0:3]
  target = target[128:384, 128:384, 0:3]

  if display:
    _, axes = plt.subplots(2, 3, figsize=(15, 15))
    axes[0, 0].imshow(background)
    axes[0, 0].set_title("Background " + str(background.shape))
    axes[0, 0].axis("off")

    axes[0, 1].imshow(foreground)
    axes[0, 1].set_title("Foreground " + str(foreground.shape))
    axes[0, 1].axis("off")
    
    axes[0, 2].imshow(background_portion)
    axes[0, 2].set_title("Background portion " + str(background_portion.shape))
    axes[0, 2].axis("off")
    
    axes[1, 0].imshow(foreground_portion)
    axes[1, 0].set_title("Foreground portion " + str(foreground_portion.shape))
    axes[1, 0].axis("off")
    
    axes[1, 1].imshow(mask)
    axes[1, 1].set_title("Mask " + str(mask.shape))
    axes[1, 1].axis("off")
    
    axes[1, 2].imshow(target)
    axes[1, 2].set_title("Target " + str(target.shape))
    axes[1, 2].axis("off")

  return target, mask

In [None]:
background = load_image(TRAIN_FOLDER + random.choice(os.listdir(TRAIN_FOLDER)))
foreground = load_image(TRAIN_FOLDER + random.choice(os.listdir(TRAIN_FOLDER)))
target, mask = generate_images(background, foreground, True)