In [None]:
%matplotlib inline

In [None]:
from fastai.vision import *
import pandas as pd
import numpy as np
from pathlib import Path
import omicronscala
import spym
import xarray
import os
import torch
import numpy as np
import random
from kmeans_pytorch import kmeans

In [None]:
def save_pickle(obj, filename):
    with open('{}.pkl'.format(filename), 'wb') as file:
        pickle.dump(obj, file)
        
def load_pickle(filename):
    with open('{}.pkl'.format(filename), 'rb') as file:
        obj = pickle.load(file)
    return obj

In [None]:
def get_img(imgs_path, img):
    """get row of image from df, return list of [row,plot] for that image"""
    file = img['ImageOriginalName']
    ds = omicronscala.to_dataset(Path(imgs_path+file))
    tf = ds.Z_Forward
    tf.spym.plane()
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return [img, tf]

In [None]:
def get_cluster_ids(features_df):
    """ create a dictionary that store cluster number as keys and list of images IDs as values"""
    clusters = {}
    n_clusters = len(features_df['dpa'].unique())
    for i in range(n_clusters):
        tmp = (features_df[features_df['dpa']==i])
        clusters[i] = list(map(int,tmp.index))
    return clusters      

In [None]:
def num_clusters(clusters, cutoff):
    """print number of images for each cluster and total number of clusters that have N images > cutoff"""
    c = 0
    for k,v in clusters.items():
        if len(v) >= cutoff:
            print(k,":",len(v))
            c += 1
    print("Total:",c)

In [None]:
def save_cluster_plots(df, clusters, filename, cutoff, n=100, rows=10, cols=10):
    """save cluster samples plots in ./cluster/filename/ folder. Takes N randomly chosen ids from cluster dictionary"""
    for k,v in clusters.items():
        if len(v) < cutoff:
            continue
        else:
            ids = random.sample(v,n)
            samples = df.loc[ids]
            images = []
            for _, image in samples.iterrows():
                try:
                    images.append(show_img(path,image))
                except Exception as e:
                    print(e)
                    print(image['ImageOriginalName'])   

            fig, axs = plt.subplots(rows, cols, figsize=(2+(8*cols),(8*rows)))
            fig.suptitle('Cluster {}, total imgs: {}'.format(k,len(v)), weight='bold', fontsize=30)
            c = 0
            for i in range(rows):
                for j in range(cols):
                    if c < len(images):
                        images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                        axs[i,j].set_title('[{}] {}'.format(images[c][0]['Date'],images[c][0]['TF0_Filename']), weight='bold', fontsize=20)
                        for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                              axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                            item.set_fontsize(12)
                    else:
                        axs[i,j].axis('off')
                    c +=1
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            Path('clusters/{}'.format(filename)).mkdir(parents=True, exist_ok=True)
            plt.savefig('clusters/{}/{}.png'.format(filename, k), dpi=40)
            plt.close(fig)

In [None]:
def save_cluster_samples(df, clusters, img_id, filename, n=100, rows=10, cols=10):
    """show plot of N images of cluster ID."""
    ids = random.sample(clusters[img_id],n)
    samples = df.loc[ids]
    images = []
    for _, image in samples.iterrows():
        try:
            images.append(show_img(path,image))
        except Exception as e:
            print(e)
            print(image['ImageOriginalName'])   

    fig, axs = plt.subplots(rows, cols, figsize=(2+(8*cols),(8*rows)))
    fig.suptitle('Cluster {}, total imgs: {}'.format(ID,len(clusters[img_id])), weight='bold', fontsize=30)
    c = 0
    for i in range(rows):
        for j in range(cols):
            if c < len(images):
                images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                axs[i,j].set_title('[{}] {}'.format(images[c][0]['Date'],images[c][0]['TF0_Filename']), weight='bold', fontsize=20)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(12)
            else:
                axs[i,j].axis('off')
            c +=1
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    Path('clusters/{}'.format(filename)).mkdir(parents=True, exist_ok=True)
    plt.savefig('clusters/{}/{}.png'.format(filename, img_id), dpi=40)
    plt.close(fig)

In [None]:
def plot_cluster_samples(df, clusters, img_id, n=100, rows=10, cols=10):
    """show plot of N images of cluster ID."""
    ids = random.sample(clusters[img_id],n)
    samples = df.loc[ids]
    images = []
    for _, image in samples.iterrows():
        try:
            images.append(show_img(path,image))
        except Exception as e:
            print(e)
            print(image['ImageOriginalName'])   

    fig, axs = plt.subplots(rows, cols, figsize=(2+(8*cols),(8*rows)))
    fig.suptitle('Cluster {}, total imgs: {}'.format(img_id,len(clusters[img_id])), weight='bold', fontsize=30)
    c = 0
    for i in range(rows):
        for j in range(cols):
            if c < len(images):
                images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                axs[i,j].set_title('[{}] {}'.format(images[c][0]['Date'],images[c][0]['TF0_Filename']), weight='bold', fontsize=20)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(12)
            else:
                axs[i,j].axis('off')
            c +=1
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    plt.close(fig)

In [None]:
def load_features(features_file, labels_file):
    """load features dataframe with DPA labels as a new column"""
    labels = np.load(labels_file)
    features_df = load_pickle(features_file)
    features_df['dpa'] = labels
    return features_df

In [None]:
def get_cluster_features(features_df, labels, clusters, img_id):
    """get features dataframe for a single cluster of a DPA clustering. Used for nested clustering"""
    ids = [str(x) for x in clusters[img_id]]
    cluster_features = features_df.loc[ids]
    labels = np.load(labels)
    print("Imgs:{}\tclusters:{}".format(len(labels),len(set(labels))))
    cluster_features['dpa'] = labels
    return cluster_features

In [None]:
#path to original imgs
path = 'path_to_imgs'

#load stm metadata df
stm = load_pickle('df_stm')

#load features dataframe and DPA cluster labels
features = load_features("df_features","labels.npy")

# get image IDs for each cluster
cl = get_cluster_ids(features)

# show N imgs for each cluster with more than 500 imgs
num_clusters(cl,500)

#show plot of cluster 10,taking 100 random images in a 10x10 grid
plot_cluster_samples(stm,cl,10,100,10,10)

# save plot of cluster 10, taking 100 random images in a 10x10 grid
save_cluster_samples(stm,cl,10,"test_plots",100,10,10)

# save plots of each cluster with more than 500 imgs, taking 100 random images in a 10x10 grid
save_cluster_plots(stm,cl,"test_plots",500,100,10,10)

In [None]:
# load features of cluster 8 (nested DPA)
cl8_features = get_cluster_features(features,"nested_labels.npy",cl,8)

# get image IDs for each cluster
cl8 = get_cluster_ids(cl8_features)

# show N imgs for each cluster with more than 100 imgs
num_clusters(cl8,100)

# save plot of cluster 1, taking 16 random images in a 4x4 grid
save_cluster_samples(stm,cl8,1,"test_plots",16,4,4)

# save plots of each cluster with more than 100 imgs, taking 25 random images in a 5x5 grid
save_cluster_plots(stm,cl8,"test_plots",100,25,5,5)