 # Notebook 06: Convolutions



 ### Primary Goal:



 Explore convolutions and how they work



 #### Background



 Convolutional Neural Networks (CNNs) are powerful machine learning models that can learn both large- and small-scale patterns from multi-dimensional data.  Before we can train a CNN, it is necessary to understand what convolutions are and how they work. Convolutions can be a bit confusing if you are unfamiliar with them, so let's take the time here to explore these important machine learning techniques.



 #### Step 1: Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd 
import numpy as np 
import tqdm 
import matplotlib
import matplotlib.pyplot as plt 
import matplotlib.patheffects as path_effects
import xarray as xr


#outlines for text 
pe1 = [path_effects.withStroke(linewidth=1.5,
                             foreground="k")]
pe2 = [path_effects.withStroke(linewidth=1.5,
                             foreground="w")]


#plot parameters that I personally like, feel free to make these your own.
matplotlib.rcParams['axes.facecolor'] = [0.9,0.9,0.9]
matplotlib.rcParams['axes.labelsize'] = 14
matplotlib.rcParams['axes.titlesize'] = 14
matplotlib.rcParams['xtick.labelsize'] = 12
matplotlib.rcParams['ytick.labelsize'] = 12
matplotlib.rcParams['legend.fontsize'] = 12
matplotlib.rcParams['legend.facecolor'] = 'w'
matplotlib.rcParams['savefig.transparent'] = False
%config InlineBackend.figure_format = 'retina'

#one quick thing here, we need to set the random seed so we all get the same results no matter the computer or python session 
torch.manual_seed(43)


 #### Step 2: Load in some data



 For this notebook we will use the same hook echo example shown in the paper (e.g., Figures 3 and 4).  The data for this example comes from [Lagerquist et al. (2020)](https://journals.ametsoc.org/view/journals/mwre/148/7/mwrD190372.xml). For convenience, we have already isolated the one storm that has a prominent hook echo on radar reflectivity.



In [None]:
ds_sample = xr.open_dataset('../datasets/lagerquist_2020/lagerquist_storm_example.nc',engine='netcdf4')

#print out the dataset 
ds_sample


 #### Step 3: Reminder about what images are



 Remember that images are just matrices of values, where the values determine the color that is shown in the image. To drive this point home, let's zoom into the IR image from Notebook 2.



In [None]:
#import some helper functions for our other directory.
import sys
sys.path.insert(1, '../scripts/')

#load contingency_table func
from aux_functions import show_vals


#make a big figure so we can see the pixels
plt.figure(figsize=(15,15))
plt.imshow(ds_sample.radar_image_matrix[:,:,0,0],vmin=0,vmax=60,cmap='Spectral_r')
plt.gca().axis('off')

#draw box to see where the next plot will zoom into
x_vertices = np.array([4.,16.,16.,4.,4.,])
y_vertices = np.array([10.,10.,22.,22.,10.])
plt.plot(x_vertices-0.5,y_vertices-0.5,'-w',lw=2)

# #add manual annotation 
plt.text(2,6,'Zoom in box',color='w',fontsize=32)


#make a big figure so we can see the pixels
plt.figure(figsize=(15,15))
da = ds_sample.radar_image_matrix.isel(grid_row=slice(10,22),grid_column=slice(4,16),radar_height=0,radar_field=0)
plt.imshow(da,vmin=0,vmax=60,cmap='Spectral_r')
show_vals(da,plt.gca())
plt.gca().axis('off')



 Hopefully this helps illustrate that images are just matrices of values.





 Next, we will take a step-by-step look at the convolution process to demonstrate what a convolution does and how it is used in CNNs.



 #### Step 4: What are Convolutions?



 I often find animations can be helpful for describing an idea. Below is an animated version of Figure 3 from the paper. Don't get bogged down in the math right now, just notice how the kernel (filter) is incrementally stepped through the image systematically. This example is the same 12x12 image (the zoom in) from above.



 <center><img src="../images/convolution_animation_01.gif" alt="drawing" width="600"/></center>



 Hopefully it is more apparent now what a ```convolution``` is. At this point we have not shown any machine learning - this is simply an image processing technique (i.e., like a mathematical operator). For those more mathematically inclined, the convolution is:



 $$\begin{equation} p_{x,y}= \sum_{j=y-k}^{j=y+k}\sum_{i=x-k}^{i=x+k} w_{i,j} p_{i,j}, \label{e3} \end{equation}$$



 where $p_{x,y}$ is the pixel value at position x,y in the image matrix and k is there kernel size. This equation is then iterated for all pixels in the image.



 Normally, the convolution kernel (the middle picture in the animation) has weights that are more sophisticated than what was shown. For example, you might have noticed the right image is the same as the left image, just with a white border around it. That is because the example uses the identity kernel to set the convolution weights.



 $$\begin{bmatrix} 0 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 0 \end{bmatrix}$$



 Thus, the kernel is simply passing the middle pixel through with no change. Other common kernels in image processing are:



 ##### **Sharpen:** #

 $$\begin{bmatrix} 0 & -1 & 0 \\ -1 & 5 & -1 \\ 0 & -1 & 0 \end{bmatrix}$$



 Here is this kernel being applied to our zoomed-in patch:



 <center><img src="../images/sharpen_image.png" alt="drawing" width="600"/></center>



 and on the entire image:



 <center><img src="../images/sharpen_image_full.png" alt="drawing" width="600"/></center>





 ##### **Blur** (i.e., mean)



 $$ \frac{1}{9} \begin{bmatrix} 1 & 1 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{bmatrix}$$



 Likewise, here is the blur kernel applied to the zoomed-in patch



 <center><img src="../images/blur_image.png" alt="drawing" width="600"/></center>



 and on the whole image:



 <center><img src="../images/blur_image_full.png" alt="drawing" width="600"/></center>





 I would recommend checking out the webpage [here](https://en.wikipedia.org/wiki/Kernel_(image_processing)) for more examples.





 Again, we have not actually applied any machine learning so far. The discussion has solely been focused on the convolution operation.







 #### Step 5: Implement a convolution with code



 It is possible to manually code a convolution yourself, but for the sake of simplicity and speed, PyTorch has a convolutional layer that handles this for you:



 ```torch.nn.Conv2d()```



 We will need to configure the layer to use the desired set of weights, but this is relatively simple to do.



 **Important Note on Data Shapes:** PyTorch expects images in the format `(Batch, Channels, Height, Width)`, often abbreviated as NCHW. However, many image libraries (and Matplotlib) use `(Height, Width, Channels)`. We will need to reshape our data accordingly.

In [None]:
# Convert xarray data to numpy, then to a torch tensor
# We want to resize it to 36x36. 
input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values

# Convert to torch tensor and add batch/channel dimensions to match PyTorch expectations: (Batch, Channel, Height, Width)
# Original is (H, W), we need (1, 1, H, W)
input_tensor = torch.from_numpy(input_data).unsqueeze(0).unsqueeze(0).float()

# Resize using interpolate.
more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)

#define the sharpen filter
kernel_weights = np.array([[ 0, -1, 0],[ -1,5,-1],[0,-1,0]])

# Define conv with specific weights. 
# PyTorch Conv2d weights shape is (Out_Channels, In_Channels, Kernel_H, Kernel_W)
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)

# Set the weights manually
with torch.no_grad():
    # We reshape our 3x3 kernel to (1, 1, 3, 3)
    conv.weight.data = torch.from_numpy(kernel_weights).float().view(1, 1, 3, 3)

# Run the data through 
res = conv(more_points)


# Plotting
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))

# We need to squeeze the dimensions back out to (36, 36) for matplotlib
ax1.imshow(more_points.squeeze(), vmin=0, vmax=60, cmap='Spectral_r')
ax1.set_title('Original')
ax1.axis('off')

# Show result. 
ax2.imshow(res.detach().squeeze(), vmin=0, vmax=60, cmap='Spectral_r')
ax2.set_title('Convolved')
ax2.axis('off')


 As you hopefully can see, PyTorch's built-in convolution layer is a convenient and easy way to perform a convolution.





 #### Step 6: Try a new image



In [None]:
#define URL of image here. (right click an image online and get image address)
url = '[https://dopplerchase-ai2es-schooner-hpc.readthedocs.io/en/latest/_images/ai2es-logo-web-trans.png](https://dopplerchase-ai2es-schooner-hpc.readthedocs.io/en/latest/_images/ai2es-logo-web-trans.png)'

#load in some image packages, dont worry about these
from PIL import Image,ImageOps
import requests
from io import BytesIO

#this grabs the image and turns it into an array of data we can use
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img = ImageOps.grayscale(img)
arr = np.array(img.convert('F'))

# Preprocess for PyTorch: (H, W) -> (1, 1, H, W)
arr_tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).float()

# Resize the image so this code works for any image given
# PyTorch interpolate expects (Batch, Channel, H, W)
arr_tensor = F.interpolate(arr_tensor, size=(260, 260), mode='bilinear', align_corners=False)


#define a filter!

#blur 
kernel_weights = np.array([[ 1, 1, 1],[ 1,1,1],[1,1,1]])*(1/9.)

#sharpen
# kernel_weights = np.array([[ 0, -1, 0],[ -1,5,-1],[0,-1,0]])

# Define conv with specific weights
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)

# Set weights
with torch.no_grad():
    conv.weight.data = torch.from_numpy(kernel_weights).float().view(1, 1, 3, 3)

# Run the conv
res = conv(arr_tensor)


# Plot it up. 
fig,(ax1,ax2,ax3) = plt.subplots(1,3,figsize=(15,5),facecolor='w')

# Squeeze dimensions for plotting (1, 1, H, W) -> (H, W)
input_img = arr_tensor.squeeze()
output_img = res.detach().squeeze()

ax1.imshow(input_img, vmin=0, vmax=255, cmap='Greys_r')
ax1.set_title('Original')
ax1.axis('off')

# Show result. 
ax2.imshow(output_img, vmin=0, vmax=255, cmap='Greys_r')
ax2.set_title('Convolved')
ax2.axis('off')

# Calculate difference. We need to slice the input to match the output size (convolution reduces size by kernel_size - 1)
# PyTorch and TF handle this similarly (valid padding), losing 1 pixel on each side for a 3x3 kernel.
diff = input_img[1:-1, 1:-1] - output_img

ax3.imshow(diff, vmin=-100, vmax=100, cmap='seismic')
ax3.set_title('Difference')
ax3.axis('off')



 How about a space cowboy?

In [None]:
url = '[https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg](https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg)'

#load in some image packages, dont worry about these
from PIL import Image,ImageOps
import requests
from io import BytesIO

#this grabs the image and turns it into an array of data we can use
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img = ImageOps.grayscale(img)
arr = np.array(img.convert('F'))

# Preprocess for PyTorch: (H, W) -> (1, 1, H, W)
arr_tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).float()

# Resize the image so this code works for any image given
arr_tensor = F.interpolate(arr_tensor, size=(260, 260), mode='bilinear', align_corners=False)


#define a filter!

#blur 
kernel_weights = np.array([[ 1, 1, 1],[ 1,1,1],[1,1,1]])*(1/9.)

#sharpen
# kernel_weights = np.array([[ 0, -1, 0],[ -1,5,-1],[0,-1,0]])

# Define conv with specific weights
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)

with torch.no_grad():
    conv.weight.data = torch.from_numpy(kernel_weights).float().view(1, 1, 3, 3)

#run the conv
res = conv(arr_tensor)


#plot it up. a figure with 3 subplots in the column direction 
fig,(ax1,ax2,ax3) = plt.subplots(1,3,figsize=(15,5))

input_img = arr_tensor.squeeze()
output_img = res.detach().squeeze()

ax1.imshow(input_img, vmin=0, vmax=255, cmap='Greys_r')
ax1.set_title('Original')
ax1.axis('off')

#show result. 
ax2.imshow(output_img, vmin=0, vmax=255, cmap='Greys_r')
ax2.set_title('Convolved')
ax2.axis('off')

diff = input_img[1:-1, 1:-1] - output_img

ax3.imshow(diff, vmin=-100, vmax=100, cmap='seismic')
ax3.set_title('Difference')
ax3.axis('off')


 #### Step 7: Padding...



 If you consider this gif again, look closely at the edge of the right image:



 <center><img src="../images/convolution_animation_01.gif" alt="drawing" width="600"/></center>



 See how there is a border of no data showing up around the image? This is not a bug. The convolution kernel shown here processes data from the left, right, top, and bottom of the center pixel.  But that means the kernel cannot process pixels around the edge of the image, as there is no data available on at least one side of the pixel.  As a result, the kernel will skip those pixels and the output will be smaller than the input.  In other words, convolutions will actually reduce the resolution of the image.



 A way to prevent the loss of pixels is a process known as **Padding**. This is where the *input* image is padded with 0s, such that the output of the convolution now keeps the same shape as the original image. This next code block adds in a zero padding around the image.

In [None]:
import copy 

#define image on the left
Z_old = copy.deepcopy(da.values)

# PyTorch padding is usually (left, right, top, bottom) for the last 2 dimensions.
# Here we want 1 pixel padding on all sides.
# First, convert Z_old to a tensor with dims (1, 1, H, W) because F.pad expects batches/channels or raw dimensions
Z_tensor = torch.from_numpy(Z_old).unsqueeze(0).unsqueeze(0).float()

# Pad: (left=1, right=1, top=1, bottom=1)
Z_padded = F.pad(Z_tensor, (1, 1, 1, 1), mode='constant', value=0)

# Squeeze back to (H, W) for plotting
Z = Z_padded.squeeze().numpy()

fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))
ax1.imshow(Z_old,vmin=0,vmax=60,cmap='Spectral_r')
ax1.set_title('Original')
ax1.axis('off')
#show result. 
ax2.imshow(Z,vmin=0,vmax=60,cmap='Spectral_r')
ax2.set_title('Padded')
ax2.axis('off')


 Here is the same gif as before, but now with the padded image to show you that the original image size is now preserved.



 <center><img src="../images/convolution_animation_01_padded.gif" alt="drawing" width="600"/></center>



 PyTorch has implemented this in the code for us. All we need to do is make sure the ```padding``` argument is set in the convolution function. For a 3x3 kernel, a padding of "same" (keeping output size equal to input size) is achieved by setting `padding='same'`.

In [None]:
# Convert input to tensor and resize
input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values
input_tensor = torch.from_numpy(input_data).unsqueeze(0).unsqueeze(0).float()
more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)

#define the sharpen filter
kernel_weights = np.array([[ 0, -1, 0],[ -1,5,-1],[0,-1,0]])
# kernel_weights = np.array([[ 1, 1, 1],[ 1,1,1],[1,1,1]])*(1/9.)

# Define conv with specific weights. 
# We add padding='same' to preserve dimensions. 
# Note: In older PyTorch versions you might see padding=1 used for a 3x3 kernel.
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False, padding='same') 

with torch.no_grad():
    conv.weight.data = torch.from_numpy(kernel_weights).float().view(1, 1, 3, 3)

#run the data through 
res = conv(more_points)


fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))
ax1.imshow(more_points.squeeze(), vmin=0, vmax=60, cmap='Spectral_r')
ax1.set_title('Original')
ax1.axis('off')
#show result. 
ax2.imshow(res.detach().squeeze(), vmin=0, vmax=60, cmap='Spectral_r')
ax2.set_title('Convolved')
ax2.axis('off')


 Notice the line of higher (red) values on the far-right side of the right image. This is a result of the padding. The sharpen filter in particular tends to exhibit some potentially undesirable behavior when the kernel is half-full of 0s.



 While it looks weird here, remember we still haven't done any machine learning yet. So maybe for our example here, the sharpen filter isn't the best choice for tornado classification.



 In the next notebook, we will apply what we've learned here by training a CNN to detect lightning in an image using the ```sub-sevir``` dataset.