## Mayo Clinic - STRIP AI -  Understanding image processing 

use `PIL` (pillow package) and `torchvision` to load and process images.

- Get image metadata
    - get file size and create/update timestamps via `pathlib`
    - get image metadata via `PIL` package
        - image lenght, width, mode, and so on
- resize images `PIL` package
    - use `PIL` thumbnail to resize images while keeping the original image height/width ratio
    - note that when converting `PIL` object to numpy, the data is in `[0, 255]` not `[0, 1]`
- crop and pad images by `torchvision` tranforms
    - use `torchvision` to crop and pad images
    - **crop** image: 
        - when the original size is 512*480, and by cropping the image to 512, the new image will be 512*512, and the additional area is filled with 0 (shown as black)
        - when the original size is 512*480 and by cropping the image to 480, the new image will be 480*480
    - **pad** image: 
        - when the original size is 512*480, and by padding the image by 10, the new image will be 522*490, the addtional area is filled with 0 (shown as black)
- add guassion blur to images by `torchvision` tranforms
- normalize images
    - note that before 



In [None]:
import pandas as pd
import numpy as np
import os
from pathlib import Path

from datetime import datetime, timedelta
import time

import gc
import copy

import pyarrow.parquet as pq
import pyarrow as pa

 
from dateutil.relativedelta import relativedelta
from sklearn.preprocessing import StandardScaler, MinMaxScaler

from sklearn.metrics import mean_squared_error, roc_auc_score
from sklearn.model_selection import StratifiedKFold, KFold

pd.options.display.max_rows = 100
pd.options.display.max_columns = 100

import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
random_seed=1234
pl.seed_everything(random_seed)



import torch
from torch import nn
import numpy as np


import torch
from torch.utils.data import (Dataset, DataLoader)


#basic libs

import pandas as pd
import numpy as np
import os
from pathlib import Path

from datetime import datetime, timedelta
import time
from dateutil.relativedelta import relativedelta

import gc
import copy

#additional data processing

import pyarrow.parquet as pq
import pyarrow as pa

from sklearn.preprocessing import StandardScaler, MinMaxScaler


#visualization
import seaborn as sns
import matplotlib.pyplot as plt

#load images
import matplotlib.image as mpimg
import PIL
from PIL import Image




#settings
pd.options.display.max_rows = 100
pd.options.display.max_columns = 100

Image.MAX_IMAGE_PIXELS = None

import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
random_seed=1234
pl.seed_everything(random_seed)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt

In [None]:
img_folder = 'images'

In [None]:
img_path = f'{img_folder}/777311_0.png' 
# img_path = f'{img_folder}/006388_0.png'

### Get image metadata

In [None]:
#check the file info
Path(img_path).stat()

In [None]:
#get image meta data using pillow
#https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=attributes#image-attributes

img = Image.open(img_path)

meta_dict = {    
            'filename': img.filename,
            'format': img.format, 
            'mode': img.mode,  
            'size': img.size,  #2-tuple (width, height).

            'width': img.width, 
            'height': img.height, 
            'palette': img.palette, 
            'info': img.info, 
            'is_animated': img.is_animated, 
            'n_frames': img.n_frames, 
}

img.close()
del img
gc.collect()

meta_dict

### Load and resize images

In [None]:
%%time
img = Image.open(img_path)
print(img.size)
# img = np.asarray(img)

In [None]:
#display the original image
plt.figure(figsize=(8, 8))
plt.imshow(img)
plt.show()

In [None]:
#https://stackoverflow.com/questions/71738218/module-pil-has-not-attribute-resampling
#dealing with pillow version differences
print(PIL.__version__)

In [None]:
#create the thumbnail of the image

if hasattr(Image, 'Resampling'):  # Pillow<8.4.0
    PIL.Image.Resampling = PIL.Image
    img.thumbnail((1024, 1024), resample=Image.Resampling.LANCZOS, reducing_gap=10)
    if (img.height> img.width):
        img = img.transpose(PIL.Image.Transpose.ROTATE_90)
else:
    img.thumbnail((1024, 1024), resample=Image.LANCZOS, reducing_gap=10)
    if (img.height> img.width):
        img = img.transpose(PIL.Image.ROTATE_90)
    
plt.figure(figsize=(8, 8))
plt.imshow(img)
plt.show()

In [None]:
np.asarray(img, np.uint8).min(), np.asarray(img, np.uint8).max()

### Crop and pad images by torchvisaion tranforms



In [None]:
#https://stackoverflow.com/questions/10965417/how-to-convert-a-numpy-array-to-pil-image-applying-matplotlib-colormap

#use torchvision to center crop the image
img2 = transforms.functional.center_crop(img, 1024)
print(img2.size)
plt.figure(figsize=(8, 8))
plt.imshow(img2)
plt.show()

In [None]:
np.asarray(img2, np.uint8).min(), np.asarray(img2, np.uint8).max()

In [None]:
img3 = transforms.functional.pad(img, 10)
print(img3.size)
plt.figure(figsize=(8, 8))
plt.imshow(img3)
plt.show()

In [None]:
np.asarray(img3, np.uint8).min(), np.asarray(img3, np.uint8).max()

### Add Gaussian Blur to images

In [None]:
img4 = transforms.functional.gaussian_blur(img, kernel_size=(5, 9), sigma=(0.1, 5))
print(img4.size)
plt.figure(figsize=(8, 8))
plt.imshow(img4)
plt.show()

In [None]:
np.asarray(img4, np.uint8).min(), np.asarray(img4, np.uint8).max()

In [None]:
img4 = transforms.functional.gaussian_blur(img2, kernel_size=(5, 9), sigma=(0.1, 5))
print(img4.size)
plt.figure(figsize=(8, 8))
plt.imshow(img4)
plt.show()

In [None]:
np.asarray(img4, np.uint8).min(), np.asarray(img4, np.uint8).max()

### Normalized image

- to apply the `torchvisaion` transforms normalize function:
    - first convert the `PIL` image object into numpy array (the data range is `0, 255]`)
    - then reshape the numpy array from height*width*channels (for rgb images, the number of channels is 3) to channels*height*width
    - make the data range from `[0, 255]` to `[0. 1]`
    - normalize the data using `torchvision` *transforms.functional.normalize*
    - reshape the numpy back to height*width*channels
    

In [None]:
img5 = np.asarray(img)
print(img5.shape)
print(img5.min(), img5.max())
img5 = img5.transpose((2,0,1))
print(img5.shape)
img5 = img5/255
print(img5.min(), img5.max())
#make sure the array is normalized to 0-1 before applying normalize
img5 = transforms.functional.normalize(torch.Tensor(img5), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
img5 = img5.numpy().transpose((1,2,0))
print(img5.min(), img5.max())

In [None]:

plt.figure(figsize=(8, 8))
plt.imshow(img5)
# plt.imshow(Image.fromarray(np.uint8(img5)*255))
plt.show()

In [None]:
plt.figure(figsize=(8, 8))
plt.imshow(np.clip(img5, 0, 1))
# plt.imshow(Image.fromarray(np.uint8(img5)*255))
plt.show()

In [None]:
img5_1 = img/np.amax(img5) # if float
img5_1 = np.array(img5_1/np.amax(img5_1)*255, np.int32) # if int

plt.figure(figsize=(8, 8))
plt.imshow(img5_1)
# plt.imshow(Image.fromarray(np.uint8(img5)*255))
plt.show()

In [None]:
(np.uint8(img5)*255).min(), (np.uint8(img5)*255).max()

In [None]:


plt.figure(figsize=(8, 8))
# plt.imshow(img5)
plt.imshow(Image.fromarray(np.uint8(img5)*255))
plt.show()

In [None]:
img5 = np.asarray(img2)
print(img5.shape)

img5 = img5.transpose((2,0,1))
img5 = img5/255
img5 = transforms.functional.normalize(torch.Tensor(img5), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
img5 = img5.numpy().transpose((1,2,0))
plt.figure(figsize=(8, 8))
plt.imshow(img5)
plt.show()

In [None]:
print(img5.min(), img5.max())

In [None]:
#https://www.kaggle.com/code/jirkaborovec/bloodclots-eda-load-wsi-prune-background?scriptVersionId=101797769

def prune_image_rows_cols(im, mask, thr=0.990):
    # delete empty columns
    for l in reversed(range(im.shape[1])):
        if (np.sum(mask[:, l]) / float(mask.shape[0])) > thr:
            im = np.delete(im, l, 1)
    # delete empty rows
    for l in reversed(range(im.shape[0])):
        if (np.sum(mask[l, :]) / float(mask.shape[1])) > thr:
            im = np.delete(im, l, 0)
    return im


def mask_median(im, val=255):
    masks = [None] * 3
    for c in range(3):
        masks[c] = im[..., c] >= np.median(im[:, :, c]) - 5
    mask = np.logical_and(*masks)
    im[mask, :] = val
    return im, mask


In [None]:
img = Image.open(f'{img_folder}/777311_0.png')
print(img.size)
img, mask = mask_median(np.array(img))
img = prune_image_rows_cols(img, mask)
img = Image.fromarray(np.uint8(img))
print(img.size)
if (img.height> img.width):
    img = img.transpose(PIL.Image.ROTATE_90)
ratio = img.height/img.width
img = img.resize((512, int(512*ratio)), resample=Image.LANCZOS, reducing_gap=10)
print(img.size)
img = transforms.functional.center_crop(img, 512)
img = transforms.functional.gaussian_blur(img, kernel_size=(5, 9), sigma=(0.1, 5))


img = np.asarray(img)
print(img.shape)

img = img.transpose((2,0,1))
img = img/255
print(img.shape)
img = transforms.functional.normalize(torch.FloatTensor(img), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
img = img.numpy().transpose((1,2,0))
plt.figure(figsize=(8, 8))
plt.imshow(img)
plt.show()

In [None]:
img6, mask6 = mask_median(np.array(img))
img6 = prune_image_rows_cols(img6, mask6)

plt.figure(figsize=(8, 8))
plt.imshow(img6)
plt.show()

In [None]:
img6.shape

In [None]:
plt.figure(figsize=(8, 8))
plt.imshow(img6/255)
plt.show()

In [None]:
img.shape

In [None]:
img = img.transpose((1,2,0))/255

In [None]:
type(img2)

In [None]:
img = np.zeros((500, 500, 3))

In [None]:
from torchvision.transforms.functional.center_crop

import torchvision
from torchvision import datasets, models, transforms

In [None]:
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops)

In [None]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [None]:
img.shape, img.transpose((2,0,1)).shape

In [None]:
######################################################################
# Visualize a few images
# ^^^^^^^^^^^^^^^^^^^^^^
# Let's visualize a few training images so as to understand the data
# augmentations.

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
