In [None]:
%matplotlib inline

In [None]:
from fastai.vision import *
import pandas as pd
import numpy as np
from pathlib import Path
import omicronscala
import spym
import xarray
import os
import torch
import matplotlib.plt as plt
from kmeans_pytorch import kmeans

In [None]:
def save_pickle(obj, filename):
    with open('{}.pkl'.format(filename), 'wb') as file:
        pickle.dump(obj, file)
        
def load_pickle(filename):
    with open('{}.pkl'.format(filename), 'rb') as file:
        obj = pickle.load(file)
    return obj

In [None]:
# credits https://github.com/aayushmnit/Deep_learning_explorations
class Hook:
    """Create a hook on `m` with `hook_func`."""
    def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
        self.hook_func,self.detach,self.stored = hook_func,detach,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module:nn.Module, input_tensor:Tensors, output_tensor:Tensors):
        """Applies `hook_func` to `module`, `input_tensor`, `output_tensor`."""
        if self.detach:
            input_tensor  = (o.detach() for o in input_tensor ) if is_listy(input_tensor) else input_tensor.detach()
            output_tensor = (o.detach() for o in output_tensor) if is_listy(output_tensor) else output_tensor.detach()
        self.stored = self.hook_func(module, input_tensor, output_tensor)

    def remove(self):
        """Remove the hook from the model."""
        if not self.removed:
            self.hook.remove()
            self.removed=True

    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()
        
def get_output(output):
    return output.flatten(1)

def get_input(input_value):
    return list(input_value)[0]

def get_named_module_from_model(nn_model, name):
    for n, m in nn_model.named_modules():
        if n == name:
            return m
    return None

In [None]:
def get_dataframe():
    """create DataFrame for ImageList loader"""
    df = load_pickle('clean_stm')
    df['path'] = [ 'data/train/{}.png'.format(x) for x in df.index.values ]
    random = []
    for category in df['Categories'].unique():
        if len(category.split(',')) > 1 or category == '':
            random.append(category) 
    category_label = df.Categories.astype("category").cat.codes
    category = df['Categories']
    df['category_label'] = category_label
    df['dataset'] = 'train'
    df['category'] = category
    df['is_valid'] = [ True if x.Categories == 'mixed' else False for _,x in df.iterrows()]
    df = df[['path', 'category_label', 'dataset', 'category', 'is_valid']][:]
    return df

def get_dict_category_labels(df):
    """create dictionary that maps categories with their labels"""
    tmp = df.groupby(['category','category_label']).size().reset_index().rename(columns={0:'count'})
    categories = tmp['category'].to_list()
    labels = tmp['category_label'].to_list()
    return dict(zip(labels,categories))

def get_imgs_data(df, imgs_path):
    imgs_list = ImageList.from_df(df=df, path=imgs_path, cols=['path']).split_from_df(col='is_valid').label_from_df(cols='category_label')
    transforms = get_transforms()
    imgs = data_source.transform(transforms, size=224).databunch(bs=32).normalize(imagenet_stats)
    return imgs_list, imgs

In [None]:
def get_img_features_df(layer, dataloader):
    dict_features = {}

    with Hook(layer, get_input, True, True) as hook:
        for i, (xb, yb) in enumerate(dataloader):
            bs = xb.shape[0]
            if bs != 32:
                img_ids = dataloader.items[-bs:]
            else:
                img_ids = dataloader.items[i*bs:(i+1)*bs]
            model.eval()(xb)
            features = hook.stored.cpu().numpy()
            features = features.reshape(bs, -1)
            for img_id, img_repr in zip(img_ids, features):
                dict_features[img_id] = img_repr
    
    features_df = pd.DataFrame(dict_features.items(), columns=['img_path', 'img_repr'])
    features_df['ID'] = [ x.split('/')[-1].split('.')[0] for x in features_df['img_path'] ]
    features_df.set_index('ID', inplace=True)
    features_df['label'] = [inference_data.classes[x] for x in inference_data.train_ds.y.items[0:features_df.shape[0]]]
    features_df['label_id'] = inference_data.train_ds.y.items[0:features_df.shape[0]]
    return features_df

def torch_save(features_df, filename):
    img_ft_df = features_df.copy()
    img_ft_df['ID'] = [ x.split('/')[-1].split('.')[0] for x in features_df['img_path'] ]
    img_ft_df.set_index('ID', inplace=True)
    img_ft_df.drop(columns=['label', 'label_id'], inplace=True)
    tmp = img_ft_df['img_repr'].to_numpy()
    tmp = np.stack(tmp)
    x = torch.from_numpy(tmp)
    torch.save(x, '{}'.format(filename))

In [None]:
# prepare dataframe
data_df = get_dataframe()

# mapping categories -> labels
dict_category_labels = get_dict_category_labels(data_df)
save_pickle(dict_category_labels,'labels_dict')

#data folder path
images_path = Path('path_to_data_folder')

# ImageLoader, Images
data_source, data = get_imgs_data(data_df, images_path)

#get resnet pretrained on imagenet
learner = cnn_learner(data, models.resnet50, pretrained=True)
model = learner.model

#select layer for feature extraction
linear_output_layer = get_named_module_from_model(model, '1.4')

#prepare images and dataloader
inference_data = data_source.transform(tmfs, size=224).databunch(bs=32).normalize(imagenet_stats)
inference_dataloader = inference_data.train_dl.new(shuffle=False,drop_last=False)

# get features df
img_repr_df = get_img_features_df(linear_output_layer, inference_dataloader)
save_pickle(img_repr_df, "stm_features_df")

# simpler df
df_features = img_repr_df[['img_repr', 'label']]
save_pickle(df_features, "stm_df_features")
torch_save(img_repr_df, 'S_features_resnet50_4096')

# categories distribution
len_dict = {}
for k,v in dict_category_labels.items():
    ldf = len(img_repr_df[img_repr_df['label']== k])/len(img_repr_df)
    len_dict[v] = ldf

In [None]:
import time
from scipy.spatial.distance import cosine
from scipy.spatial.distance import euclidean

def get_similar_images(features_df, img_id, n=10):
    img_id, features, label, _  = img_repr_df.loc[str(img_id)][:]
    cosine_similarity = 1 - features_df['img_repr'].apply(lambda x: cosine(x, features))
    similar_img_ids = np.argsort(cosine_similarity)[-n-1:-1][::-1]
    return img_id, label, features_df.iloc[similar_img_ids]

def get_similar_images_euclidean(features_df, img_id, n=10):
    img_id, features, label, _  = img_repr_df.loc[str(img_id)][:]
    similarity = features_df['img_repr'].apply(lambda x: euclidean(x, features))
    similar_img_ids = np.argsort(similarity)[::][:n]
    return img_id, label, features_df.iloc[similar_img_ids]

def show_similar_images(features_df):
    images = [open_image(img_id) for img_id in features_df['img_path']]
    categories = [learner.data.train_ds.y.reconstruct(y) for y in features_df['label_id']]
    return learner.data.show_xys(images, categories)

def cosine_similarity_analysis(df, features_df, category_labels, trials=100, n_imgs=100):
    analysis = {}
    categories = df['label'].unique()
    for category in categories:
        start = time.time()
        list_ids = []
        dataframes = []
        results_dict = {}
        size = len(df[df['label']== category])
        if size < trials:
            print('skipping {}: requested {} samples out of {} images'.format(dict_category_labels[category], trials, size))
            continue
        tmp_df = df[df['label']== category].sample(trials)
        for i, row in tmp_df.iterrows():
            image, label, nn_features = get_similar_images(features_df, i, n_imgs)
            list_ids.append(i)
            dataframes.append(nn_features)
            results = nn_features.groupby('label').size().reset_index(name='N imgs').sort_values(by='N imgs', ascending=False)
            results['label'] = [ category_labels[x] for x in results['label'] ]
            labels = results['label'].to_list()
            n_imgs = results['N imgs'].to_list()
            for l,n in zip(labels, n_imgs):
                if l not in results_dict.keys():
                    results_dict[l] = n
                else:
                    results_dict[l] += n
        label = category_labels[category]
        analysis[category] = {'ID': list_ids, 'dfs': dataframes, 'res': results_dict, 'label': label, 'len': size}
        end = time.time()
        print(f'{end - start} secs')
    return analysis

def euclidean_stats(df, trials=100, n_imgs=100):
    analysis = {}
    categories = df['label'].unique()
    for category in categories:
        start = time.time()
        list_ids = []
        dataframes = []
        results_dict = {}
        size = len(df[df['label']== category])
        if size < trials:
            print('skipping {}: requested {} samples out of {} images'.format(dict_category_labels[category], trials, size))
            continue
        tmp_df = df[df['label']== category].sample(trials)
        for i, row in tmp_df.iterrows():
            image, label, nn_features = get_similar_images_euclidean(features_df, i, n_imgs)
            list_ids.append(i)
            dataframes.append(nn_features)
            results = nn_features.groupby('label').size().reset_index(name='N imgs').sort_values(by='N imgs', ascending=False)
            results['label'] = [ category_labels[x] for x in results['label'] ]
            labels = results['label'].to_list()
            n_imgs = results['N imgs'].to_list()
            for l,n in zip(labels, n_imgs):
                if l not in results_dict.keys():
                    results_dict[l] = n
                else:
                    results_dict[l] += n
        label = category_labels[category]
        analysis[category] = {'ID': list_ids, 'dfs': dataframes, 'res': results_dict, 'label': label, 'len': size}
        end = time.time()
        print(f'{end - start} secs')
    return analysis

In [None]:
def is_same_date(new_date, start_date):
    if new_date == start_date:
        return True
    return False

def is_same_offset(xoff,yoff,start_xoff, start_yoff,rounded=False):
    if rounded: 
        if (int(xoff) == int(start_xoff)) and (int(yoff) == int(start_yoff)):
            return True
    else:
        if (xoff == start_xoff) and (yoff == start_yoff):
            return True
    return False

def filter_ids(df, df_ids, img_id, check_off=False, rounded=False):
    good_imgs = [int(img_id)]
    for i in df_ids.index.tolist():
        x,y,d = df.loc[int(i)][["XOffset","YOffset","Date"]].tolist()
        is_good = True
        for j in good_imgs:
            xj,yj,dj=df.loc[int(j)][["XOffset","YOffset","Date"]].tolist()
            if is_same_date(d,dj):
                if check_off:
                    if is_same_offset(x,y,xj,yj,rounded):
                        is_good = False
                else:
                    is_good = False
        if is_good:
            good_imgs.append(int(i))
    if len(good_imgs)>25:
        good_imgs = good_imgs[1:25]
    else:
        good_imgs = good_imgs[1:]
    return good_imgs
    

In [None]:
def plot_cosine(df, imgs_path, img_id, list_ids, fig_size=8, dpi=40):
    start_img = get_img_by_id(df,imgs_path,int(img_id))
    imgs = df.loc[df.index.intersection(list_ids)]
    images = []
    for i, image in imgs.iterrows():
        try:
            images.append(get_img(imgs_path, image))
        except:
            print(i)
    plt.ioff()
    rows=5
    cols=5
    figure, axs = plt.subplots(rows, cols, figsize=((fig_size*cols),(fig_size*rows)))
    c = 0
    for i in range(rows):
        for j in range(cols):
            if (i==2) and (j==2):
                start_img[1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                axs[i,j].set_title(r"[{}] {} $\bf{{{}}}$".format(start_img[0]['Date'],
                                                                 start_img[0]['TF0_Filename'],
                                                                 start_img[0].name), fontsize=20)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(fsize*2)
            else:
                if c < len(images):
                    images[c][1].plot(ax=axs[i,j], cmap='afmhot', add_colorbar=False )
                    axs[i,j].set_title(r"[{}] {} $\bf{{{}}}$".format(images[c][0]['Date'],
                                                                     images[c][0]['TF0_Filename'],
                                                                     images[c][0].name), fontsize=20)
                    for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                          axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                        item.set_fontsize(fsize*2)
                    c +=1
    plt.tight_layout()
    plt.draw()
    Path('cosine/{}'.format(start_img[0]['Categories'])).mkdir(parents=True, exist_ok=True)
    plt.savefig('cosine/{}/{}.png'.format(start_img[0]['Categories'],start_img[0].name), dpi=dpi)
    plt.close(fig)

In [None]:
def get_img(imgs_path, img):
    """get row of image from df, return list of [row,plot] for that image"""
    file = img['ImageOriginalName']
    ds = omicronscala.to_dataset(Path(imgs_path+file))
    tf = ds.Z_Forward
    tf.spym.plane()
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return [img, tf]

def get_img_by_id(df, imgs_path, img_id):
    img = df.loc[img_id]
    file = img.ImageOriginalName
    ds = omicronscala.to_dataset(Path(imgs_path+file))
    tf = ds.Z_Forward
    tf.spym.plane()
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    return [img, tf]

In [None]:
def plot_cosine_results(stats):
    plt.ioff()
    rows = 4
    cols = 4
    figure, axs = plt.subplots(rows, cols, figsize=(2+(10*cols),10*rows))
    figure.suptitle('Cosine similarity for 100 trials of 100 images for each category', fontsize=36)
    list_ids = list(stats.keys())
    c = 0
    for i in range(rows):
        for j in range(cols):
            if c < len(list_ids):
                tmp_id = list_ids[c]
                labels = list(stats[tmp_id]['res'].keys())
                n_imgs = list(stats[tmp_id]['res'].values())
                total = max(n_imgs)
                n_imgs = [x/total for x in n_imgs]
                label = stats[tmp_id]['label']
                color = ['g' if x == label else 'b' for x in labels ]
                axs[i,j].bar(labels, n_imgs, color=color)
                axs[i,j].set_title('{}'.format(label), fontsize=28)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(16)
                axs[i,j].set_ylabel('N images')
                for tick in axs[i,j].get_xticklabels():
                    tick.set_rotation(90)
                c += 1
            else:
                axs[i,j].axis('off')
                c += 1
                
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig.show()
    fig.savefig('cosine_S_100x100.png')

    
def plot_analysis_results_norm(stats, prefix, trials, n_imgs, len_label_dict):
    plt.ioff()
    rows = 4
    cols = 4
    figure, axs = plt.subplots(rows, cols, figsize=(2+(10*cols),10*rows))
    figure.suptitle('{} similarity for {} trials of {} images for each category'.format(prefix, trials, n_imgs), fontsize=36)
    list_ids = list(stats.keys())
    c = 0
    for i in range(rows):
        for j in range(cols):
            if c < len(list_ids):
                tmp_id = list_ids[c]
                labels = list(stats[tmp_id]['res'].keys())
                n_imgs = list(stats[tmp_id]['res'].values())
                total = [ len_label_dict[x] for x in labels]
                size = sum(n_imgs)
                n_imgs = [x/size for x in n_imgs]
                label = stats[tmp_id]['label']
                color = ['g' if x == label else 'b' for x in labels ]
                interval = np.arange(len(labels))
                w = 0.3
                axs[i,j].bar(interval+0.0, n_imgs, w, color=color)
                axs[i,j].bar(interval+0.3, total, w, color='r')
                axs[i,j].set_xticks(interval)
                axs[i,j].set_xticklabels(labels)
                axs[i,j].set_title('{}'.format(label), fontsize=28)
                for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label] +
                      axs[i,j].get_xticklabels() + axs[i,j].get_yticklabels()):
                    item.set_fontsize(16)
                axs[i,j].set_ylabel('N images')
                for tick in axs[i,j].get_xticklabels():
                    tick.set_rotation(90)
                c += 1
            else:
                axs[i,j].axis('off')
                c += 1
                
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig.show()
    fig.savefig('{}_S_{}x{}.png'.format(prefix, trials, n_imgs))
    

In [None]:
#path to original imgs
path = 'path_to_images'

#load stm metadata df
stm = load_pickle('clean_stm')

# good images manually selected
goods_dict = { "N_Gr_Ni111": [87980, 87931, 88019, 87795, 84551, 84568, 87048, 83206, 
87912, 87774], "Gr_Ni111" : [85804, 83502, 83080, 83018, 83005, 82729, 82061, 50736, 50701, 
49062], "Gr_Ni100": [77857, 77795, 77809, 77690, 77696, 79863, 77649, 77626, 79779, 
79729],"NFFA_ID617": [86169, 86687, 86374, 83699, 83734, 84700, 83880, 84133, 
85763, 85800]}

# similarity search with filtering for each image
for k,v in goods_dict.items():
    print('\n{}'.format(k))
    for image_id in v:
        print('\t{}'.format(image_id))
        _, _, ids = get_similar_images(img_repr_df, image_id, 250)
        nn_ids = filter_ids(stm, ids, image_id, check_off=True, rounded=True)
        plot_cosine(stm, imgs_path, image_id, ID, nn_ids)

In [None]:
#single image similarity example
base_image, base_label, similar_images_df = get_similar_images_euclidean(img_repr_df, 44148, 100)
print(base_label)
print(base_image)
open_image(base_image)
show_similar_images(similar_images_df)

In [None]:
#features validation by statistical analysis on extracted images from similarity search

stats2 = cosine_similarity_analysis(img_repr_df, 500, 20)
plot_analysis_results_norm(stats2, 'cosine',500, 20, len_dict)
save_pickle(stats2, 'cosine_500_20')

stats3 = euclidean_stats(img_repr_df, 100, 100)
plot_analysis_results_norm(stats3, 'euclidean', 100, 100, len_dict)
save_pickle(stats3, 'euclidean_100_100')

stats4 = euclidean_stats(img_repr_df, 500, 20)
plot_analysis_results_norm(stats4, 'euclidean', 500, 20, len_dict)
save_pickle(stats4, 'euclidean_500_20')

In [None]:
from sklearn.neighbors import NearestNeighbors
from sklearn.linear_model import LinearRegression

def compute_intrinsic_dimension (data_array):
    n = data_array.shape[0]
    nn = NearestNeighbors(n_neighbors=3, algorithm='kd_tree', n_jobs=-1).fit(data_array)
    nn_distances, nn_indices = nn.kneighbors(data_array)
    mu = nn_distances[:,2] / nn_distances[:,1]
    i_sorted = np.argsort(mu)
    f_emp = np.zeros(n, dtype=float)
    f_emp[i_sorted] = [i /n for i in range(n)]
    x = np.log(mu).reshape(-1,1)
    y = -np.log(1. - F_emp).reshape(-1,1)
    l = LinearRegression(fit_intercept=False, n_jobs=1).fit(x,y)
    return l.coef_[0,0]

def plot_components(data_array, nn_model, images=None, axs=None,
                    thumb_frac=0.05, colormap='gray'):
    axs = axs or plt.gca()
    proj = nn_model.fit_transform(x)
    axs.plot(proj[:, 0], proj[:, 1], '.k')
    if images is not None:
        min_dist_2 = (thumb_frac * max(proj.max(0) - proj.min(0))) ** 2
        shown_images = np.array([2 * proj.max(0)])
        for i in range(data_array.shape[0]):
            dist = np.sum((proj[i] - shown_images) ** 2, 1)
            if np.min(dist) < min_dist_2:
                # don't show points that are too close
                continue
            shown_images = np.vstack([shown_images, proj[i]])
            image_box = offsetbox.AnnotationBbox(
                offsetbox.OffsetImage(images[i], cmap=colormap),
                                      proj[i])
            ax.add_artist(image_box)

In [None]:
from sklearn.manifold import Isomap
mod = Isomap(n_components=2)
xx = torch.load('S_features_resnet50_4096')
fig, ax = plt.subplots(figsize=(10, 10))
plot_components(xx, mod)

In [None]:
from sklearn.cluster import KMeans
from matplotlib import pyplot as plt

X = xx
distorsions = []
for k in range(2, 20):
    k_means = KMeans(n_clusters=k)
    k_means.fit(X)
    distorsions.append(k_means.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
dy2=np.diff(distorsions,n=2)
fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 18), dy2)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
dy2_log=np.diff(np.log(distorsions),n=2)
fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 18), dy2)
plt.grid(True)
plt.title('Elbow curve')