In [484]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import torch
import torchinfo
import random
import string

from torch.utils.data import Dataset, DataLoader

TRAIN_FOLDER = 'train/'
TEST_FOLDER = 'test/'

EPOCHS = 5
BATCH_SIZE = 1
LR = 1e-4
IMAGES_PER_EPOCH = 10
SIZE = 256

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Device: {device}")

Device: cuda


In [485]:
class ImageDataset(Dataset):
  def __init__(self, X, Y):
    self.X = X
    self.Y = Y

  def __len__(self):
    return len(self.X)

  def __getitem__(self, index):
    return self.X[index], self.Y[index]

In [486]:
class Connect(torch.nn.Module):
  def __init__(self, input, output) -> None:
    super().__init__()
    self.conv2d = torch.nn.Conv2d(input, output, kernel_size=(3, 3), padding="same")
    self.batchnorm2d = torch.nn.BatchNorm2d(output)
    self.relu = torch.nn.ReLU()

  def forward(self, x):
    x = self.conv2d(x)
    x = self.batchnorm2d(x)
    x = self.relu(x)
    return x

In [487]:
class UNet(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.connect3_32 = Connect(3, 32)
    self.connect32_32 = Connect(32, 32)
    self.connect32_32_2 = Connect(32, 32)
    self.connect32_64 = Connect(32, 64)
    self.connect64_64 = Connect(64, 64)
    self.connect64_64_2 = Connect(64, 64)
    self.connect64_128 = Connect(64, 128)
    self.connect128_128 = Connect(128, 128)
    self.connect128_128_2 = Connect(128, 128)
    self.connect128_256 = Connect(128, 256)
    self.connect256_256 = Connect(256, 256)
    self.connect256_128 = Connect(256, 128)
    self.connect128_64 = Connect(128, 64)
    self.connect64_32 = Connect(64, 32)
    self.maxpool2d = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
    self.conv2d = torch.nn.Conv2d(32, 1, (1, 1), padding="same")
    self.first_convtranspose2d = torch.nn.ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    self.second_convtranspose2d = torch.nn.ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    self.third_convtranspose2d = torch.nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))

  def forward(self, x):
    x = self.connect3_32(x)
    x = self.connect32_32(x)
    first_layer = torch.clone(x)

    x = self.maxpool2d(x)
    x = self.connect32_64(x)
    x = self.connect64_64(x)
    second_layer = torch.clone(x)

    x = self.maxpool2d(x)
    x = self.connect64_128(x)
    x = self.connect128_128(x)
    third_layer = torch.clone(x)

    x = self.maxpool2d(x)
    x = self.connect128_256(x)
    x = self.connect256_256(x)

    x = self.first_convtranspose2d(x)
    x = torch.cat((third_layer, x), dim=1)
    x = self.connect256_128(x)
    x = self.connect128_128_2(x)

    x = self.second_convtranspose2d(x)
    x = torch.cat((second_layer, x), dim=1)
    x = self.connect128_64(x)
    x = self.connect64_64_2(x)
    
    x = self.third_convtranspose2d(x)
    x = torch.cat((first_layer, x), dim=1)
    x = self.connect64_32(x)
    x = self.connect32_32_2(x)
    x = self.conv2d(x)

    return x

In [488]:
def display_image(image, title, cmap="viridis"):
  plt.imshow(image, cmap=cmap)
  plt.title(title)
  plt.axis('off')
  plt.show()

In [489]:
def fix_image(image):
  image = image / 255
  image[image > 1] = 1
  image[image < 0] = 0
  image = image.astype(np.float32)
  return image

In [490]:
def transpose(image):
  image = image.transpose(-1, 0, 1)

  return image

In [491]:
def load_image(path):
  image = cv2.imread(path, -1)

  if image is not None:
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return fix_image(image)

In [492]:
def get_image_portion(image):
  x1 = random.randint(0, image.shape[0] - SIZE * 2 - 1)
  y1 = random.randint(0, image.shape[1] - SIZE * 2 - 1)

  x2, y2 = x1 + SIZE * 2, y1 + SIZE * 2

  return image[x1:x2, y1:y2, 0:3]

In [493]:
def rotate_image(image, angle):
  image_center = tuple(np.array(image.shape[1::-1]) / 2)
  rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
  result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LINEAR)
  return result

In [494]:
def generate_images(background, foreground, display=False):
  background_portion = get_image_portion(background)
  foreground_portion = get_image_portion(foreground)

  text = ''.join(random.choice(string.ascii_lowercase + string.ascii_uppercase + string.digits + ' ') for _ in range(random.randint(20, 50)))
  text_location = (random.randint(SIZE / 2, SIZE), random.randint(SIZE / 2, SIZE))
  fonts = [0, 1, 2, 3, 4, 5, 6, 7, 16]

  mask = np.zeros((SIZE * 2, SIZE * 2, 3), dtype=np.float32)
  cv2.putText(
    mask, 
    text, 
    text_location,
    random.choice(fonts),
    random.uniform(1, 3),
    (255, 255, 255),
    random.randint(1, 5),
    cv2.LINE_AA,
    False
  )
  #mask = cv2.warpAffine(mask, cv2.getRotationMatrix2D(text_location, random.uniform(0, 360), 1), (mask.shape[1], mask.shape[0]))
  mask = fix_image(mask)
  target = (1 - mask) * background_portion + mask * foreground_portion

  angle = random.randint(0, 360)
  flip = random.randint(-1, 1)

  mask = rotate_image(mask, angle)
  target = rotate_image(target, angle)

  mask = cv2.flip(mask, flip)
  target = cv2.flip(target, flip)

  mask[mask >= 0.5] = 1
  mask[mask < 0.5] = 0

  mask = mask[128:384, 128:384, 0:3]
  target = target[128:384, 128:384, 0:3]

  if display:
    _, axes = plt.subplots(2, 3, figsize=(15, 15))
    axes[0, 0].imshow(background)
    axes[0, 0].set_title("Background " + str(background.shape))
    axes[0, 0].axis("off")

    axes[0, 1].imshow(foreground)
    axes[0, 1].set_title("Foreground " + str(foreground.shape))
    axes[0, 1].axis("off")
    
    axes[0, 2].imshow(background_portion)
    axes[0, 2].set_title("Background portion " + str(background_portion.shape))
    axes[0, 2].axis("off")
    
    axes[1, 0].imshow(foreground_portion)
    axes[1, 0].set_title("Foreground portion " + str(foreground_portion.shape))
    axes[1, 0].axis("off")
    
    axes[1, 1].imshow(mask)
    axes[1, 1].set_title("Mask " + str(mask.shape))
    axes[1, 1].axis("off")
    
    axes[1, 2].imshow(target)
    axes[1, 2].set_title("Target " + str(target.shape))
    axes[1, 2].axis("off")

  return target, mask

In [495]:
def train():
  model = UNet()
  model = model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=LR)
  loss_fn = torch.nn.BCEWithLogitsLoss()

  background = load_image(TRAIN_FOLDER + random.choice(os.listdir(TRAIN_FOLDER)))
  foreground = load_image(TRAIN_FOLDER + random.choice(os.listdir(TRAIN_FOLDER)))
  
  model.train()

  for epoch in range(EPOCHS):
    targets = []
    masks = []

    if (epoch + 1) % 20 == 0:
      print('epoch:' + str(epoch + 1) + ' of ' + str(EPOCHS))

    for i in range(IMAGES_PER_EPOCH):
      #if (i + 1) % 128 == 0:
      print('generating image: ' + str(i))
      target, mask = generate_images(background, foreground, False)

      targets.append(target)
      masks.append(mask)

    for i in range(len(targets)):
      targets[i] = transpose(targets[i])

    for i in range(len(masks)):
      masks[i] = transpose(masks[i])

    dataset = ImageDataset(targets, masks)
    dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

    for step, (target_batch, mask_batch) in enumerate(dataloader):
      if (step + 1) % 100 == 0:
        print('training image:' + str(step + 1) + ' of ' + str(len(dataloader)))

      target_batch = target_batch.to(device)
      mask_batch = mask_batch.to(device)

      optimizer.zero_grad()

      pred = model(target_batch)

      loss = loss_fn(pred, mask_batch[0:1, 0:1, 0:SIZE, 0:SIZE])

      loss.backward()

      optimizer.step()

  torch.save(model, "text-segmentation.pt")

In [496]:
#background = load_image(TRAIN_FOLDER + random.choice(os.listdir(TRAIN_FOLDER)))
#foreground = load_image(TRAIN_FOLDER + random.choice(os.listdir(TRAIN_FOLDER)))
#target, mask = generate_images(background, foreground, True)
train()

generating image: 0
generating image: 1
generating image: 2
generating image: 3
generating image: 4
generating image: 5
generating image: 6
generating image: 7
generating image: 8
generating image: 9
generating image: 0
generating image: 1
generating image: 2
generating image: 3
generating image: 4
generating image: 5
generating image: 6
generating image: 7
generating image: 8
generating image: 9
generating image: 0
generating image: 1
generating image: 2
generating image: 3
generating image: 4
generating image: 5
generating image: 6
generating image: 7
generating image: 8
generating image: 9
generating image: 0
generating image: 1
generating image: 2
generating image: 3
generating image: 4
generating image: 5
generating image: 6
generating image: 7
generating image: 8
generating image: 9
generating image: 0
generating image: 1
generating image: 2
generating image: 3
generating image: 4
generating image: 5
generating image: 6
generating image: 7
generating image: 8
generating image: 9


In [497]:
model = torch.load("text-segmentation.pt", map_location=device)
torchinfo.summary(model, (1, 3, 256, 256), depth=32)

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 1, 256, 256]          --
├─Connect: 1-1                           [1, 32, 256, 256]         --
│    └─Conv2d: 2-1                       [1, 32, 256, 256]         896
│    └─BatchNorm2d: 2-2                  [1, 32, 256, 256]         64
│    └─ReLU: 2-3                         [1, 32, 256, 256]         --
├─Connect: 1-2                           [1, 32, 256, 256]         --
│    └─Conv2d: 2-4                       [1, 32, 256, 256]         9,248
│    └─BatchNorm2d: 2-5                  [1, 32, 256, 256]         64
│    └─ReLU: 2-6                         [1, 32, 256, 256]         --
├─MaxPool2d: 1-3                         [1, 32, 128, 128]         --
├─Connect: 1-4                           [1, 64, 128, 128]         --
│    └─Conv2d: 2-7                       [1, 64, 128, 128]         18,496
│    └─BatchNorm2d: 2-8                  [1, 64, 128, 128]         128
│    └