# Order of patterns

The evolution of patterns is investigated by evaluating the most common labels for consequtive timesteps (Terra > Aqua) or days (Terra (Day X) -> Terra (Day X+1) and for Aqua resp.). This analysis is done on a pixel to pixel basis. For any pixel within the image the most common labels per timestep are compared.

In [None]:
import csv
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd

import numpy as np
import seaborn as sns
import datetime as datetime
from netCDF4 import num2date, date2num
import zooniverse as zn

from PIL import Image
import matplotlib.patches as patches
import matplotlib.image as mpimage

import numba as nb

%load_ext autoreload
%autoreload 2

In [None]:
classfile_in = '../zooniverse_raw/sugar-flower-fish-or-gravel-classifications_18_11_30.csv'
subject_in = '../zooniverse_raw/sugar-flower-fish-or-gravel-subjects_18_11_05.csv'

In [None]:
subject_data = zn.load_classifications(subject_in)

In [None]:
def retrieve_fn(x):
    try:
        return x["fn"]
    except:
        return None

def retrieve_date(x):
    try:
        return np.int(x.split('/')[-1].split('_')[1][-8:])
    except:
        return None

def retrieve_region(x):
    regions = ["Region1_DJF", "Region1_MAM"]#, "Region2_DJF", "Region3_DJF", "Region3_SON"]
    try:
        for r,region_str in enumerate(regions):
            if region_str in x: 
                return r
    except:
        return None

def retrieve_satellite(x):
    sats = ["Aqua", "Terra"]
    try:
        for s,sat_str in enumerate(sats):
            if sat_str in x: 
                return s
    except:
        return None

In [None]:
subject_data["filename"] = subject_data["metadata"].apply(retrieve_fn)

In [None]:
subject_data["date"] = subject_data["filename"].apply(retrieve_date)

In [None]:
subject_data["region"] = subject_data["filename"].apply(retrieve_region)

In [None]:
subject_data["satellite"] = subject_data["filename"].apply(retrieve_satellite)

In [None]:
# Calculate pairs of (subject_id(t), subject_id(t+1))
# for every region, satellite
day_pairs = []
pair_type = 'days'
for i,subject in subject_data.iterrows():
    ind = np.array([])
    filename = subject.filename
    # Region test
    region    = subject.region
    satellite = subject.satellite
    date      = subject.date
    next_day  = subject.date+1
    # Time test
    if pair_type == "days":
        ind = np.where(np.logical_and(np.logical_and(subject_data.date == next_day, subject_data.region == region),subject_data.satellite == satellite))[0]
    elif pair_type == "satellite" and satellite == 1: #Terra overpass is first
        ind = np.where(np.logical_and(np.logical_and(subject_data.date == date, subject_data.region == region),subject_data.satellite == 0))[0]
        
        
    if ind.size ==1: day_pairs.append([subject.subject_id ,np.int(subject_data.iloc[ind]["subject_id"].values)])
    if ind.size > 1: day_pairs.append([subject.subject_id ,np.int(subject_data.iloc[ind]["subject_id"].values[0])])

In [None]:
len(day_pairs)

In [None]:
classification_data = zn.parse_classifications(classfile_in,json_columns=['metadata', 'annotations', 'subject_data']);
fulldataset = classification_data[classification_data.workflow_id == 8073]; fulldataset.head()

In [None]:
def extract_labels(annotations):
    annos = []
    for a in annotations:
        for box in a['value']:
            annos.append([box['x'], box['y'], box['width'],
                          box['height'], box['tool_label']
            ])
    return annos

In [None]:
for subject_id1, subject_id2 in day_pairs:
    print(subject_id1, subject_id2)
    classifications_1  = fulldataset[fulldataset.subject_ids == subject_id1]
    classifications_2  = fulldataset[fulldataset.subject_ids == subject_id2]
    boxes_1  = extract_labels(classifications_1.annotations.values)
    boxes_2  = extract_labels(classifications_2.annotations.values)

    break

In [None]:
boxes_2

In [None]:
pattern_dic = {'Sugar': 1, 'Flower': 2, 'Fish': 3, 'Gravel': 4}

In [None]:
# Create dictionary of possible combinations
# pattern(day1) -> pattern(day2)
from itertools import product
patterns = ['0','1','2','3','4']
freq_dic = {}
for p in product(patterns,repeat=2):
    freq_dic[p[0]+"->"+p[1]] = 0
print(freq_dic)

In [None]:
@nb.jit()
def calc_freq(common_boxes_1,common_boxes_2,output):
    for before,after in zip(common_boxes_1.flatten(),common_boxes_2.flatten()):
        output[before,after] += 1/(2100*1400)
#         freq_dic[str(before)+'->'+str(after)] += 1

In [None]:
@nb.jit()
def most_common_boxes(boxes,visualize=False):
    """
    Combine most common boxes of one image
    into one grid
    """
    pattern_dic = {'Sugar': 1, 'Flower': 2, 'Fish': 3, 'Gravel': 4}
    
    grid = np.zeros((2100,1400,5),dtype="int")
    for b,box in enumerate(boxes):
        # Get coordinates of single label
        coords = np.round(box[0:4],0).astype(int)
        x0 = coords[0]
        y0 = coords[1]
        # restrict x1,y1 to domain size
        x1 = min(x0 + coords[2],2100)
        y1 = min(y0 + coords[3],1400)
        pattern = pattern_dic[box[4]]
        # Add box to specific layer of grid
        grid[x0:x1,y0:y1,pattern] += 1
    if visualize: visualize_grid(grid)
    common_box = np.argmax(grid,axis=2)
    if visualize: visualize_common_box(common_box)
    return common_box

In [None]:
def visualize_grid(grid):
    pattern_dic_inv = {1: 'Sugar',2: 'Flower',3: 'Fish',4: 'Gravel'}
    fig, ax = plt.subplots(1,len(pattern_dic_inv))
    for p in pattern_dic.values():
        ax[p-1].set_title(pattern_dic_inv[p])
        ax[p-1].imshow(grid[:,:,p].T,origin="lower",cmap="Blues",alpha=0.8)
        ax[p-1].axis('off')
    plt.show()

In [None]:
def visualize_common_box(common_box):
    fig, ax = plt.subplots(1,1)
    ax.set_title('Common')
    ax.imshow(common_box[:,:].T,origin="lower",cmap="flag",alpha=0.8)
    ax.axis('off')
    plt.show()

In [None]:
def calc_mutation_freq(subject_id_pair, dataset, visualize=False):
    """
    Calculates the order of changing patterns
    """
    subject_id1, subject_id2 = subject_id_pair
    classifications_1  = fulldataset[fulldataset.subject_ids == subject_id1]
    classifications_2  = fulldataset[fulldataset.subject_ids == subject_id2]
    boxes_1  = extract_labels(classifications_1.annotations.values)
    boxes_2  = extract_labels(classifications_2.annotations.values)

    # Prepare boxes for intercomparison with other day
    common_boxes_1 = most_common_boxes(boxes_1,visualize)

    # Prepare boxes for intercomparison with other day
    common_boxes_2 = most_common_boxes(boxes_2,visualize)
    
    calc_freq(common_boxes_1,common_boxes_2,output=freq_arr)
    return None

In [None]:
freq_arr = np.zeros((5,5))
for pair, (subject_id_pair) in enumerate(day_pairs):
    calc_mutation_freq(subject_id_pair, fulldataset,visualize=False)
    if pair%100 == 0: print(pair)

In [None]:
pairs_evaluated = len(day_pairs)
print(np.round(freq_arr/pairs_evaluated*100,0))
print(np.sum(freq_arr/pairs_evaluated*100))

In [None]:
dic = {}
pattern_dic_inv = {0: 'None', 1: 'Sugar',2: 'Flower',3: 'Fish',4: 'Gravel'}
for i,j in zip(np.indices((5,5))[0].flatten(),np.indices((5,5))[1].flatten()):
    dic[pattern_dic_inv[i]+'->'+pattern_dic_inv[j]] = np.round(freq_arr[i,j]/pairs_evaluated*100,1)
dic

In [None]:
sns.set_context('talk')
ax = sns.heatmap(np.round(freq_arr/pairs_evaluated*100,1),annot=True,vmin=0,vmax=3,cmap='BuGn',
            cbar=False,xticklabels=pattern_dic_inv.values(),
            yticklabels=pattern_dic_inv.values());
ax.xaxis.set_ticks_position('top')
plt.gcf().set_dpi(300)

In [None]:
sns.set_context('talk')
percentage_all = freq_arr/pairs_evaluated*100
percentage_not_nones = percentage_all[1:,1:]/np.sum(percentage_all[1:,1:])*100
ax = sns.heatmap(np.round(percentage_not_nones,1),annot=True,vmin=0,vmax=20,cmap='BuGn',
            cbar=False,xticklabels=list(pattern_dic_inv.values())[1:],
            yticklabels=list(pattern_dic_inv.values())[1:]);
ax.xaxis.set_ticks_position('top')
plt.gcf().set_dpi(300)

In [None]:
sns.set_context('talk')
percentage_all = freq_arr
percentage_not_nones = percentage_all[1:,1:]
ax = sns.heatmap(np.round(percentage_not_nones,1),annot=True,vmin=0,vmax=20,cmap='BuGn',
            cbar=False,xticklabels=list(pattern_dic_inv.values())[1:],
            yticklabels=list(pattern_dic_inv.values())[1:]);
ax.xaxis.set_ticks_position('top')
plt.gcf().set_dpi(300)

In [None]:
freq_arr