# Random Forest Classification of X-ray Images

This notebook uses scikit-learn random forest to classify X-ray images into good coverage and poor coverage. 

## Import packages, define metrics, prepare datasets

In [None]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
import os
from sklearn.utils import shuffle
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [None]:
def countZeroPixels(image):
    """Return the number of value-zero pixels in the image."""
    image=image.flatten()
    return len(image)-np.count_nonzero(image)

def getImageContrast(image):
    """Calculate a simple measure of the image contrast."""
    image=image.flatten()
    image = image[image>0]
    maxp = np.max(image)
    minp = np.min(image)
    return (maxp-minp)/(maxp+minp)

def npixelsAboveNoise(image,threshold=5):
    """Calculate the number pixels above X*sigma of the noise."""
    image=image.flatten()
    image=image[image>0]
    rmsnoise = np.sqrt(np.mean(image**2.))
    return len(image[np.where(image>threshold*rmsnoise)])

def meanSepBrightPixels(image,threshold=5):
    """Return the mean separation (in px) between bright pixels."""
    image=image.flatten()
    image=image[image>0]
    rmsnoise = np.sqrt(np.mean(image**2.))
    sel = np.where(image>threshold*rmsnoise)
    if len(sel[0])==0:
        return 300
    else:
        return (90000/len(image[sel]))**0.5

def medianBrightYPosition(image,threshold=10):
    """Compute the median y-value among bright pixels."""
    rmsnoise = np.sqrt(np.mean(image**2.))
    positions = np.where(image>threshold*rmsnoise)
    return np.median(positions[0])

def medianBrightXPosition(image, threshold=10):
    """Compute the median x-value among bright pixels."""
    rmsnoise = np.sqrt(np.mean(image**2.))
    positions = np.where(image>threshold*rmsnoise)
    return np.median(positions[1])

def lr_contrast(image):
    # left and right sections of image
    left = image[:, :150]
    right = image[:, 150:]
    return np.abs(getImageContrast(left)-getImageContrast(right))

def ud_contrast(image):
    # left and right sections of image
    up = image[150:, :]
    down = image[:150, :]
    return np.abs(getImageContrast(up)-getImageContrast(down))

def triangular_contrast1(image):
    # get upper, lower triangle
    upper = np.triu(image)
    lower = np.tril(image)
    return np.abs(getImageContrast(upper)-getImageContrast(lower))

def triangular_contrast2(image):
    # get upper, lower triangle
    upper = np.triu(np.fliplr(image))
    lower = np.tril(np.fliplr(image))
    return np.abs(getImageContrast(upper)-getImageContrast(lower))

def ud_difference(image, threshold=10):
    # left and right sections of image
    rmsnoise = np.sqrt(np.mean(image**2.))
    up = image[150:, :]
    down = image[:150, :]
    n_bright_up = len(up[up>threshold*rmsnoise])
    n_bright_down = len(down[down>threshold*rmsnoise])
    return np.abs(n_bright_up-n_bright_down)

def lr_difference(image, threshold=10):
    # left and right sections of image
    rmsnoise = np.sqrt(np.mean(image**2.))
    left = image[:, :150]
    right = image[:, 150:]
    n_bright_l = len(left[left>threshold*rmsnoise])
    n_bright_r = len(right[right>threshold*rmsnoise])
    return np.abs(n_bright_l-n_bright_r)

def tri_difference1(image,threshold=10):
    # get upper, lower triangle
    rmsnoise = np.sqrt(np.mean(image**2.))
    upper = np.triu(image)
    lower = np.tril(image)
    n_upper = len(upper[upper>threshold*rmsnoise])
    n_lower = len(lower[lower>threshold*rmsnoise])
    return np.abs(n_upper-n_lower)

def tri_difference2(image,threshold=10):
    # get upper, lower triangle
    rmsnoise = np.sqrt(np.mean(image**2.))
    upper = np.triu(np.fliplr(image))
    lower = np.tril(np.fliplr(image))
    n_upper = len(upper[upper>threshold*rmsnoise])
    n_lower = len(lower[lower>threshold*rmsnoise])
    return np.abs(n_upper-n_lower)

def lr_zero(image, threshold=10):
    # left and right sections of image
    left = image[:, :150]
    right = image[:, 150:]
    n_bright_l = len(left[left==0])
    n_bright_r = len(right[right==0])
    return np.abs(n_bright_l-n_bright_r)

def ud_zero(image, threshold=10):
    # left and right sections of image
    up = image[150:, :]
    down = image[:150, :]
    n_bright_up = len(up[up==0])
    n_bright_down = len(down[down==0])
    return np.abs(n_bright_up-n_bright_down)

def symmetry_lr(image, clipsigma=10):
    orig = np.copy(image)
    clippedmean, jk, jk = sigma_clipped_stats(image[image!=0], sigma=10, maxiters=2, cenfunc=np.mean)
    image[image>clipsigma*clippedmean] = clippedmean
    flipped = np.fliplr(image)
    diff = np.sum([image, -1*flipped], axis=0)
    return np.sqrt(np.mean(diff**2.))

def symmetry_ud(image, clipsigma=10):
    orig = np.copy(image)
    clippedmean, jk, jk = sigma_clipped_stats(image[image!=0], sigma=10, maxiters=2, cenfunc=np.mean)
    image[image>clipsigma*clippedmean] = clippedmean
    flipped = np.flipud(image)
    diff = np.sum([image, -1*flipped], axis=0)
    return np.sqrt(np.mean(diff**2.))


metadata_funcs = [countZeroPixels, symmetry_lr, symmetry_ud]#, ud_difference, lr_difference, lr_zero, ud_zero]
#metadata_funcs = [ud_difference, lr_difference, lr_zero, ud_zero, meanSepBrightPixels,\
#                  npixelsAboveNoise, getImageContrast, countZeroPixels, tri_difference1, tri_difference2]

## Metadata proof of concept

In [None]:
testimage = fits.open('/srv/scratch/zhutchen/khess_images/poor_coverage/RASS-Int_Hard_grp9530.0_.fits')[0].data

plt.figure()
plt.imshow(testimage)
plt.show()

[fx(testimage) for fx in metadata_funcs]

In [None]:
testimage = fits.open('/srv/scratch/zhutchen/khess_images/poor_coverage/RASS-Int_Hard_grp3562.0_.fits')[0].data

plt.figure()
plt.imshow(testimage)
plt.show()

[fx(testimage) for fx in metadata_funcs]

In [None]:
testimage = fits.open('/srv/scratch/zhutchen/khess_images/nondetections/RASS-Int_Hard_grp10007.0_.fits')[0].data

plt.figure()
plt.imshow(testimage)
plt.show()

[fx(testimage) for fx in metadata_funcs]

# Create training/validation dataset

In [None]:
imagesXgood = []
labelsygood = []
imagesXpoor = []
labelsypoor = []

dpath = "/srv/scratch/zhutchen/khess_images/detections/"
for f in os.listdir(dpath):
    if f.endswith('.fits'):
        #imagesX.append(np.array(fits.open(dpath+f)[0].data).flatten())
        image = fits.open(dpath+f)[0].data
        imagesXgood.append(np.array([fx(image) for fx in metadata_funcs]))
        labelsygood.append('good_coverage')

ndpath = "/srv/scratch/zhutchen/khess_images/nondetections/"
for f in os.listdir(ndpath):
    if f.endswith('.fits'):
        #imagesX.append(np.array(fits.open(ndpath+f)[0].data).flatten())
        image = fits.open(ndpath+f)[0].data
        imagesXgood.append(np.array([fx(image) for fx in metadata_funcs]))
        labelsygood.append('good_coverage')

i=0
pcpath = "/srv/scratch/zhutchen/khess_images/poor_coverage_augmented/"
for f in os.listdir(pcpath):
    if f.endswith('.fits') and i>-1:
        #imagesX.append(np.array(fits.open(pcpath+f)[0].data).flatten()) #flatten each 300x300 image to 1x90000
        image = fits.open(pcpath+f)[0].data
        imagesXpoor.append(np.array([fx(image) for fx in metadata_funcs]))
        labelsypoor.append('poor_coverage')
        i+=1

Now that we have the arrays filled in, we need to separate them into training and validation data. Typically we would do something like
```
imagesXtrain, imagesXtest, labelsytrain, labelsytest = train_test_split(imagesX, labelsy,\
                                                                        test_size=0.2, random_state=46)
```                                                                        
but it's more complicated here. For the good coverage (detections + nondetections path), we can simply split the array on the training percentage (typically 80%). 

In [None]:
trsplit = 0.8

In [None]:
imagesXgood_train = imagesXgood[0:int(trsplit*len(imagesXgood))]
labelsygood_train = labelsygood[0:int(trsplit*len(imagesXgood))]

In [None]:
imagesXgood_test = imagesXgood[int(trsplit*len(imagesXgood)):]
labelsygood_test = labelsygood[int(trsplit*len(imagesXgood)):]

In [None]:
assert (len(imagesXgood_train)+len(imagesXgood_test))==len(imagesXgood)

For the poor coverage, however, most of our data are transformations of about 150 original images. We need to ensure that the validation dataset includes only fresh images (and their transformations), so that the classifier does not "see" a training image while testing its accuracy. Some of our metadata metrics (e.g. number of zero pixels) could be invariant with the transformation (e.g. rotation), so we want to ensure that our validation is not biased by the classifier already having seen a variant of the image.

Fortunately, `os.listdir` reads the files in order, so we just need to figure out the first original image near ~80% for training, and use it. 

In [None]:
npc = len(imagesXpoor) # number of poor images
perc80 = int(0.8*len(imagesXpoor))
print(perc80)
print(os.listdir(pcpath)[perc80])
print('----')

for i, nm in enumerate(os.listdir(pcpath)):
    if i>(perc80-20) and i<(perc80+20):
        print(i,i/npc,nm)
#for i,nm in os.listdir(pcpath):

The split lands us on group 20170, and we can work backwards to include all of them in the validation set. Group 20170 first appears at index 3045. So that's where we split the training and validation datasets for poor images.

In [None]:
#os.listdir(pcpath)[0:3045] # note this doesn't include group 20170
#os.listdir(pcpath)[3045:] # note starts on first 20170 image

In [None]:
imagesXpoor_train = imagesXpoor[0:3045]
labelsypoor_train = labelsypoor[0:3045]
imagesXpoor_test = imagesXpoor[3045:]
labelsypoor_test = labelsypoor[3045:]

In [None]:
assert len(imagesXpoor_train) + len(imagesXpoor_test) == len(imagesXpoor)

Now that we've split it up appropriately, let's combine everything into a final training set and a final validation set.

In [None]:
imagesXgood_train.extend(imagesXpoor_train)
imagesXtrain = imagesXgood_train
labelsygood_train.extend(labelsypoor_train)
labelsytrain = labelsygood_train

imagesXgood_test.extend(imagesXpoor_test)
imagesXtest = imagesXgood_test
labelsygood_test.extend(labelsypoor_test)
labelsytest = labelsygood_test

In [None]:
len(imagesXtrain), len(imagesXtest)

Now just shuffle the data to remove the pattern of file transformations.

In [None]:
imagesXtrain, labelsytrain = shuffle(imagesXtrain, labelsytrain, random_state=46)
imagesXtest, labelsytest = shuffle(imagesXtest, labelsytest, random_state=46)



## Initiate Classifier 

In [None]:
clf = RandomForestClassifier()

## Training the Model

In [None]:
# breakdown valiation sample
labelsytrain = np.array(labelsytrain)
labelsytest = np.array(labelsytest)
print("Percent of poor coverage in training sample: {}".format(len(labelsytrain[labelsytrain=='poor_coverage'])/len(labelsytrain)))
print("Percent of poor coverage in validation sample: {}".format(len(labelsytest[labelsytest=='poor_coverage'])/len(labelsytest)))

In [None]:
clf.fit(imagesXtrain, labelsytrain)

## Test the results 

In [None]:
preds = clf.predict(imagesXtest)
print("Accuracy: ", accuracy_score(labelsytest,preds))

In [None]:
importances = clf.feature_importances_
feature_names = [f.__name__ for f in metadata_funcs]
std = np.std([tree.feature_importances_ for tree in clf.estimators_], axis=0)
plt.figure(figsize=(20,5))
plt.bar(feature_names, importances, yerr=std)
plt.gca().set_xticklabels(feature_names, fontsize=8)
plt.show()
print(np.sum(importances))

## Test individual images

In [None]:
image = fits.open("/srv/two/zhutchen/rosat_xray_stacker/g3rassimages/eco/RASS-Int_Broad_grp10300_ECO06627.fits")
image = image[0].data
image_metadata = np.array([fx(image) for fx in metadata_funcs]).reshape(1,-1)

plt.figure()
plt.imshow(image)
plt.show()

print(clf.predict_proba(image_metadata))
print(clf.predict(image_metadata))

In [None]:
image = fits.open("/srv/scratch/zhutchen/khess_images/poor_coverage/RASS-Int_Hard_grp11771.0_.fits")
image = image[0].data
image_metadata = np.array([fx(image) for fx in metadata_funcs]).reshape(1,-1)

plt.figure(figsize=(7,7))
plt.imshow(image)
plt.show()

print(clf.predict_proba(image_metadata))
print(clf.predict(image_metadata))

In [None]:
image = fits.open("/srv/two/zhutchen/rosat_xray_stacker/g3rassimages/eco/RASS-Int_Soft_grp10003_ECO05407.fits")
image = image[0].data
image_metadata = np.array([fx(image) for fx in metadata_funcs]).reshape(1,-1)

plt.figure(figsize=(7,7))
plt.imshow(image)
plt.show()

print(clf.predict_proba(image_metadata))
print(clf.predict(image_metadata))

## Test on all the original poor coverage images

In [None]:
origtestimages=[]
origtestlabels=[]

origpcpath = "/srv/scratch/zhutchen/khess_images/poor_coverage/"
for f in os.listdir(origpcpath):
    if f.endswith('.fits'):
        #imagesX.append(np.array(fits.open(pcpath+f)[0].data).flatten()) #flatten each 300x300 image to 1x90000
        image = fits.open(pcpath+f)[0].data
        origtestimages.append(np.array([fx(image) for fx in metadata_funcs]))
        origtestlabels.append('poor_coverage')
        i+=1

norig = len(origtestimages)
print(norig)
        
origtestimages.extend(imagesXgood[int(trsplit*len(imagesXgood)):][0:norig])
origtestlabels.extend(labelsygood[int(trsplit*len(labelsygood)):][0:norig])

In [None]:
len(origtestimages)

In [None]:
preds = clf.predict(origtestimages)
print("Accuracy: ", accuracy_score(origtestlabels,preds))
