In [None]:
import pickle 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import omicronscala
import spym
import xarray
import os
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool

def load_stm():
    "load stm dataset from cleaned version"
    with open('clean_stm.pkl','rb') as f:
        df = pickle.load(f)
    return df

def imgTrain(path, df, ID):
    "helper to create training dataset, plot image from ID in 224x244 px"
    plt.ioff()
    img = df.loc[ID]
    file = img.ImageOriginalName
    ds = omicronscala.to_dataset(Path(path+file))
    tf = ds.Z_Forward
    tf.spym.plane()
    tf.spym.align()
    tf.spym.plane()
    tf.spym.fixzero(to_mean=True)
    fig = plt.figure(figsize=(4,4))
    axis = plt.Axes(fig, [0., 0., 1., 1.])
    axis.set_axis_off()
    fig.add_axes(axis)
    tf.plot(ax=axis, cmap='afmhot', add_colorbar=False)
    Path('data/train').mkdir(parents=True, exist_ok=True)
    plt.savefig('data/train/{}.png'.format(ID), aspect='auto', dpi=56)
    plt.close()

def make_train(df, path):
    "sequential function that create training dataset for all images"
    for i in tqdm(range(0,len(df)), position=0, leave=True):
        try:
            ID = df.iloc[i].name
            imgTrain(path, stm, ID)
        except Exception as e:
            print(e)
            continue

def multicore_train(df,path,ncores):
    "multicore training dataset creation"
    splits = np.array_split(df, ncores)
    pool = Pool(ncores)
    pool.map(make_train, splits)
    pool.close()
    pool.join()

def autocomplete_train(df,path):
    "checks training set, completes it if some images are missing"
    for i, row in df.iterrows():
        if not os.path.isfile('{}/data/train/{}.png'.format(path,i)):
            try:
                imgTrain(path, stm, i)
            except Exception as e:
                print(i)

In [None]:
stm = load_stm()
path = "path_to_images"
multicore_train(stm,path,24)
autocomplete_train(stm,path)