In [None]:
import pickle 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import omicronscala
import spym
import xarray
import os
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool
from functools import partial


def load_stm():
    """load stm dataset from cleaned version"""
    with open('clean_stm.pkl','rb') as f:
        df = pickle.load(f)
    return df


def save_train_img(imgs_path, df, img_id):
    """helper to create training dataset, plot image from ID in 224x224 px"""
    plt.ioff()
    img = df.loc[img_id]
    file = img.ImageOriginalName
    ds = omicronscala.to_dataset(Path(imgs_path+file))
    tf = ds.Z_Forward
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    fig = plt.figure(figsize=(4,4))
    axis = plt.Axes(fig, [0., 0., 1., 1.])
    axis.set_axis_off()
    fig.add_axes(axis)
    tf.plot(ax=axis, cmap='afmhot', add_colorbar=False)
    Path('data/train').mkdir(parents=True, exist_ok=True)
    plt.savefig('data/train/{}.png'.format(ID), aspect='auto', dpi=56)
    plt.close()


def make_training_set(imgs_path, df):
    """sequential function that create training dataset for all images"""
    for i in tqdm(range(0,len(df)), position=0, leave=True):
        try:
            img_id = df.iloc[i].name
            imgTrain(imgs_path, stm, img_id)
        except Exception as e:
            print(e)
            continue


def multicore_train(imgs_path, df, workers):

    """multi core training dataset creation"""
    splits = np.array_split(df, workers)
    pool = Pool(workers)
    func = partial(make_training_set, imgs_path)
    pool.map(func, splits)
    pool.close()
    pool.join()


def autocomplete_train(df,imgs_path):
    """checks training set, completes it if some images are missing"""
    for i, row in df.iterrows():
        if not os.path.isfile('{}/data/train/{}.png'.format(imgs_path,i)):
            try:
                imgTrain(path, stm, i)
            except:
                print(i)

In [None]:
stm = load_stm()
path = "path_to_images"
n_cores = 24
multicore_train(stm, path, n_cores)
autocomplete_train(stm, path)