In [None]:
%matplotlib inline

In [None]:
from fastai.vision import *
import pandas as pd
import numpy as np
from pathlib import Path
import omicronscala
import spym
import xarray
import os
import torch
import numpy as np
import random
from kmeans_pytorch import kmeans

In [None]:
def savePickle(obj, filename):
    with open('{}.pkl'.format(filename), 'wb') as file:
        pickle.dump(obj, file)
        
def loadPickle(filename):
    with open('{}.pkl'.format(filename), 'rb') as file:
        obj = pickle.load(file)
    return obj

In [None]:
def show_img(path, img):
    file = img['ImageOriginalName']
    ds = omicronscala.to_dataset(Path(path+file))
    tf = ds.Z_Forward
    tf.spym.plane()
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return [img, tf]

In [None]:
#path to original imgs
path = 'path_to_imgs'

#load stm metadata df
stm = loadPickle('clean_stm')

#load DPA cluster labels
labels = np.load("labels.npy")

#load features dataframe
features = loadPickle("df_features")

#add cluster labels
features['dpa'] = labels
features.head()

In [None]:
def get_cluster_ids(df, features_df):
    " create a dictionary that store cluster number as keys and list of images IDs as values"
    clusters = {}
    n_clusters = len(features_df['dpa'].unique())
    for i in range(n_clusters):
        tmp = (features[features['dpa']==i])
        clusters[i] = list(map(int,tmp.index))
    return clusters      

In [None]:
def save_cluster_plots(df, clusters, fname, N=100, rows=10, cols=10):
    "save cluster samples plots in ./cluster/fname/ folder. Takes N randomly choosen ids from cluster dictionary"
    for k,v in clusters.items():
        IDs = random.sample(v,N)
        samples = df.loc[IDs]
        images = []
        for _, image in samples.iterrows():
            try:
                images.append(show_img(path,image))
            except Exception as e:
                print(e)
                print(image['ImageOriginalName'])   

        fig, axs = plt.subplots(rows, cols, figsize=(2+(8*cols),(8*rows)))
        fig.suptitle('Samples of cluster {}'.format(k), weight='bold', fontsize=30)
        c = 0
        for i in range(rows):
            for j in range(cols):
                if c < len(images):
                    images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                    axs[i,j].set_title('[{}] {}'.format(images[c][0]['Date'],images[c][0]['TF0_Filename']), weight='bold', fontsize=20)
                    for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                          axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                        item.set_fontsize(12)
                else:
                    axs[i,j].axis('off')
                c +=1
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        Path('clusters/{}'.format(fname)).mkdir(parents=True, exist_ok=True)
        plt.savefig('clusters/{}/{}.png'.format(fname, k), dpi=40)
        plt.close(fig)

In [None]:
def plot_cluster_samples(df, clusters, ID, N=100, rows=10, cols=10):
    "show plot of N images of cluster ID."
    IDs = random.sample(clusters[ID],N)
    samples = df.loc[IDs]
    images = []
    for _, image in samples.iterrows():
        try:
            images.append(show_img(path,image))
        except Exception as e:
            print(e)
            print(image['ImageOriginalName'])   

    fig, axs = plt.subplots(rows, cols, figsize=(2+(8*cols),(8*rows)))
    fig.suptitle('Samples of cluster {}'.format(ID), weight='bold', fontsize=30)
    c = 0
    for i in range(rows):
        for j in range(cols):
            if c < len(images):
                images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                axs[i,j].set_title('[{}] {}'.format(images[c][0]['Date'],images[c][0]['TF0_Filename']), weight='bold', fontsize=20)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(12)
            else:
                axs[i,j].axis('off')
            c +=1
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    plt.close(fig)

In [None]:
cl = get_cluster_ids(stm, features)

In [None]:
save_cluster_plots(stm,cl,"test",25,5,5)

In [None]:
plot_cluster_samples(stm,cl,3,25,5,5)