**CS 4501 - Digital Signal Processing**

Phyl Peng (hp9psb), Brian Mbogo (bpm4pkz), Anna Williamson (amw4uet)

*Image filter design for different noise distributions*

# Part 1: A neural network to classify the noise distribution

Koehrsen, W. (2018). Transfer Learning with Convolutional Neural Networks in PyTorch. Towards Data Science. https://towardsdatascience.com/transfer-learning-with-convolutional-neural-networks-in-pytorch-dd09190245ce

In [6]:
from torchvision import transforms, datasets, models
from torchvision.models.resnet import Bottleneck
import torch
from torch import optim, cuda
from torch.utils.data import DataLoader, RandomSampler
import torch.nn as nn
from torchsummary import summary

from scipy.fft import ifft2
import seaborn as sns
import numpy as np
import pandas as pd
import os
from timeit import default_timer as timer

from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

# salt and pepper
from scipy import signal
from statistics import median
from numpy import asarray

In [None]:
# Location of data
datadir = os.path.curdir
traindir = os.path.join(datadir, 'train')
validdir = os.path.join(datadir, 'valid') 
testdir = os.path.join(datadir, 'valid') #validation set = testing set for simplicity

save_file_name = os.path.join(datadir, 'noise-model.pt')
checkpoint_path = os.path.join(datadir, 'noise-model.pth')
model_str = "inception"

batch_size = 32

In [None]:
# Image transformations
vgg_transforms = {
    # Train uses data augmentation
    'train':
    transforms.Compose([
        transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),  # Image net standards
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225])  # Imagenet standards
    ]),
    # Validation does not use augmentation
    'val':
    transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    # Test does not use augmentation
    'test':
    transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

inception_transforms = {
    # Train uses data augmentation
    'train':
    transforms.Compose([
        transforms.Resize(342),
        #transforms.RandomRotation(degrees=30),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=299),  # Inception-specific
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225])  # Imagenet standards
    ]),
    # Validation does not use augmentation
    'val':
    transforms.Compose([
        transforms.Resize(size=342),
        transforms.CenterCrop(size=299),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    # Test does not use augmentation
    'test':
    transforms.Compose([
        transforms.Resize(size=342),
        transforms.CenterCrop(size=299),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_transforms = vgg_transforms if model_str == "vgg" else inception_transforms


In [None]:
model_seed = 12600
torch.random.manual_seed(model_seed)
train_on_gpu = cuda.is_available()
# Datasets from each folder
data = {
    'train':
    datasets.ImageFolder(root=traindir, transform=image_transforms['train']),
    'val':
    datasets.ImageFolder(root=validdir, transform=image_transforms['val']),
    'test':
    datasets.ImageFolder(root=testdir, transform=image_transforms['test'])
}

# Dataloader iterators
dataloaders = {
    'train': DataLoader(data['train'], sampler=RandomSampler(data['train'], num_samples=200), batch_size=batch_size),
    'val': DataLoader(data['val'], batch_size=batch_size, shuffle=True),
    'test': DataLoader(data['test'], batch_size=batch_size, shuffle=True)
}

n_classes = len(data['train'].classes)
print(f"n_classes: {n_classes} \n {data['train'].classes}")

In [None]:
trainiter = iter(dataloaders['train'])
features, labels = next(trainiter)
features.shape, labels.shape

In [None]:
sp_example = features[labels == 1][0, 0] #salt and pepper
ga_example = features[labels == 0][0, 0] #gaussian
plt.imshow(sp_example)
plt.figure()
plt.imshow(ga_example)

In [None]:
def get_pretrained_model(model_name):
  model = None
  if model_name == "vgg16":
    model = models.vgg16(pretrained=True)

    # Freeze early layers
    for param in model.parameters():
      param.requires_grad = False
    
    #enable last 2 convolution layers
    fl = len(model.features)
    for layer in model.features[fl-6:fl]:
      for param in layer.parameters():
        param.requires_grad = True

    #prepend transform layer to normalize fft channels
    model.features[0] = torch.nn.Sequential(
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    model.features[0]
  )
    
    #enable other fully-connected layers
    for layer in model.classifier[1:]:
      for param in layer.parameters():
        param.requires_grad = True

    n_inputs = model.classifier[6].in_features

    # Add on classifier
    model.classifier[6] = nn.Sequential(
        nn.Linear(n_inputs, 256), nn.ReLU(),
        nn.Linear(256, n_classes), nn.Softmax(dim=1))
    
  elif model_name == "inceptionv3":
    model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1)

    # Freeze early layers
    for param in model.parameters():
      param.requires_grad = False

    #prepend transform layer to normalize fft channels
    model.Conv2d_1a_3x3 = torch.nn.Sequential(
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    model.Conv2d_1a_3x3
)
    #enable last inception layer (features)
    for param in model.Mixed_7c.parameters():
      param.requires_grad = True

    model.fc = nn.Sequential(
      nn.Linear(model.fc.in_features, 1024), nn.ReLU(),
      nn.Linear(1024, n_classes), nn.Softmax(dim=1))

  # Move to gpu and parallelize
  if train_on_gpu:
      model = model.to('cuda')

  return model

In [None]:
model = get_pretrained_model("inceptionv3")
print(summary(
        model, input_size=(3, 224, 224), batch_size=batch_size))

In [None]:
#get numerical value of each classification
model.class_to_idx = data['train'].class_to_idx
model.idx_to_class = {
    idx: class_
    for class_, idx in model.class_to_idx.items()
}

print(list(model.idx_to_class.items()))

trainiter = iter(dataloaders['train'])
features, labels = next(trainiter)

In [None]:
def train(model,
          criterion,
          optimizer,
          train_loader,
          valid_loader,
          save_file_name,
          model_type="vgg",
          max_epochs_stop=3,
          n_epochs=20,
          print_every=2):
    """Train a PyTorch Model

    Params
    --------
        model (PyTorch model): cnn to train
        criterion (PyTorch loss): objective to minimize
        optimizer (PyTorch optimizier): optimizer to compute gradients of model parameters
        train_loader (PyTorch dataloader): training dataloader to iterate through
        valid_loader (PyTorch dataloader): validation dataloader used for early stopping
        save_file_name (str ending in '.pt'): file path to save the model state dict
        model_type (str either vgg or inception)
        max_epochs_stop (int): maximum number of epochs with no improvement in validation loss for early stopping
        n_epochs (int): maximum number of training epochs
        print_every (int): frequency of epochs to print training stats

    Returns
    --------
        model (PyTorch model): trained cnn with best weights
        history (DataFrame): history of train and validation loss and accuracy
    """

    # Early stopping intialization
    epochs_no_improve = 0
    valid_loss_min = np.Inf

    valid_max_acc = 0
    history = []

    # Number of epochs already trained (if using loaded in model weights)
    try:
        print(f'Model has been trained for: {model.epochs} epochs.\n')
    except:
        model.epochs = 0
        print(f'Starting Training from Scratch.\n')

    # Main loop
    for epoch in range(n_epochs):

        # keep track of training and validation loss each epoch
        train_loss = 0.0
        valid_loss = 0.0

        train_acc = 0
        valid_acc = 0

        # Set to training
        model.train()

        # Training loop
        for ii, (data, target) in enumerate(train_loader):
            #print(f"batch {ii}")
            # Clear gradients
            optimizer.zero_grad()
            
            #use fft in the other channels of the image
            imfft = ifft2(data[0, 0].numpy())
            data[0, 1] = torch.tensor(np.real(imfft))
            data[0, 2] = torch.tensor(np.imag(imfft))

            if train_on_gpu:
                data, target = data.cuda(), target.cuda()

            output = model(data)
            if model_type=="inception":
              output = output[0]

            # Loss and backpropagation of gradients
            output = output[:, 1]
            loss = criterion(output, target.float())
            loss.backward()

            # Update the parameters
            optimizer.step()

            # Track train loss by multiplying average loss by number of examples in batch
            train_loss += loss.item() * data.size(0)

            # Calculate accuracy
            pred = torch.round(output)
            correct_tensor = pred.eq(target.data.view_as(pred))
            # Need to convert correct tensor from int to float to average
            accuracy = torch.mean(correct_tensor.type(torch.FloatTensor))
            # Multiply average accuracy times the number of examples in batch
            train_acc += accuracy.item() * data.size(0)

            # Track training progress
            #print(
            #    f'Epoch: {epoch}\t{100 * (ii + 1) / len(train_loader):.2f}% complete. {timer() - start:.2f} seconds elapsed in epoch.',
            #    end='\r')

        # After training loops ends, start validation
        else:
            model.epochs += 1

            # Don't need to keep track of gradients
            with torch.no_grad():
                # Set to evaluation mode
                model.eval()

                # Validation loop
                for data, target in valid_loader:
                    imfft = ifft2(data[0, 0].numpy())
                    data[0, 1] = torch.tensor(np.real(imfft))
                    data[0, 2] = torch.tensor(np.imag(imfft))

                    if train_on_gpu:
                        data, target = data.cuda(), target.cuda()

                    output = model(data)
                    output = output[:, 1]
                    loss = criterion(output, target.float())
                    # Multiply average loss times the number of examples in batch
                    valid_loss += loss.item() * data.size(0)

                    # Calculate validation accuracy
                    pred = torch.round(output)
                    correct_tensor = pred.eq(target.data.view_as(pred))
                    accuracy = torch.mean(
                        correct_tensor.type(torch.FloatTensor))
                    # Multiply average accuracy times the number of examples
                    valid_acc += accuracy.item() * data.size(0)

                # Calculate average losses
                train_loss = train_loss / len(train_loader.dataset)
                valid_loss = valid_loss / len(valid_loader.dataset)

                # Calculate average accuracy
                train_acc = train_acc / len(train_loader.dataset)
                valid_acc = valid_acc / len(valid_loader.dataset)

                history.append([train_loss, valid_loss, train_acc, valid_acc])

                # Print training and validation results
                if (epoch + 1) % print_every == 0:
                    print(
                        f'\nEpoch: {epoch} \tTraining Loss: {train_loss:.4f} \tValidation Loss: {valid_loss:.4f}'
                    )
                    print(
                        f'\t\tTraining Accuracy: {100 * train_acc:.2f}%\t Validation Accuracy: {100 * valid_acc:.2f}%'
                    )

                # Save the model if validation loss decreases
                if valid_loss < valid_loss_min:
                    # Save model
                    torch.save(model.state_dict(), save_file_name)
                    # Track improvement
                    epochs_no_improve = 0
                    valid_loss_min = valid_loss
                    valid_best_acc = valid_acc
                    best_epoch = epoch

                # Otherwise increment count of epochs with no improvement
                else:
                    epochs_no_improve += 1
                    # Trigger early stopping
                    if epochs_no_improve >= max_epochs_stop:
                        print(
                            f'\nEarly Stopping! Total epochs: {epoch}. Best epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_acc:.2f}%'
                        )
                        #print(
                        #    f'{total_time:.2f} total seconds elapsed. {total_time / (epoch+1):.2f} seconds per epoch.'
                        #)

                        # Load the best state dict
                        model.load_state_dict(torch.load(save_file_name))
                        # Attach the optimizer
                        model.optimizer = optimizer

                        # Format history
                        history = pd.DataFrame(
                            history,
                            columns=[
                                'train loss', 'valid loss', 'train acc',
                                'valid acc'
                            ])
                        return model, history

    # Attach the optimizer
    model.optimizer = optimizer
    print(
        f'\nBest epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_acc:.2f}%'
    )
    #print(
    #    f'{total_time:.2f} total seconds elapsed. {total_time / (epoch):.2f} seconds per epoch.'
    #)
    # Format history
    history = pd.DataFrame(
        history,
        columns=['train loss', 'valid loss', 'train acc', 'valid acc'])
    return model, history


In [None]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters())
model, history = train(
    model,
    criterion,
    optimizer,
    dataloaders['train'],
    dataloaders['val'],
    save_file_name=save_file_name,
    max_epochs_stop=15,
    model_type=model_str,
    print_every=1)

# Part 2: Filters to perform denoising by noise distribution

In [7]:
# Gaussian Filter - used to remove noise and detail


'''
SOURCES:

https://www.cs.auckland.ac.nz/courss/compsci373s1c/PatricesLectures/Image%20Filtering.pdf
https://blog.en.uwa4d.com/2022/08/11/screen-post-processing-effects-chapter-1-basic-algorithm-of-gaussian-blur-and-its-implementation/
https://en.wikipedia.org/wiki/Salt-and-pepper_noise
https://www.geeksforgeeks.org/python-pil-getpixel-method/
https://stackoverflow.com/questions/52307290/what-is-the-difference-between-images-in-p-and-l-mode-in-pil#:~:text=If%20you%20have%20an%20L,stores%20a%20greyscale%2C%20not%20colour.
https://ijesc.org/upload/a2d11768dad7f56db1cc12bb3650879a.A%20Comparison%20of%20Salt%20and%20Pepper%20Noise%20Removal%20Filters.pdf
https://www.geeksforgeeks.org/python-pil-copy-method/
'''

'''
NOTES:

- "An effective noise reduction method for this type of noise is a median filter or a morphological filter."
- first attempt: median filter
- note: images are in "L"-mode... maps to black and white pixels/greyscale
- a median filter is the best of a variety of filters to handle salt and pepper noise
'''

def convolution2DPadded(px, width, height, gaussian_kernel, ks_w, ks_h):
    pad_x = ks // 2
    pad_y = ks // 2
    input_padded = np.pad(px, ((pad_x, pad_x), (pad_y, pad_y)), mode='constant')
        
    # Convolve the padded matrix with the Gaussian kernel
    return convolve2d(input_padded, gaussian_kernel, mode='valid')
    
def gaussianDiscrete2D(theta, x, y):
    g = 0
    for ySubPixel in [i * 0.1 for i in range(int(y - 0.5 * 10), int(y + 0.6 * 10))]:
        for xSubPixel in [i * 0.1 for i in range(int(x - 0.5 * 10), int(x + 0.6 * 10))]:
            g += ((1 / (2 * math.pi * theta * theta)) *
                  math.pow(math.e, -(xSubPixel * xSubPixel + ySubPixel * ySubPixel) / (2 * theta * theta)))
    g /= 121
    return g

def gaussian2D(theta, size):
    kernel = [[0 for i in range(size)] for j in range(size)]
    for j in range(size):
        for i in range(size):
            kernel[i][j] = gaussianDiscrete2D(theta, i - (size / 2), j - (size / 2))

    kernel_sum = sum([sum(row) for row in kernel])

    kernel = [[element / kernel_sum for element in row] for row in kernel]
    return kernel

def smooth(px, width, height, ks, theta):
    gaussian_kernel = gaussian2D(theta, ks)
    print(px.shape)
    output = convolution2DPadded(px, width, height, gaussian_kernel, ks, ks)
    return output

def smooth_image(px, w, h, ks, theta):
    input_2d = [[0 for i in range(w)] for j in range(h)]
    output_1d = [0 for i in range(w * h)]
    output_2d = [[0 for i in range(w)] for j in range(h)]
    output = [0 for i in range(w * h)]

#     for j in range(h):
#         for i in range(w):
#             input_2d[j][i] = Image.new('RGB', (1, 1), px[j * w + i]).convert('L').getpixel((0, 0))
    
    output_2d = smooth(px, w, h, ks, theta)
    print(len(output_1d))
    print(output_2d.shape)
#     for j in range(h):
#         for i in range(w):
#             output_1d[j * w + i] = output_2d[j][i]

#     for i in range(len(output_1d)):
#         grey = round(output_1d[i])
#         if grey > 255:
#              grey = 255
#         if grey < 0:
#             grey = 0
#         output[i] = Image.new('L', (1, 1), (grey)).getpixel((0, 0))

    return Image.fromarray(np.uint8(output_2d), 'L')

 
# load image
im = Image.open(r"/Users/philpeng/Documents/image-denoising/guassian_bridge.jpg")
px = imread("guassian_bridge.jpg")

w, h = px.shape
print("Original photo:")
display(im)
original_im = im.copy()

# filter image
print("Gaussian blurred photo:")
ks = 5
im = smooth_image(px, w, h, ks, 0.9)
display(im)


#print("Mean Squared Error:", MSE(original_im, im))

FileNotFoundError: [Errno 2] No such file or directory: '/Users/philpeng/Documents/image-denoising/guassian_bridge.jpg'

Salt and Pepper Filters

In [8]:
# function for calculating the MSE (Mean Squared Error) between two images
# Source: https://www.statology.org/mean-squared-error-python/

def MSE(actual_im, predict_im):
    mse = np.square(np.subtract(actual_im, predict_im)).mean() 
    return mse

# initialize the image to filter

ima = r"saltpepper_car.jpg"
imb = r"saltpepper_hiking.jpg"
imc = r"saltpepper_seagull.jpg"

imsp = Image.open(ima)
display(imsp)

FileNotFoundError: [Errno 2] No such file or directory: 'saltpepper_car.jpg'

In [None]:
# Median Filter - used to remove salt and pepper noise


'''
SOURCES:

https://en.wikipedia.org/wiki/Salt-and-pepper_noise
https://www.cs.auckland.ac.nz/courss/compsci373s1c/PatricesLectures/Image%20Filtering.pdf
https://www.geeksforgeeks.org/python-pil-getpixel-method/
https://stackoverflow.com/questions/52307290/what-is-the-difference-between-images-in-p-and-l-mode-in-pil#:~:text=If%20you%20have%20an%20L,stores%20a%20greyscale%2C%20not%20colour.
https://ijesc.org/upload/a2d11768dad7f56db1cc12bb3650879a.A%20Comparison%20of%20Salt%20and%20Pepper%20Noise%20Removal%20Filters.pdf
https://www.geeksforgeeks.org/python-pil-copy-method/
'''

'''
NOTES:

- "An effective noise reduction method for this type of noise is a median filter or a morphological filter."
- first attempt: median filter
- note: images are in "L"-mode... maps to black and white pixels/greyscale
- a median filter is the best of a variety of filters to handle salt and pepper noise
'''


def median_pixel(width, height, w, h, px):
    if (w != 0 and w != width-1 and h != 0 and h != height-1): # don't compute the edges
        px[w, h] = median([px[w, h+1], px[w, h-1], px[w+1, h], px[w-1, h], px[w-1, h-1], px[w+1, h-1], px[w+1, h+1], px[w-1, h+1], px[w, h]])
    return


def median_filter(im):
    px = im.load()
    width, height = im.size
    for w in range(0, width):
        for h in range(0, height):
            median_pixel(width, height, w, h, px)
    return

original1 = imsp.copy()

# load image
print("Original photo:")
display(original1)

# filter image
print("Median filtered photo:")
median_filter(original1)
display(original1)

print("Mean Squared Error:", MSE(imsp, original1))


In [None]:
# Laplacian Filter - used to sharpen edges/remove blurriness in an image

'''
SOURCES:

https://www.youtube.com/watch?v=kewse-JsjH0
https://www.geeksforgeeks.org/laplacian-filter-using-matlab/
https://www.pluralsight.com/guides/importing-image-data-into-numpy-arrays
https://www.l3harrisgeospatial.com/docs/laplacianfilters.html#:~:text=Laplacian%20filter%20kernels%20usually%20contain,be%20either%20negative%20or%20positive.
https://www.geeksforgeeks.org/image-sharpening-using-laplacian-filter-and-high-boost-filtering-in-matlab/
'''

def sharpen_edges(im, new_image, mult):
    data_im = asarray(im)
    data_newimage = asarray(new_image)
    sharp_image = Image.fromarray(abs(data_im + (mult)*data_newimage))
    return sharp_image

def apply_laplacian(im, kernel):
    data = asarray(im) # load image as a 2d array
    new = signal.convolve2d(data, kernel, boundary='symm', mode='same')
    new_image = Image.fromarray(abs(new))
    return new_image

# copy images with median filter already applied
denoised_image1 = original1.copy()
denoised_image2 = original1.copy()

kernel1 = [[0, 1, 0], [1, -4, 1], [0, 1, 0]]
new_image1 = apply_laplacian(denoised_image1, kernel1)
print("1 Laplacian filtered photo - outline edges (kernel 1)")
new_image1.show() # need to use show()! Using display() won't work here
sharp_im1 = sharpen_edges(denoised_image1, new_image1, -1)
print("2 Laplacian filtered photo (kernel 1)")
sharp_im1.show()

print("Mean Squared Error (kernel 1):", MSE(imsp, sharp_im1))

kernel2 = [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]
new_image2 = apply_laplacian(denoised_image2, kernel2)
print("3 Laplacian filtered photo - outline edges (kernel 2)")
new_image2.show() # need to use show()! Using display() won't work here
sharp_im2 = sharpen_edges(denoised_image2, new_image2, 1)
print("4 Laplacian filtered photo (kernel 2)")
sharp_im2.show()

print("Mean Squared Error (kernel 2):", MSE(imsp, sharp_im2))


In [None]:
# edited version of a Noise Adaptive Fuzzy Switching Median filter (NAFSM) - used to remove salt and pepper noise 
# while maintaining sharp edges/prevent blurring of the original photo

'''
SOURCES:

https://ieeexplore.ieee.org/document/5356178
https://stackoverflow.com/questions/39554660/np-arrays-being-immutable-assignment-destination-is-read-only
'''

'''
NOTES:

The below implementation is an edited version of the Noise Adaptive Fuzze Switching Median Filter (NAFSM) described in
the paper linked above. The major difference is the method for determining the new value of a salt and pepper pixel.
The paper provided an algorithm using various constants (and little substantiation for the values of these constants)
which led to a subpar filtering result. By switching out this custom algorithm and using the median value of the
surrounding pixels (excluding other salt and pepper pixels) to determine the value of the current salt and pepper pixel
led to a better result. 
'''

def check_G(data_im, noise_mask, w, h):
    width, height = noise_mask.shape
    sum_N = 0
    N_one_list = []
    luminance_diff = []
    for m in range(-1, 2):
        for n in range(-1, 2):
            if ((w+m) >= 0 and (h+n) >= 0 and (w+m) < width and (h+n) < height):
                if (noise_mask[w+m][h+n] == 1):
                    N_one_list.append(data_im[w+m][h+n])
                    sum_N += 1
                if (m != 0 and n != 0):
                    luminance_diff.append(abs(int(data_im[w+m][h+n]) - int(data_im[w][h])))
                
    max_luminance_diff = max(luminance_diff)
    # values given by the paper
    T_1 = 10
    T_2 = 30
    if (max_luminance_diff < T_1):
        F = 1 # here
    elif (max_luminance_diff >= T_2):
        F = 1
    else:
        F = (max_luminance_diff - T_1) / (T_2 - T_1)
    return sum_N, N_one_list, F

def NAFSM_filter(im, noise_mask, use_median):
    original_im = asarray(im)
    data_im = original_im.copy()
    width, height = data_im.shape
    
    for w in range(0, width):
        for h in range(0, height):
            if (noise_mask[w][h] == 0):
                sum_N, N_one_list, F = check_G(data_im, noise_mask, w, h)
                if (sum_N >= 1):
                    if (use_median):
                        data_im[w][h] = np.median(N_one_list)
                    else:
                        data_im[w][h] = (1-F)*data_im[w][h] + F*median(N_one_list) # calculate median, excluding surrounding salt&pepper pixels
    return data_im

def detection(im):
    data_im = asarray(im)
    # note: strangely, using asarray() flips the width and height of the original image when converting to array
    width, height = data_im.shape
    
    L_salt = 255
    L_pepper = 0
    
    # N(i,j)=1  represents "noise-free pixels"
    # N(i,j)=0  represents "noise pixels"
    noise_mask = np.ones((width, height))
    
    for w in range(0, width):
        for h in range(0, height):
            if (data_im[w][h] < 20 or data_im[w][h] > 235):
                noise_mask[w][h] = 0 # noise pixel identified
    return noise_mask

original2 = imsp.copy()
noise_mask = detection(original2)
data_result = NAFSM_filter(original2, noise_mask, True)
final_result = Image.fromarray(data_result)
print("Noise adaptive fuzzy switching median filter - adapted to use median filter")
final_result.show()
print("Mean Squared Error:", MSE(imsp, final_result))

In [None]:
# implementation of novel image-denoising technique Elastic Median Filter 1 and 2 (EMF1) (EMF2)

'''
SOURCES:

https://www.researchgate.net/publication/299474315_Two_new_methods_for_removing_salt-and-pepper_noise_from_digital_images
'''

def med_pixel(width, height, w, h, px):
    if (w != 0 and w != width-1 and h != 0 and h != height-1):
        px_list = [px[w, h+1], px[w, h-1], px[w+1, h], px[w-1, h], px[w-1, h-1], px[w+1, h-1], px[w+1, h+1], px[w-1, h+1], px[w, h]]
        med_val = np.median(px_list)
        return med_val
    return -1

def diff_pixels(width, height, w, h, px):
    if (w != 0 and w != width-1 and h != 0 and h != height-1):
        px_list = [px[w, h+1], px[w, h-1], px[w+1, h], px[w-1, h], px[w-1, h-1], px[w+1, h-1], px[w+1, h+1], px[w-1, h+1], px[w, h]]
        med_val = np.median(px_list)
        sum_diff = 0
        for each in px_list:
            sum_diff += np.abs(med_val - each)
        return sum_diff, med_val
    return -1, -1


def EMF1(im, beta):
    px = im.load()
    width, height = im.size
    alpha = (3*3)**beta
    
    for w in range(0, width):
        for h in range(0, height):
            sum_diff, med_val = diff_pixels(width, height, w, h, px)
            if (sum_diff != -1 and med_val != -1):
                comp_plus = med_val + alpha + np.sqrt(sum_diff)
                comp_minus = med_val - alpha - np.sqrt(sum_diff)
                if (px[w, h] >= comp_plus or px[w, h] <= comp_minus):
                    px[w, h] = int(med_val)
    return 

def EMF2(im, beta):
    px = im.load()
    width, height = im.size
    alpha = (3*3)**beta
    
    for w in range(0, width):
        for h in range(0, height):
            med_val = med_pixel(width, height, w, h, px)
            if (med_val != -1):
                comp_plus = med_val + alpha
                comp_minus = med_val - alpha
                if (px[w, h] >= comp_plus or px[w, h] <= comp_minus):
                    px[w, h] = int(med_val)
    return


# load image
original3 = imsp.copy()
comp_im = original3.copy()
original_im = original3.copy()
print("Original photo:")
display(original3)

# smaller beta value - less salt and pepper pixels, but blurrier?
beta = 1.0

print("EMF1:")
EMF1(original3, beta)
display(original3)
print("Mean Squared Error (EMF1):", MSE(comp_im, original3))

print("EMF2:")
EMF2(original_im, beta)
display(original_im)
print("Mean Squared Error (EMF2):", MSE(comp_im, original_im))