In [None]:
import cv2
import numpy as np
import matplotlib
import matplotlib.pyplot as pyplot
import matplotlib.image as mpimg
import random
import math
import torch
import os
import torchvision
from nn.nn import NeuralNetwork

height = 256
width = 256
noise_type = "add"
std_deviation = 0.1
epochs = 1
images_to_generate = 10
images_to_test = 1
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Device: {device}")

In [None]:
def random_number(min, max):
  return random.randint(math.floor(min), math.floor(max))

In [None]:
def load_images_from_folder(folder):
  images = []
  for filename in os.listdir(folder):
      img = cv2.imread(os.path.join(folder,filename))
      img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      if img is not None:
          images.append(img)
  return images

In [None]:
def display_image(image, cmap="viridis"):
  imgplot = plt.imshow(image, cmap=cmap)
  plt.axis('off')
  plt.show()

In [None]:
def get_rgb(image):
  image_r = image[:,0,:,:].reshape(1, 1, height, width)
  image_g = image[:,1,:,:].reshape(1, 1, height, width)
  image_b = image[:,2,:,:].reshape(1, 1, height, width)
  return image_r, image_g, image_b

In [None]:
def expand(image): 
  expanded = image.transpose(-1, 0, 1)
  expanded = np.expand_dims(expanded, axis = 0)
  return expanded

In [None]:
def fix_image(image):
  image[image > 1] = 1
  image[image < 0] = 0

In [None]:
def create_triangle(rand_height_start, rand_height_end, rand_width_start, rand_width_end, image):
  ppt = np.array([
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
  ], np.int32)

  ppt = ppt.reshape((-1, 1, 2))
  cv2.fillPoly(image, [ppt], (random_number(0, 255), random_number(0, 255), random_number(0, 255)), 8)

In [None]:
def create_rectangle(rand_height_start, rand_height_end, rand_width_start, rand_width_end, image):
  ppt = np.array([
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
    [random_number(rand_width_start, rand_width_end), random_number(rand_height_start, rand_height_end)], 
  ], np.int32)

  ppt = ppt.reshape((-1, 1, 2))
  cv2.fillPoly(image, [ppt], (random_number(0, 255), random_number(0, 255), random_number(0, 255)), 8)

In [None]:
def create_ellipse(height, width, image):
  cv2.ellipse(
    image,
    (random_number(0, width), random_number(0, height)),
    (random_number(0, width), random_number(0, height)),
    random_number(0, 360),
    0,
    360,
    (random_number(0, 255), random_number(0, 255), random_number(0, 255)),
    -1,
    8
  )

In [None]:
def create_stars(height, width, image):
  points = []
  for i in range(5):
    points.append((random_number(0, width), random_number(0, height)))

  center = np.array(points).mean(axis = 0)

  for point in points:
    cv2.line(
      image,
      point,
      (math.floor(center[0]), math.floor(center[1])),
      (random_number(0, 255), random_number(0, 255), random_number(0, 255)),
      2,
      8
    )


In [None]:
def generate_image(height, width, triangles, rectangles, ellipses, stars):
  r = np.full((height, width), random_number(0, 255))
  g = np.full((height, width), random_number(0, 255))
  b = np.full((height, width), random_number(0, 255))

  rand_height_start = height * -0.1
  rand_height_end = height * 1.1 
  rand_width_start = width * -0.1
  rand_width_end = width * 1.1 

  image = np.dstack((r, g, b))
  shapes = [triangles, rectangles, ellipses, stars]
  index = random_number(0, 3)
  while sum(shapes) > 0:
    while(shapes[index] == 0):
      index = random_number(0, 3)
    
    if(index == 0):
      create_triangle(rand_height_start, rand_height_end, rand_width_start, rand_width_end, image)
    elif(index == 1):
      create_rectangle(rand_height_start, rand_height_end, rand_width_start, rand_width_end, image)
    elif(index == 2):
      create_ellipse(height, width, image)
    elif(index == 3):
      create_stars(height, width, image)

    shapes[index] -= 1

  return image

In [None]:
def add_noise(image, type):
  if type == "add":
    noise = np.random.normal(0, std_deviation, image.shape)
    image = image + noise
  elif type == "mul":
    noise = np.random.normal(1, std_deviation, image.shape)
    image = image * noise

  return image

In [None]:
def train(noised, originals, model, optimizer, loss_fn, epochs):
  model = model.to(device)

  model.train()

  for epoch in range(epochs):
    for step in range(len(noised)):
      image = torch.from_numpy(noised[step:step+1])
      image_r, image_g, image_b = get_rgb(image)

      original = torch.from_numpy(originals[step:step+1])

      image = image.to(device)
      image_r = image_r.to(device)
      image_g = image_g.to(device)
      image_b = image_b.to(device)

      original = original.to(device)

      #print(image.shape)
      #print(image_r.shape)
      #print(image_g.shape)
      #print(image_b.shape)
      pred = model(image, image_r, image_g, image_b)

      loss = loss_fn(pred[0, :, :, :], original[0, :, :, :])

      optimizer.zero_grad()

      loss.backward()

      optimizer.step()

In [None]:
def test_generated(model):
  std_dev_original = 0
  std_dev_denoised = 0
  mse = 0

  model = model.to(device)
  model.eval()
  
  for i in range(images_to_test):
    with torch.no_grad():
      image = generate_image(height, width, 3, 3, 3, 3)
        
      image = image / 255

      if i == 0:
        display_image(image)

      image = add_noise(image, noise_type)
      fix_image(image)

      if i == 0:
        display_image(image)

      images = expand(image)

      images = images.astype(np.float32)

      images = torch.from_numpy(images).to(device)
      r, g, b = get_rgb(images)

      r = r.to(device)
      g = g.to(device)
      b = b.to(device)

      pred = model(images, r, g, b)
      
      fix_image(pred)

      pred_img = pred[0, :, :, :].cpu().detach().numpy()
      pred_img = pred_img.transpose(1, 2, 0)
      #fix_image(pred_img)
      
      if i == 0:
        display_image(pred_img)

      mse += np.square(np.subtract(image, pred_img)).mean()
      std_dev_denoised += torch.std(torch.from_numpy(pred_img))
      std_dev_original += torch.std(torch.from_numpy(image))

  std_dev_original /= images_to_test
  std_dev_denoised /= images_to_test

  mse /= images_to_test
  psnr = 20 * math.log10(255/math.sqrt(mse))
  snr = 10 * math.log10(std_dev_original/std_dev_denoised)

  print("PSNR: ", psnr)
  print("SNR: ", snr)

In [None]:
def test_from_disk(model):
  loaded_images = load_images_from_folder("assets")

  for image in loaded_images:
    h, w, c = image.shape
    image = image / 255

    display_image(image)

    images = expand(image)

    images = images.astype(np.float32)

    images = torch.from_numpy(images).to(device)
    
    r = images[:,0,:,:].reshape(1, 1, h, w).to(device)
    g = images[:,1,:,:].reshape(1, 1, h, w).to(device)
    b = images[:,2,:,:].reshape(1, 1, h, w).to(device)

    pred = model(images, r, g, b)

    pred_img = pred[0, :, :, :].cpu().detach().numpy()
    pred_img = pred_img.transpose(1, 2, 0)
    fix_image(pred_img)

    display_image(pred_img)

In [None]:
originals = []
noised = []
for i in range(images_to_generate):
  image = generate_image(height, width, 3, 3, 3, 3)

  originals.append(image)

  noised_image = add_noise(image, noise_type)
  fix_image(noised_image)
  noised.append(noised_image)

originals = np.reshape(originals, (images_to_generate, height, width, 3))
originals = np.transpose(originals, (0, 3, 1, 2))

noised = np.reshape(noised, (images_to_generate, height, width, 3))
noised = np.transpose(noised, (0, 3, 1, 2))

In [None]:
model = NeuralNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()

model = model.to(device)

noised = noised.astype(np.float32)
originals = originals.astype(np.float32)
train(noised, originals, model, optimizer, loss_fn, epochs)

print('done training')

torch.save(model, "denoising.pt")

In [None]:
filters = model.filter_layer.weight.cpu().detach().numpy()
N = filters.shape[0]

fig, ax = pyplot.subplots(1, N)
for n in range(N):
  ax[n].set_title(f'{n}')
  ax[n].imshow(filters[n, 0], cmap='gray')
  ax[n].set_axis_off()

test_generated(model)
test_from_disk(model)