# Making a color image from images in three different bands

## Tools from previous lectures
* Loading FITS files.
* Displaying/plotting FITS images.
* Using Source Extractor to find positions of stars.
* Finding separation between points in images.
* Dictionaries, lists, strings, and loops.

## New for this lecture
* Use pixel value histogram to normalize image.
* Combine three images to make an RGB color image.
* Align different images of the same target.
* Clip and smooth images.


### Import packages

In [None]:
#Most of these you've seen before
import glob
import sep
import numpy as np
from astropy.wcs import WCS
import astropy.io.fits as pyfits
from matplotlib import pyplot as plt

#Except maybe some scipy tools for basic image manipulation
from scipy.ndimage import interpolation as interp
from scipy.ndimage import gaussian_filter

# Outline:

At the beginning of the class we tried to observe some Messier objects with a robotic telescope. Now we're going to put those data together to make a color image, or what astronomers often call an RGB image. Astronomical imaging is usually performed in specific filters which only allow certain wavelengths of light to pass through them. If you observe an object in 3 different filters, then each image you get will be a greyscale image, but you can assign each image a primary color (red/green/blue) and combine them to make a color image. Which is what we're going to do today.

# What filters did we observe with?

First we need to figure out what filters we observed with and which to assign to which primary color.

We can find this out by opening each fits file (FITS is the image format typically used in astronomy) and extracting the filter information from the file header.

In this example we're going to be using images of M83 that were taken in 2021 (hopefully you have already downloaded these).

In [None]:
#Find all files in the current directory starting with "M83" and ending "fits"
files = glob.glob('M83*fits')

#Make an empty list to record filter names
filters = []

#Make an empty dictionary to record image filenames corresponding to each filter
imgname_filters = {}

#Loop over all the files that we found above
for file in files:
    #get the filter from the header
    _filter = pyfits.getheader(file)['FILTER']
    filters.append(_filter)
    
    #Make a new dictionary key for the filter name
    #and assign the filename to that filter
    imgname_filters[_filter] = file
    
print('All the filters we have: ', np.unique(filters))
print(imgname_filters)

## Question:

<font color=blue>
Which filter should we assign to Red, Green, and Blue?
</font>

**Hint**: Maybe this plot of filter transmission curves will help. (Ignore the colors used to plot the curves.)

![](http://spiff.rit.edu/classes/phys440/lectures/filters/bessell.png)

# Visualize the data from one filter

Now let's display one of the images. We're going to start by taking all the pixel values in an image and making a histogram of all the values. This will give us an idea of the minimum and maximum values that we want to set when we display the image.

## Exercise
<font color=blue>
    <ul>
        <li>Grab the filename of the B band entry's first image from our dictionary.</li>
        <li>Plot a histogram of the pixel values from the entire image.</li>
    </ul>
</font>

#### Hints

* use the `pyfits.getdata(filename)` to get the image data
* plot it using `plt.hist(data.flatten(), bins=nbins, range=(lower, higher))`
    * you'll have to figure out what to use for nbins, lower, higher.

# Now...

Let's display this image as an example.

## Question:

<font color=blue>From the histogram above, what values would you use for `vmin` and `vmax`?</font>

In [None]:
vmin, vmax =  ????, ????

fig = plt.figure(figsize=(10,12))

#plot the image
data = pyfits.getdata(imgname_filters['B'])
plt.imshow(data, origin='lower', cmap='gray', vmin=vmin, vmax=vmax)

# Make an RGB image

* How will we do this:
    * our bands need to be ordered from red-blue
    * we loop over these ordered bands
        * append the data to our `simpleRGB` list after applying a scale factor 
        * the scale factor will be from 0-1, closer to 1 the brighter the image will be in that color
        
    * `ax.imshow(simpleRGB)`

In [None]:
#This needs to be ordered as R/G/B
bands = ['R', 'V', 'B'] 

#This will scale how bright images are in each filter
scale_factor = np.array([1.0, 0.9, 1.0])

#Use numpy to find size of image
image_size = np.shape(data)

#Make an empty set of three images
RGBimage=np.zeros((image_size[0],image_size[1],3),dtype=float) 

#Loop over all the bands
for i in range(len(bands)):
    #Use the dictionary "imgname_filters" to read the image file in each band
    data = pyfits.getdata(imgname_filters[bands[i]])
    
    #Set the minimum and maximum of the image
    #using percentiles of the pixel values
    min_value = np.percentile(data.flatten(), 2) #2nd percentile to 98th percentile
    max_value = np.percentile(data.flatten(), 98)
    
    #Scale the data so that the range goes from 0 to 1 for each image
    data = (data - min_value)/(max_value-min_value)
    
    #Place the data for the current filter in the appropriate part of simpleRGB
    RGBimage[:,:,i] = data*scale_factor[i]

Now we can display the stack of three images as a color image. Matplotlib knows that if you give it a stack of three images then it is probably meant to be a RGB image and it will display it as such automatically.

In [None]:
fig = plt.figure(figsize=(10,12))

plt.imshow(RGBimage, origin='lower', interpolation='nearest',vmin=-40, vmax=50)

# Hmmm....

It looks good, but not that great. The image stacking doesn't appear to be perfect.

## Question:

<font color=blue>
    <ul>
        <li>What signs can you see that there's a problem?</li>
        <li>What should be done about it?</li>
    </ul>
</font>

.

.

.

.

.

.

.

.


# Aligning the images

We need to align the images better to remove artifacts.

### How can we do this:

* set the first image in our list as a reference
* use `sep` to find the objects in the reference image
* before stacking and averaging our data. Loop over every object in the iterated image, and the reference objects
    * find the average offset by calculating the distance between the ref_obj and the iterated obj
    * use a `scipy.ndimage.interpolate.shift()` to shift our data based off the offset
* stack and average our offset image

In [None]:
#Here is a function that searches for a reference object
#once it finds it, it returns the shift in the x and y directions

def find_offset(ref_x, ref_y, _data):
    #Get the objects in our current image
    _data = _data.byteswap().newbyteorder()
    _data_bkg = sep.Background(_data)
    data_objs = sep.extract(_data, thresh=20.0, err=_data_bkg.globalrms, minarea=8)
    
    #loop over them and calculate their distances
    for i,j in enumerate(data_objs['x']):
        shift_x = ref_x - data_objs['x'][i]
        shift_y = ref_y - data_objs['y'][i]
        distance = np.sqrt((shift_x)**2+(shift_y)**2)
        
        # if the distance is less than this threshhold, its a match
        # return the offset x,y
        if distance < 10:
            return shift_x, shift_y
            break

In [None]:
shift_band_image = {}

image_size = np.shape(data)
#These commands get the reference image objects
ref_file = imgname_filters['B']
ref_img = pyfits.getdata(ref_file)
ref_img = ref_img.byteswap().newbyteorder() # magic command
ref_bkg = sep.Background(ref_img)
ref_objects = sep.extract(ref_img, thresh=20.0, err=ref_bkg.globalrms, minarea=10)

#loop over the bands: RGB
for band in filters:
    print('Stacking all images in filter: {}'.format(band))
    _shift_data = np.zeros(image_size)
    
    #get the image data
    tmp_data = pyfits.getdata(imgname_filters[band])

    #our lists to hold the offsets
    sx, sy = [], []

    #loop over each reference object
    for i,j in enumerate(ref_objects['x']):
        #find and append the offsets to our lists
        tmp_sx, tmp_sy = find_offset(ref_objects['x'][i], ref_objects['y'][i], tmp_data)
        sx.append(tmp_sx)
        sy.append(tmp_sy)

    #calculate the average offset
    shift_x, shift_y = np.mean(sx), np.mean(sy)

    print('Average offsets in x: {}, y: {}\n'.format(round(shift_x, 3), round(shift_y, 3)))

    #scipy method that shifts the image based off these offsets
    new_data = interp.shift(tmp_data, [shift_y, shift_x])
    #new coadded data from the shift
    _shift_data += new_data

    shift_band_image.update({band: _shift_data})

# Let's plot our aligned/stacked data 

In [None]:
bands = ['R', 'V', 'B'] # This needs to be ordered as R/G/B

scale_factor = np.array([1.0, 0.9, 1.0]) 

RGBimage = np.zeros((image_size[0],image_size[1],3),dtype=float)


for i in range(len(bands)):
    data = shift_band_image[bands[i]].copy()
    
    min_value = np.percentile(data.flatten(), 2) #2nd percentile to 98th percentile
    max_value = np.percentile(data.flatten(), 98)
    
    data = (data - min_value)/(max_value-min_value)
    RGBimage[:,:,i] = (data*scale_factor[i])**1.
                      
fig = plt.figure(figsize=(10,12))

plt.imshow(RGBimage, origin='lower', interpolation='nearest')

The mis-alignment is fixed! 

However, the image still looks pretty noisy and it's hard to clearly see the fainter regions of the galaxy. You can change the power the image is scaled by (`RGBimage[:,:,i]=(data*scale_factor[i])**1.`), e.g. try 0.5 or 2, but this will only help so much. 

Setting the correct visual scale in astronomy is always a trade off. Making the fainter parts of the image stand out will result in the brightest parts being saturated. What scale you want to choose usually depends on what features you are most interested in e.g. the galactic nucleus of the edges of the spiral arms.

# Smooth the image

One simple way to enhance the image, reduce its noise, and bring out the faint features is to smooth (or blur) it. When an image is smoothed, random noise in adjacent pixels will tend to cancel out on average, but real emission coming from the galaxy won't cancel out. So the real emission will appear stronger against the background. The trade off is that the image will be more blurry. For some science cases it's vital to have high resolution images and smoothing is not really an option, but for investigations of diffuse, faint objects smoothing will likely be beneficial.

To perform the smoothing we will use the `scipy` function `gaussian_filter`. Your cell phone camera probably has a Gaussian filter function to blur parts of images. The function smooths/blurs the image using a Guassian function (a bell shape) of width N pixels. Here we're going to use N=2, but you can experiment with different values.

In [None]:
bands = ['R', 'V', 'B'] # This needs to be ordered as R/G/B
scale_factor = np.array([1.0, 0.9, 1.0])
RGBimage = np.zeros((image_size[0],image_size[1],3),dtype=float)

N_smooth = 2


for i in range(len(bands)):
    data = shift_band_image[bands[i]].copy()
    
    min_value = np.percentile(data, 2) #2nd percentile to 98th percentile
    max_value = np.percentile(data, 98)
    
    data = (data - min_value)/(max_value-min_value)
    RGBimage[:,:,i] = (data*scale_factor[i])**1.
    
    RGBimage[:,:,i] = gaussian_filter(RGBimage[:,:,i],N_smooth,mode='wrap')

fig= plt.figure(figsize=(10,12))

plt.imshow(RGBimage, origin='lower', interpolation='nearest')  

Ok, things are looking pretty good now. 

The image is aligned well, the smoothing reduced the noise and brought out the fainter features. However, there are black spots all over the image.



## Question:

<font color=blue>
    <ul>
        <li>What might be causing the black spots? </li>
        <li>Why are they visible in this image, but not the previous one? </li>
    </ul>
</font>

.

.

.

.

.

.


# Clip the image

To remove the black spots we can clip the image (before smoothing) to eliminate the bad pixels. In general in science you need to be very careful about manually modifying your data as this might create a bias in your final results. However, in this case our goal is to make a beautiful final image, like a publisher would for a popular science article, so this approach is ok. For an image which we were going to do science with, we would need to be more careful and identify each bad pixel and remove them individually.

To perform the clipping we will just take the `min_value` and `max_value` that we already defined and clip the data at these values using the `numpy` function `clip`, then smooth as before.

In [None]:
bands = ['R', 'V', 'B'] # This needs to be ordered as R/G/B
scale_factor = np.array([1.0, 0.9, 1.0])
RGBimage = np.zeros((image_size[0],image_size[1],3),dtype=float)

for i in range(len(bands)):
    data = shift_band_image[bands[i]].copy()
    
    min_value = np.percentile(data, 2) #2nd percentile to 98th percentile
    max_value = np.percentile(data, 98)
    
    #Added one line to clip the data
    data = np.clip(data,min_value,max_value)
    data = (data - min_value)/(max_value-min_value)
    
    RGBimage[:,:,i] = (data*scale_factor[i])**1.

    RGBimage[:,:,i] = gaussian_filter(RGBimage[:,:,i],2.,mode='wrap')

fig= plt.figure(figsize=(10,12))

plt.imshow(RGBimage, origin='lower', interpolation='nearest')  

# Perfect!

Show your friends and family! You now know how to make the Astronomy pictures the same way the Hubble does!

Save all of your notebooks and use them for future projects. Maybe you can try out creating a similar image for other objects observed with skynet.

I hope you have a great summer : )