In [None]:
import wandb
wandb.login()

In [None]:
from diametery.line_fit import LineFit
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

In [35]:
def load_task(file_path):
    # Load an image
    im = Image.open(os.path.join(file_path))
    im = np.array(ImageOps.grayscale(im))
    #im = im.reshape(im.shape+(1,)).astype(np.float32)/255 # one channel image
    # Load segmentation
    seg:np.ndarray = np.load(os.path.join(file_path + '_seg'))['y']
    return im, seg

def select_point_and_fiber(seg):
    # Select a random point that is not background, return the mask for the fiber that the point touches.
    mask_all = seg > 0
    possible_points = np.argwhere(mask_all)
    point_index = np.random.randint(0, possible_points.shape[0]-1)
    point = possible_points[point_index]
    fiber_id = seg[point[0], point[1], point[2]]
    mask = seg == fiber_id
    selected_seg = np.zeros_like(seg, dtype=np.float32)
    selected_seg[mask] = 1.0
    return point[0:2], selected_seg

In [51]:
dataset_path = '/Users/carmenlopez/dev/diameterY/scratch/dataset_files_3D'
task_id = 'test0000'
file_path = os.path.join(dataset_path, task_id)
im, seg = load_task(file_path)
point, selected_seg = select_point_and_fiber(seg)
selected_seg = selected_seg.reshape((selected_seg.shape[0:2]))
print(selected_seg.shape)


(256, 256)


In [None]:
with wandb.init(project="diameterY", job_type="test", mode="online") as run:
    run.config.n_measurements = 30
    run.config.step_size = 0.3
    #Download dataset 3D fibers
    dataset_artifact = run.use_artifact("rendered-fibers-mini:v0")
    dataset_dir = dataset_artifact.download("dataset_files_3D")
    model = LineFit(run.config.n_measurements, run.config.step_size)
    rows = []
    pbar = tqdm(desc="images")
    for f in os.listdir(dataset_dir): 
        if f.startswith("test") and not (f.endswith("_params") or f.endswith("_seg")):
            file_path = os.path.join(dataset_dir, f)
            example = np.load(file_path)
            im, seg = load_task(file_path)
            point, selected_seg = select_point_and_fiber(seg)
            diameter_pred, mask_meas_lines = model.predict(example["x"])
            class_labels = ({0:'bg', 1:'measured_lines'})
            wandb_im = wandb.Image(example["x"], caption="masks_measured_lines", masks={
                            'measurements':{
                                'mask_data': mask_meas_lines,
                                'class_labels': class_labels  
                        }})
            rows.append(dict(
                measured_lines=wandb_im,
                d=example["d"],
                d_pred=diameter_pred,
                ))
            pbar.update()
    df = pd.DataFrame(rows)
    df["Error_abs"] =  abs((df["d_pred"] - df["d"])/ df["d"])
    mean_abs_error = df["Error_abs"].mean()
    artifact = wandb.Artifact("test_table", type="test-results")
    table = wandb.Table(dataframe=df)
    artifact.add(table, name="test-results")
    run.log_artifact(artifact)
    run.log(dict(mean_abs_error=mean_abs_error))
    
    