In [1]:
import basicpy
from basicpy import datasets
from matplotlib import pyplot as plt
from skimage import io
from subprocess import check_output
from pathlib import Path
import pandas as pd
import time

In [2]:
basedir = Path("/work/fukai/basicpy/BaSiCPy/misc_notebooks/analysis_for_publication/")
imagedir = basedir/"testdata_for_imagej"

In [7]:
fitting_res = []

for trial in range(10):
    for name in datasets.RESCALED_TEST_DATA_PROPS.keys():
        images = datasets.fetch(name)
        filepath=str(imagedir/f"{name}.tif")
        io.imsave(filepath, images)
        print(filepath)
        for get_darkfield in [False, True]:
            res = check_output([
                f"/opt/Fiji.app/ImageJ-linux64 --headless --run imagej_macro.py \"filename='{filepath}', get_darkfield={1 if get_darkfield else 0}\""
            ],shell=True)
            lines = [l for l in res.decode().split("\n") if l.startswith("erapsed time")]
            assert len(lines) == 1
            erapsed_time = float(lines[0].split(":")[1])
            print(name,erapsed_time)
            fitting_res.append({
                "image_name": name,
                "trial": trial,
                "method":"ImageJ BaSiC",
                "get_darkfield": get_darkfield,
                "erapsed_time": erapsed_time
            })
fitting_res_df = pd.DataFrame.from_records(fitting_res)
fitting_res_df.to_csv("imagej_benchmark.csv")

/work/fukai/basicpy/BaSiCPy/misc_notebooks/analysis_for_publication/testdata_for_imagej/cell_culture.tif




In [None]:
import jax
jax.config.update('jax_platform_name', 'gpu')

fitting_res = []
for trial in range(10):
    for name in datasets.RESCALED_TEST_DATA_PROPS.keys():
        images = datasets.fetch(name)
        for get_darkfield in [False, True]:
            b = basicpy.BaSiC(fitting_mode="approximate",get_darkfield=get_darkfield)
            b.fit(images)
            start = time.time()
            b.fit(images)
            stop = time.time()
            erapsed_time = stop - start
            suffix = "with_darkfield" if get_darkfield else "no_darkfield"
            io.imsave(imagedir/f"jax_gpu_{name}_flatfield_{suffix}.tif",b.flatfield)
            io.imsave(imagedir/f"jax_gpu_{name}_darkfield_{suffix}.tif",b.darkfield)
            fitting_res.append({
                    "image_name": name,
                    "trial": trial,
                    "method":"GPU BaSiCPy",
                    "get_darkfield": get_darkfield,
                    "erapsed_time": erapsed_time
            })
fitting_res_df2 = pd.DataFrame.from_records(fitting_res)
fitting_res_df2.to_csv("gpu_benchmark.csv")

15.0010001659

In [None]:
import jax
jax.config.update('jax_platform_name', 'cpu')
fitting_res = []
for trial in range(10):
    for name in datasets.RESCALED_TEST_DATA_PROPS.keys():
        images = datasets.fetch(name)
        for get_darkfield in [False, True]:
            b = basicpy.BaSiC(fitting_mode="approximate",get_darkfield=get_darkfield)
            b.fit(images)
            start = time.time()
            b.fit(images)
            stop = time.time()
            erapsed_time = stop - start
            suffix = "with_darkfield" if get_darkfield else "no_darkfield"
            io.imsave(imagedir/f"jax_gpu_{name}_flatfield_{suffix}.tif",b.flatfield)
            io.imsave(imagedir/f"jax_gpu_{name}_darkfield_{suffix}.tif",b.darkfield)
            fitting_res.append({
                    "image_name": name,
                    "trial": trial,
                    "method":"CPU BaSiCPy",
                    "get_darkfield": get_darkfield,
                    "erapsed_time": erapsed_time
            })
fitting_res_df3 = pd.DataFrame.from_records(fitting_res)
fitting_res_df3.to_csv("cpu_benchmark.csv")

In [None]:
fitting_res_df_all = pd.concat([
    fitting_res_df,
    fitting_res_df2,
    fitting_res_df3
])

In [None]:
fitting_res_df_all.to_csv("time_benchmark.csv")