In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import math
import torch
import torchvision
from nn.nn import NeuralNetwork

height = 256
width = 256
noise_type = "mul"
std_deviation = 0.3
file_path = ""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Device: {device}")

In [None]:
def random_number(min, max):
  return random.randint(math.floor(min), math.floor(max))

In [None]:
def display_image(image):
  imgplot = plt.imshow(image)
  plt.axis('off')
  plt.show()

In [None]:
def get_rgb(image):
  image_r = image[:,0,:,:].reshape(1, 1, height, width)
  image_g = image[:,1,:,:].reshape(1, 1, height, width)
  image_b = image[:,2,:,:].reshape(1, 1, height, width)
  return image_r, image_g, image_b

In [None]:
def expand(image): 
  expanded = image.transpose(-1, 0, 1)
  expanded = np.expand_dims(expanded, axis = 0)
  return expanded

In [None]:
def fix_image(image):
  image[image > 1] = 1
  image[image < 0] = 0

In [None]:
def create_triangle(rand_height_start, rand_height_end, rand_width_start, rand_width_end, image):
  ppt = np.array([
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
  ], np.int32)

  ppt = ppt.reshape((-1, 1, 2))
  cv2.fillPoly(image, [ppt], (random_number(0, 255), random_number(0, 255), random_number(0, 255)), 8)

In [None]:
def create_rectangle(rand_height_start, rand_height_end, rand_width_start, rand_width_end, image):
  ppt = np.array([
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
  ], np.int32)

  ppt = ppt.reshape((-1, 1, 2))
  cv2.fillPoly(image, [ppt], (random_number(0, 255), random_number(0, 255), random_number(0, 255)), 8)

In [None]:
def create_ellipse(height, width, image):
  cv2.ellipse(
    image,
    (random_number(0, width), random_number(0, height)),
    (random_number(0, width), random_number(0, height)),
    random_number(0, 360),
    0,
    360,
    (random_number(0, 255), random_number(0, 255), random_number(0, 255)),
    -1,
    8
  )

In [None]:
def create_stars(height, width, image):
  points = []
  for i in range(5):
    points.append((random_number(0, width), random_number(0, height)))

  center = np.array(points).mean(axis = 0)

  for point in points:
    cv2.line(
      image,
      point,
      (math.floor(center[0]), math.floor(center[1])),
      (random_number(0, 255), random_number(0, 255), random_number(0, 255)),
      2,
      8
    )


In [None]:
def generate_image(height, width, triangles, rectangles, ellipses, stars):
  r = np.full((height, width), random_number(0, 255))
  g = np.full((height, width), random_number(0, 255))
  b = np.full((height, width), random_number(0, 255))

  rand_height_start = height * -0.1
  rand_height_end = height * 1.1 
  rand_width_start = width * -0.1
  rand_width_end = width * 1.1 

  image = np.dstack((r, g, b))
  shapes = [triangles, rectangles, ellipses, stars]
  index = random_number(0, 3)
  while sum(shapes) > 0:
    while(shapes[index] == 0):
      index = random_number(0, 3)
    
    if(index == 0):
      create_triangle(rand_height_start, rand_height_end, rand_width_start, rand_width_end, image)
    elif(index == 1):
      create_rectangle(rand_height_start, rand_height_end, rand_width_start, rand_width_end, image)
    elif(index == 2):
      create_ellipse(height, width, image)
    elif(index == 3):
      create_stars(height, width, image)

    shapes[index] -= 1

  return image

In [None]:
def add_noise(image, type):
  if type == "add":
    noise = np.random.normal(0, std_deviation, image.shape)
    image = image + noise
  elif type == "mul":
    noise = np.random.normal(1, std_deviation, image.shape)
    image = image * noise

  return image

In [None]:
def train(noised, originals, model, optimizer, loss_fn, epochs, print_step):
  model = model.to(device)

  model.train()

  for epoch in range(epochs):
    loss_sum = 0
    true_sum = 0

    for step in range(len(noised)):
      image = torch.from_numpy(noised[step:step+1])
      image_r, image_g, image_b = get_rgb(image)

      original = torch.from_numpy(originals[step:step+1])

      image = image.to(device)
      image_r = image_r.to(device)
      image_g = image_g.to(device)
      image_b = image_b.to(device)

      original = original.to(device)

      #print(image.shape)
      #print(image_r.shape)
      #print(image_g.shape)
      #print(image_b.shape)
      pred = model(image, image_r, image_g, image_b)

      loss = loss_fn(pred, original)

      optimizer.zero_grad()

      loss.backward()

      optimizer.step()

      loss_sum += loss.item()

In [None]:
if(file_path == ""):
  image = generate_image(height, width, 3, 3, 3, 3)
else:
  image = cv2.imread(file_path)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

image = image / 255

display_image(image)

In [None]:
noised_image = add_noise(image, noise_type)
fix_image(noised_image)

display_image(noised_image)

noised = expand(noised_image)
originals = expand(image)

In [None]:
model = NeuralNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

model = model.to(device)

noised = noised.astype(np.float32)
originals = originals.astype(np.float32)
train(noised, originals, model, optimizer, loss_fn, 10, 100)

print('done training')

In [None]:
new_image = generate_image(height, width, 3, 3, 3, 3)
new_image = new_image / 255

display_image(new_image)

new_image = add_noise(new_image, noise_type)
fix_image(new_image)

display_image(new_image)

new_images = expand(new_image)

new_images = new_images.astype(np.float32)

new_images = torch.from_numpy(new_images)
new_image_r, new_image_g, new_image_b = get_rgb(new_images)

new_image_r = new_image_r.to(device)
new_image_g = new_image_g.to(device)
new_image_b = new_image_b.to(device)

pred = model(new_images, new_image_r, new_image_g, new_image_b)
fix_image(pred)

pred_img = pred[0, :, :, :].detach().numpy()
pred_img = pred_img.transpose(1, 2, 0)

display_image(pred_img)