In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image

In [15]:
def plot_example_Hough_line(data, image_number):
    
    plt.imsave("img.png", data[image_number])
    image = cv2.imread("img.png")
    edges = cv2.Canny(image, 50, 150)
    lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=1)

    rho, theta = lines[0][0]
    a = np.cos(theta)
    b = np.sin(theta)
    x0 = a * rho
    y0 = b * rho
    x1 = int(x0 + 1000 * (-b))
    y1 = int(y0 + 1000 * (a))
    x2 = int(x0 - 1000 * (-b))
    y2 = int(y0 - 1000 * (a))

    cv2.line(image, (x1, y1), (x2, y2), (0, 0, 255), 1) 

    plt.figure(figsize=(8, 6))
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.show()

Define functions for creating oriented dataset

In [None]:
def rotate_array(array, angle):

    img = Image.fromarray(array)
    rotated_img = img.rotate(angle)
    rotated_array = np.asarray(rotated_img)

    return rotated_array

In [7]:
def create_aligned_dataset(data):

    oriented_data = np.empty(data.shape)
    for i in range(len(data)):
        plt.imsave("img.png", data[i])
        image = cv2.imread("img.png")
        edges = cv2.Canny(image, 50, 150)
        lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=1)
    
        if lines is not None:
            rho, theta = lines[0][0]
            theta_degrees = theta * 180/np.pi
            if theta_degrees > 90:
                oriented_data[i] = rotate_array(data[i], theta_degrees - 180)
        
            else:
                oriented_data[i] = rotate_array(data[i], theta_degrees)
        
        else: 
            oriented_data[i] = data[i]

    return oriented_data

Define some functions to clean up the oriented dataset, e.g. looking for images where the track has been rotated out of frame and removing said images

In [8]:
def plot_example_images(data, n_rows, n_cols):

  fig, axes = plt.subplots(n_rows, n_cols, figsize=(30,6))
  cmap=matplotlib.colormaps['viridis']
  im=matplotlib.cm.ScalarMappable()

  for i in range(n_rows * n_cols):
          axes.flat[i].imshow(data[i])
          axes.flat[i].set_axis_off()

  fig.suptitle("Example images")
  fig.colorbar(im, ax=axes.ravel().tolist())
  plt.show()

In [9]:
def mean_pixel_intensity_histogram(data):

    means = []
    for array in data:
        means.append(np.mean(array))

    plt.hist(means, bins=300)
    plt.xlabel("Mean pixel intensity")
    plt.ylabel("Frequency")
    plt.title("Distribution of mean pixel intensities")
    plt.show()

    return None

In [12]:
def filter_dataset(data, names, cutoff):

    mean_array = np.mean(data, axis=(1, 2))
    mask = mean_array > cutoff
    filtered_array = data[mask]
    filtered_names = names[mask]

    return filtered_array, filtered_names
