In [None]:
import operator
import pickle 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import omicronscala
import spym
import xarray
import os
from pathlib import Path
from tqdm import tqdm

In [None]:
def load_stm():
    """load stm dataset from cleaned version"""
    with open('clean_meta.pkl','rb') as f:
        df = pickle.load(f)
    return df

def get_columns(df, metadata, treshold=100):
    """groups images by metadata values, metadata variable is a list, also if one element"""
    df = df.groupby(metadata).size().reset_index(name='N imgs')
    df = df[df['N imgs'] >= treshold ]
    cols = metadata
    cols.insert(0, 'N imgs')
    df = df[cols].sort_values('N imgs', ascending=False)
    return df

def df_images(df, cols, values):
    """create df"""
    df = df.loc[ (df[cols[1]] == values[0]) & (df[cols[2]] == values[1]) ]
    return df
           
def get_imgs_df_by_size(df, category, size):
    """get df with image size as suggested by Mirco (S<=20, 20<M<=70, L>70)"""
    cols = df.columns
    if size == 'S':
        images = df[(df['Categories'] == category) & (df['FieldXSizeinnm'] <= 20.0 )][cols]
    elif size == 'M':
        images = df[(df['Categories'] == category) & (df['FieldXSizeinnm'] <= 70.0 ) & (df['FieldXSizeinnm'] > 20.0 )][cols]
    elif size == 'L':
        images = df[(df['Categories'] == category) & (df['FieldXSizeinnm'] > 70.0 )][cols]
    else:
        return pd.DataFrame()
    return images

def get_img(imgs_path, img):
    """get row of image from df, return list of [row,plot] for that image"""
    file = img['ImageOriginalName']
    ds = omicronscala.to_dataset(Path(imgs_path+file))
    tf = ds.Z_Forward
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return [img, tf]

def plot_images(df, category, size, rows=5, cols=3):
    """exploratory analysis of images from a category, save one plot for every size (S,M,L) of rows x cols images"""
    plt.ioff()
    imgs = imgSize(df, category, size)
    if len(imgs) == 0:
        return print('{} with size {} is empty.'.format(category, size))
    try:
        samples = imgs.sample(rows*cols)
    except Exception as e:
        print(e)
        samples = imgs

      
    images = []
    path = 'G:\STM'
    for _, image in samples.iterrows():
        try:
            images.append(show_img(path,image))
        except Exception as e:
            print(e)
            print(image['ImageOriginalName'])
            
    if cols > len(images):
        cols = len(images)
        rows = 1
    else:
        rows = len(images) // cols     
    
    
    fig, axs = plt.subplots(rows, cols, figsize=(2+(8*cols),(8*rows)))
    fig.suptitle('[{}] {}'.format(size, category), weight='bold', fontsize=30)
    c = 0
    if rows > 1:
        for i in range(rows):
            for j in range(cols):
                images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                axs[i,j].set_title('[{}] {}'.format(images[c][0]['Date'],images[c][0]['TF0_Filename']), weight='bold', fontsize=20)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(16)
                c +=1
    else:
        for j in range(cols):
            images[c][1].plot(ax=axs[j], cmap='afmhot', add_colorbar=False )
            axs[j].set_title('[{}] {}'.format(images[c][0]['Date'],images[c][0]['TF0_Filename']), weight='bold', fontsize=20)
            for item in ([axs[j].xaxis.label, axs[j].yaxis.label] +
                  axs[j].get_xticklabels() + axs[j].get_yticklabels()):
                item.set_fontsize(16)
            c +=1
    plt.tight_layout(rect=[0, 0.03, 1, 0.92])
    plt.draw()
    
    Path('categories').mkdir(parents=True, exist_ok=True)
    plt.savefig('categories/{}_{}_{}x{}.png'.format(category,size, rows, cols), dpi=100)
    plt.close(fig)
    
def images_overview(df, size, cols=3):
    """exploratory analysis of images size, save one plot for selected size 
       Each row is a different category,  so N images in plot = N categories x cols """
    images = []
    categories = []
    plt.ioff()
    for category in stm['Categories'].unique():
        if len(category.split(',')) == 1 and category != '':
            categories.append(category)
            imgs = imgSize(df, category, size)
            if len(imgs) == 0:
                print('{} with size {} is empty.'.format(category, size))
                pass
            try:
                samples = imgs.sample(cols)
            except Exception as e:
                print(e)
                samples = imgs

            path = 'G:\STM'
            tmp_imgs = []
            for _, image in samples.iterrows():
                try:
                    tmp_imgs.append(show_img(path,image))
                except Exception as e:
                    print(e)
                    print(image['ImageOriginalName'])
            images.append(tmp_imgs)
            
    rows = len(categories)
    
    fig, axs = plt.subplots(rows, cols, figsize=(2+(8*cols),(8*rows)))
    fig.suptitle('[{}] {}'.format(size, 'Dataset Overview'), weight='bold', fontsize=30)
    for i in range(rows):
        for j in range(cols):
            if j < len(images[i]):
                images[i][j][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                axs[i,j].set_title('[{}] {}'.format(images[i][j][0]['Date'],images[i][j][0]['Categories']), weight='bold', fontsize=20)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(16)
            else:
                axs[i,j].axis('off')
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    
    Path('overviewSize').mkdir(parents=True, exist_ok=True)

    plt.savefig('overviewSize/overview_{}{}.png'.format(size, cols), dpi=100)
    plt.close(fig)
    
    
def get_img_id(df, date, name):
    """get ID from date and name of image"""
    img_id = df[(df['Date'] == date) & (df['TF0_Filename'] == name )]['ID']
    return img_id


def filter_df(df, column, value, operation):
    """get image df with rows whose column fields satisfy operation with respect to value"""
    filtered = df[(operation(df[column],value))][:]
    return filtered


def aligned_data(df, date, name):
    """return aligned data of image, rotated to match heatmap plot."""
    file = df[(df['Date'] == date) & (df['TF0_Filename'] == name )]['ImageOriginalName'].item()
    ds = omicronscala.to_dataset(Path(path+file))
    tf = ds.Z_Forward
    tf.spym.plane()
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return np.rot90(tf.data, -1)

def raw_data(df, date, name):
    """return raw data of image, rotated to match heatmap plot"""
    file = df[(df['Date'] == date) & (df['TF0_Filename'] == name )]['ImageOriginalName'].item()
    data = omicronscala.to_dataset(Path(path+file)).Z_Forward.data
    return np.rot90(data,-1)

def final_data(df, date, name):
    """return aligned data of image in xarray format"""
    file = df[(df['Date'] == date) & (df['TF0_Filename'] == name )]['ImageOriginalName'].item()
    ds = omicronscala.to_dataset(Path(path+file))
    tf = ds.Z_Forward
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return tf

def compare_data(df, date, name, path=None, fname=None):
    """plot same image as row, aligned and final, that is same as aligned just in heatmap format instead of lines"""
    plt.ioff()
    fig, axs = plt.subplots(1, 3, figsize=(18, 8))
    fig.suptitle('[{}] {}'.format(date, name), fontsize=16)
    raw = rawData(df, date, name)
    aligned = aligned_data(df, date, name)
    final = finalData(df, date, name)
    axs[0].plot(raw)
    axs[0].set_title('raw')
    axs[1].plot(aligned)
    axs[1].set_title('aligned')
    final.plot(ax=axs[2], cmap='afmhot', add_colorbar=False )
    axs[2].set_title('final')
    plt.tight_layout(rect=[0, 0.03, 1, 0.90])
    plt.draw()
    
    if path is None:
        Path('compareData').mkdir(parents=True, exist_ok=True)
    else:
        Path('{}/compareData'.format(path)).mkdir(parents=True, exist_ok=True)

    if fname is None:
        plt.savefig(Path('compareData/[{}] {}.png'.format(date, name)), dpi=100)
    else:
        plt.savefig(Path('{}/compareData/{}_[{}] {}.png'.format(path, fname, date, name)), dpi=100)
    plt.close(fig)


def plot_samples(df, name, condition, filename, rows, cols):
    images = []
    plt.ioff()
    try:
        samples = df.sample(rows*cols)
    except Exception as e:
        print(e)
        samples = df

    path = 'G:\STM'
    for _, image in samples.iterrows():
        try:
            images.append(show_img(path,image))
        except Exception as e:
            print(e)
            print(image['ImageOriginalName'])   
    
    fig, axs = plt.subplots(rows, cols, figsize=(2+(8*cols),(8*rows)))
    fig.suptitle('Samples with {} {}'.format(name, condition), weight='bold', fontsize=30)
    c = 0
    for i in range(rows):
        for j in range(cols):
            if c < len(images):
                images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                axs[i,j].set_title('[{}] {}'.format(images[c][0]['Date'],images[c][0]['TF0_Filename']), weight='bold', fontsize=20)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(12)
            else:
                axs[i,j].axis('off')
            c +=1
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    Path('samples').mkdir(parents=True, exist_ok=True)
    plt.savefig('samples/{}_{}x{}.png'.format(filename, rows, cols), dpi=100)
    plt.close(fig)

def show_slices_aligned_final(df, date, name, cut):
    """Only aligned and final subplots but every subplot is divided in slices of size = cut, 
       to better show data patterns in strange images"""
    plt.ioff()

    final = finalData(df, date, name)
    data = aligned_data(df, date, name)
    size = final.shape[0]
    rows = size//cut
    
    fig, axs = plt.subplots(rows, 2, figsize=(16,16), sharex='col')
    fig.suptitle('[{}] {}'.format(date, name), fontsize=16)

    for i in range(rows):
        test = final[size-(i+1)*cut:size-i*cut,:]
        test.plot(ax=axs[i,1], cmap='afmhot', add_colorbar=False)
        x = data[:, (i*cut):(i+1)*cut]
        axs[i,0].plot(x, linewidth=0.5)
    plt.tight_layout(rect=[0, 0.03, 1, 0.90])
    plt.figtext(0.5,0.95, "                           Aligned                                                                                                                              Final                           ", ha="center", va="top", fontsize="14")
    plt.draw()
    
    Path('slices/AlignedFinal/{}'.format(cut)).mkdir(parents=True, exist_ok=True) 
    
    plt.savefig(Path('slices/AlignedFinal/{}/[{}] {}.png'.format(cut, date, name)), dpi=100)
    plt.close(fig)
    
def show_slices_aligned_final_median(df, date, name, cut):
    """Same as show_slices_aligned_final but shown median of row for every cut,
       without sharex='col' on subplots """
    plt.ioff()

    final = finalData(df, date, name)
    data = aligned_data(df, date, name)
    size = final.shape[0]
    rows = size//cut
    
    fig, axs = plt.subplots(rows, 2, figsize=(16,16))
    fig.suptitle('[{}] {}'.format(date, name), fontsize=16)

    for i in range(rows):
        test = final[size-(i+1)*cut:size-i*cut,:]
        test.plot(ax=axs[i,1], cmap='afmhot', add_colorbar=False)
        x = data[:, (i*cut):(i+1)*cut]
        axs[i,0].plot(x, linewidth=0.5)
        y = np.median(x, axis=1)
        axs[i,0].plot(y, linewidth=2, color='red')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.90])
    plt.figtext(0.5,0.95, "                           Aligned                                                                                                                              Final                           ", ha="center", va="top", fontsize="14")
    plt.draw()
    
    Path('slices/Test/{}'.format(cut)).mkdir(parents=True, exist_ok=True) 
    
    plt.savefig(Path('slices/Test/{}/[{}] {}.png'.format(cut, date, name)), dpi=100)
    plt.close(fig)

def show_slices_full(df, date, name, cut):
    """Same as compare_data but every subplot is divided in slices of size = cut, 
       to better show data patterns in strange images"""
    plt.ioff()

    final = finalData(df, date, name)
    data = aligned_data(df, date, name)
    raw = rawData(df, date, name)
    size = final.shape[0]
    rows = size//cut
    
    fig, axs = plt.subplots(rows, 3, figsize=(18,18), sharex='col')
    fig.suptitle('[{}] {}'.format(date, name), fontsize=16)

    for i in range(rows):
        test = final[size-(i+1)*cut:size-i*cut,:]
        test.plot(ax=axs[i,2], cmap='afmhot', add_colorbar=False)
        x1 = data[:, (i*cut):(i+1)*cut]
        axs[i,1].plot(x1, linewidth=0.5)
        x0 = raw[:, (i*cut):(i+1)*cut]
        axs[i,0].plot(x0, linewidth=0.5)
    plt.tight_layout(rect=[0, 0.03, 1, 0.90])
    plt.figtext(0.5,0.95, "                           Raw                                                                                         Aligned                                                                                       Final                           ", ha="center", va="top", fontsize="14")
    plt.draw()
    
    Path('slices/Full/{}'.format(cut)).mkdir(parents=True, exist_ok=True) 
    
    plt.savefig(Path('slices/Full/{}/[{}] {}.png'.format(cut, date, name)), dpi=100)
    plt.close(fig)

In [None]:
def get_img_by_id(df, img_id):
    """like show img but using ID, and returning row,plot without using a list """
    img = df[(df['ID'] == img_id)]
    file = img.ImageOriginalName.item()
    ds = omicronscala.to_dataset(Path(path+file))
    tf = ds.Z_Forward
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return img, tf

def get_b_img_by_id(df, img_id):
    """like get_img_by_id but for tb0 images"""
    img = df[(df['ID'] == img_id)]
    file = img.ImageOriginalName.item()
    file = file[:-3] + 'tb0'
    ds = omicronscala.to_dataset(Path(path+file))
    tf = ds.Z_Forward
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return img, tf

def plot_img(row, img):
    """plot image directly with category and name as title"""
    category = row.Categories.item()
    name = row.TF0_Filename.item()
        
    # Generate plot
    fig = plt.figure(figsize=(16,16))
    axis = fig.add_subplot(1, 1, 1)
    img.plot(ax=axis, cmap='afmhot', add_colorbar=False)
    axis.set_title("[{}] {}".format(category, name), fontsize=16)
    plt.show()
    

In [None]:
def get_stripes(data, size):
    """get outliers from image raw data which should be stripes"""

    #image data
    x = data

    #median vector
    v = np.median(x, axis=1)

    #adding one dimension
    mat_v = v[:, np.newaxis]

    #matrix composed of rows of median vector
    y = np.repeat(mat_v, size, axis=1)

    #matrix of distance between median and x
    mat_d = np.abs(y-x)

    #vector of row distances
    vd = mat_d.sum(axis=0)

    # vector of relative distance between row distances 
    vrd = np.abs(np.ediff1d(vd))

    #pad first row result with second row result
    vrd = np.insert(vrd, 0, vrd[0])

    # normalize in [0,1]
    vd = vd / vd.max()
    vrd = vrd / vrd.max()

    # vector of the product of these distances
    vpd = np.multiply(vd,vrd)

    # stripes are outliers 
    stripes = np.where(vpd > 2*np.std(vpd))

    return list(stripes[0])

    
def show_slices_stripes(df, date, name, cut, legend=False, path=None, filename=None):
    """Same as show_slices_aligned_final_median but show aligned data with median only on cuts  which contain stripes,
       without sharex='col' on subplots """
    plt.ioff()
    final = finalData(df, date, name)
    data = aligned_data(df, date, name)

    size = final.shape[0]
    rows = size//cut
    
    stripes = get_stripes(data,size)

    fig, axs = plt.subplots(rows, 2, figsize=(18,16))
    fig.suptitle('[{}] {}'.format(date, name), fontsize=16)

    for i in range(rows):
        f = False
        test = final[size-(i+1)*cut:size-i*cut,:]
        test.plot(ax=axs[i,1], cmap='afmhot', add_colorbar=False)
        x = data[:, (i*cut):(i+1)*cut]
        for j in range((i*cut),(i+1)*cut):
            if j in stripes:
                f = True
                axs[i,0].plot(x[:,j-(i*cut)], linewidth=0.5, label=str(j))
        if f:     
            y = np.median(x, axis=1)
            axs[i,0].plot(y, linewidth=1.5, color='red', label='median')
        if legend:
            axs[i,0].legend()

    plt.tight_layout(rect=[0, 0.03, 1, 0.90])
    plt.figtext(0.5,0.95,
                "                           Stripes                                                           \
                                                                                   Final                           ",
                ha="center", va="top", fontsize="14")
    plt.draw()
    
    if path is None:
        Path('stripes/Test/{}'.format(cut)).mkdir(parents=True, exist_ok=True)
    else:
        Path('{}'.format(path)).mkdir(parents=True, exist_ok=True)

    if filename is None:
        plt.savefig(Path('stripes/Test/{}/[{}] {}.png'.format(cut, date, name)), dpi=100)
    else:
        plt.savefig(Path('{}/{}_[{}] {}.png'.format(path, filename, date, name)), dpi=100)
    
    plt.close(fig) 

In [None]:
def save_pickle(obj, filename):
    """save data as pickle binary obj"""
    with open('{}.pkl'.format(filename), 'wb') as file:
        pickle.dump(obj, file)
        
def load_pickle(filename):
    """load data as pickle binary obj"""
    with open('{}.pkl'.format(filename), 'rb') as file:
        obj = pickle.load(file)
    return obj


In [None]:
def plot_months(df, path=path, fig_size=8, dpi=40):
    """plot all dataset, month by month, in plots of 10x10 images and remaining modulo"""
    sample_errors = []
    img_errors = []
    plt.ioff()
    df['months'] = [x[:7] for x in df['Date']]
    df = df.sort_values(by='months')
    months = df['months'].unique()
    for h, month in enumerate(months):
        rows=10
        cols=10
        imgs = df[df['months'] == month][:]
        print('[{}/{}] Plotting {} - {} images '.format(h, len(months), month, len(imgs) ))

        if len(imgs) == 0:
            print('There are no images of {}.'.format(month))
            continue

        if len(imgs) == 1:
            print('Only one image for {}.'.format(month))
            continue
        
        n = rows*cols
        
        f = len(imgs)%n
        p = len(imgs)//n
        
        for r in range(1,p+1):
            try:
                samples = imgs[(r-1)*n:r*n]
            except:
                sample_errors.append('{}_{}-{}'.format(month,int((r-1)*n),r*n))
                continue

            samples = samples.sort_values(by=['Date','Time'])
            images = []

            for _, image in samples.iterrows():
                try:
                    images.append(show_img(path,image))
                except :
                    img_errors.append(image['ID'])
                    pass


            if cols > len(images):
                cols = len(images)
                rows = 1
            else:
                rows = len(images) // cols
                if len(images) % cols != 0:
                    rows += 1

            print('images: {} cols: {}, rows: {}'.format(len(images), cols, rows))

            fig, axs = plt.subplots(rows, cols, figsize=((fig_size*cols),(fig_size*rows)))
            c = 0
            if rows > 1:
                for i in range(rows):
                    for j in range(cols):
                        try:
                            images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                            axs[i,j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],images[c][0]['TF0_Filename'], images[c][0]['ID']), fontsize=20)
                            for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                                  axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                                item.set_fontsize(fsize*2)
                        except:
                            axs[i,j].axis('off')
                            pass
                        c +=1
            else:
                for j in range(cols):
                    try:
                        images[c][1].plot(ax=axs[j], cmap='afmhot', add_colorbar=False )
                        axs[j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],images[c][0]['TF0_Filename'], images[c][0]['ID']), fontsize=20)
                        for item in ([axs[j].xaxis.label, axs[j].yaxis.label] +
                              axs[j].get_xticklabels() + axs[j].get_yticklabels()):
                            item.set_fontsize(fsize*2)
                    except:
                        axs[j].axis('off')
                        pass
                    c +=1
            plt.tight_layout()
            plt.draw()

            Path('test_months/{}'.format(month)).mkdir(parents=True, exist_ok=True)
            plt.savefig('test_months/{}/{}_{}.png'.format(month, int((r-1)*n), r*n), dpi=dpi)
            plt.close(fig)
            
        if f != 0:    
            try:
                samples = imgs[p*n:]
            except:
                sample_errors.append('{}_{}-{}'.format(month,int(p*n),len(imgs)))
                continue

            samples = samples.sort_values(by=['Date','Time'])
            images = []

            for _, image in samples.iterrows():
                try:
                    images.append(show_img(path,image))
                except:
                    img_errors.append(image['ID'])
                    pass


            if cols > len(images):
                cols = len(images)
                rows = 1
            else:
                rows = len(images) // cols 
                if len(images) % cols != 0:
                    rows += 1

            print('images: {} cols: {}, rows: {}'.format(len(images), cols, rows))

            fig, axs = plt.subplots(rows, cols, figsize=((fig_size*cols),(fig_size*rows)))
            c = 0
            if rows > 1:
                for i in range(rows):
                    for j in range(cols):
                        try:
                            images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                            axs[i,j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],images[c][0]['TF0_Filename'], images[c][0]['ID']), fontsize=20)
                            for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                                  axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                                item.set_fontsize(fsize*2)
                        except:
                            axs[i,j].axis('off')
                            pass
                        c +=1
            else:
                for j in range(cols):
                    try:
                        images[c][1].plot(ax=axs[j], cmap='afmhot', add_colorbar=False )
                        axs[j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],images[c][0]['TF0_Filename'], images[c][0]['ID']), fontsize=20)
                        for item in ([axs[j].xaxis.label, axs[j].yaxis.label] +
                              axs[j].get_xticklabels() + axs[j].get_yticklabels()):
                            item.set_fontsize(fsize*2)
                    except:
                        axs[j].axis('off')
                        pass
                    c +=1
            plt.tight_layout()
            plt.draw()

            Path('test_months/{}'.format(month)).mkdir(parents=True, exist_ok=True)
            plt.savefig('test_months/{}/{}_{}.png'.format(month, p*n, len(imgs), dpi=dpi))
            plt.close(fig)
    return sample_errors, img_errors

In [None]:
def months_barplot(df, year):
    """make monthly images distribution for selected year"""
    dfy = df.loc[df['Date'].str.startswith(str(year))].copy()
    dfy['Date'] = [x[:7] for x in dfy['Date']]
    months = dfy.groupby('Date').size().reset_index(name='N imgs')
    months.plot(kind='bar', x='Date', y='N imgs', title='STM Images distribution of {}'.format(year))

def months_barplot_full(df):
    """make monthly images distribution for all dataset, not colored by categories"""
    df['Date'] = [x[:7] for x in df['Date']]
    months = df.groupby('Date').size().reset_index(name='N imgs')
    months.plot(kind='bar', x='Date', y='N imgs', title='STM Images distribution of mixed category', figsize=(12,6))

In [None]:
def get_colormap_error(df, img_id):
    """return the percentage of colors lost due to outliers"""

    meta, img = get_img_by_id(df, img_id)
    data = img.data
    
    # max and min vectors of rows
    max_row = np.amax(data, axis=0)
    min_row = np.amin(data, axis=0)
    
    # median max and min for image
    min_median = np.median(min_row)
    max_median = np.median(max_row)

    #diff max/min of image
    diff = (max_row - min_row)
    
    for i in range(1,100):

        # treshold to define wrong image
        treshold = (max_median - min_median) * (255 * (1- (i * 0.01)))
        colormap_error = np.where(diff > treshold)[0]

        if len(colormap_error) > 0:
            return 100-i

    return 0


def get_df_colormap_error(df):
    """for every image return percentage of colormap error. Exceptions encoded as -1"""
    cmap_err_list = []
    for i in tqdm(range(1,len(df)+1)):
        try:
            err = get_colormap_error(df, i)
            cmap_err_list.append(err)
        except:
            cmap_err_list.append(-1)
    return cmap_err_list


def plot_df_colormap_error(cmap_err_list, df):
    """plot colormap error distribution of STM images"""
    colormap_error = pd.Series(cmap_err_list)
    df['Colormap Error'] = colormap_error
    w = df.groupby('Colormap Error').size().reset_index(name='N imgs')
    w.plot.scatter(x='Colormap Error', y='N imgs', logy=True, title="Colormap Error (%) distribution of STM images", figsize=(12,8))
    plt.savefig('colormap_error.png')

In [None]:
def mixed_images(df, month, rows=10, cols=10, path=path, fsize=8, dpi=40):
    """plot of images of mixed categories"""
    plt.ioff()
    df['months'] = [x[:7] for x in df['Date']]
    imgs = df[df['months'] == month][:]
    if len(imgs) == 0:
        return print('There are no images of {}.'.format(month))
    
    if len(imgs) == 1:
        return print('Only one image for {}.'.format(month))
    
    try:
        samples = imgs.sample(rows*cols)
    except:
        samples = imgs
        
    samples = samples.sort_values(by='Date')
    
    images = []
    for _, image in samples.iterrows():
        try:
            images.append(show_img(path,image))
        except Exception as e:
            print(e)
            print(image['ID'])
            
    if cols > len(images):
        cols = len(images)
        rows = 1
    else:
        rows = len(images) // cols     
    
    fig, axs = plt.subplots(rows, cols, figsize=((fsize*cols),(fsize*rows)))
    c = 0
    if rows > 1:
        for i in range(rows):
            for j in range(cols):
                images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                axs[i,j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],images[c][0]['TF0_Filename'], images[c][0]['ID']), fontsize=20)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(fsize*2)
                c +=1
    else:
        for j in range(cols):
            images[c][1].plot(ax=axs[j], cmap='afmhot', add_colorbar=False )
            axs[j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],images[c][0]['TF0_Filename'], images[c][0]['ID']), fontsize=20)
            for item in ([axs[j].xaxis.label, axs[j].yaxis.label] +
                  axs[j].get_xticklabels() + axs[j].get_yticklabels()):
                item.set_fontsize(fsize*2)
            c +=1
    plt.tight_layout()
    plt.draw()
    
    Path('mixed').mkdir(parents=True, exist_ok=True)
    plt.savefig('mixed/{}_f{}_dpi{}.png'.format(month, fsize, dpi), dpi=dpi)
    plt.close(fig)

In [None]:
def months_images(df, path=path, fsize=8, dpi=40):
    """plot a 10x10 image for every month, deprecated"""
    plt.ioff()
    df['months'] = [x[:7] for x in df['Date']]
    df = df.sort_values(by='months')
    months = df['months'].unique()
    for h, month in enumerate(months):
        rows=10
        cols=10
        imgs = df[df['months'] == month][:]
        print('[{}/{}] Plotting {} - {} images '.format(h, len(months), month, len(imgs) ))

        if len(imgs) == 0:
            print('There are no images of {}.'.format(month))
            continue

        if len(imgs) == 1:
            print('Only one image for {}.'.format(month))
            continue

        try:
            samples = imgs.sample(rows*cols)
        except:
            samples = imgs
            
        category = imgs.iloc[0][:].Categories
        samples = samples.sort_values(by='Date')
        images = []
        
        for _, image in samples.iterrows():
            try:
                images.append(show_img(path,image))
            except Exception as e:
                print(e)
                print(image['ID'])
                pass

        
        if cols > len(images):
            cols = len(images)
            rows = 1
        else:
            rows = len(images) // cols     

        print('images: {} cols: {}, rows: {}'.format(len(images), cols, rows))
        
        fig, axs = plt.subplots(rows, cols, figsize=((fsize*cols),(fsize*rows)))
        c = 0
        if rows > 1:
            for i in range(rows):
                for j in range(cols):
                    images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                    axs[i,j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],images[c][0]['TF0_Filename'], images[c][0]['ID']), fontsize=20)
                    for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                          axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                        item.set_fontsize(fsize*2)
                    c +=1
        else:
            for j in range(cols):
                images[c][1].plot(ax=axs[j], cmap='afmhot', add_colorbar=False )
                axs[j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],images[c][0]['TF0_Filename'], images[c][0]['ID']), fontsize=20)
                for item in ([axs[j].xaxis.label, axs[j].yaxis.label] +
                      axs[j].get_xticklabels() + axs[j].get_yticklabels()):
                    item.set_fontsize(fsize*2)
                c +=1
        plt.tight_layout()
        plt.draw()

        Path('months/{}'.format(category)).mkdir(parents=True, exist_ok=True)
        plt.savefig('months/{}/{}.png'.format(category, month), dpi=dpi)
        plt.close(fig)