In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import plotly.express as px
import xarray as xr

from pydicom import dcmread

import src.df_filters as dff
import src.file_traversal
import src.make_df as make_df
import src.process_data as proc_data

from src.file_traversal import read_config
from src.process_data import process_images


In [2]:
config = read_config('config')
labels = pd.read_csv(config['labels_path'])
labels = make_df.fill_labels_df(labels)
scan_paths = make_df.fill_scan_df(labels, config)

In [3]:
labels.describe()

Unnamed: 0,MGMT_value
count,585.0
mean,0.524786
std,0.499813
min,0.0
25%,0.0
50%,1.0
75%,1.0
max,1.0


In [4]:
scan_paths.describe()

Unnamed: 0,image_id
count,96766.0
mean,104.180559
std,72.989321
min,1.0
25%,45.0
50%,95.0
75%,151.0
max,400.0


In [5]:
# test 1: all patients are discovered. Should return blank array. 
np.setdiff1d(labels.index, scan_paths['pid'].unique())

array([], dtype=object)

In [None]:
# visualize dimensions of scan
im_dims = []
for patient in scan_paths.groupby('pid'): # group by patients 
    grouped, df_slice = patient  # patient is a tuple of (<grouped index>, <df slice>)
    mid_index = int(len(df_slice)/2)  # get index of middle cross-section
    mid_path = df_slice.take([mid_index]).file_path.values[0] # get raw value of file_path
    mid_im = proc_data.read_dicom(mid_path)
    im_shape = np.shape(mid_im)
    im_dims.append(im_shape)
im_dims = pd.DataFrame(im_dims, columns=['x', 'y'])

squares = im_dims[im_dims['x'] == im_dims['y']].value_counts()
rectangles = im_dims[im_dims['x'] != im_dims['y']].value_counts()

In [87]:
px.bar(x=[x for x,y in squares.index], y=squares.values)

In [86]:
import plotly.graph_objects as go
x_dat = im_dims.x.value_counts()
y_dat = im_dims.y.value_counts()
fig = go.Figure(data=[
    go.Bar(name='Width', x=x_dat.index, y=x_dat.values),
    go.Bar(name='Height', x=y_dat.index, y=y_dat.values)
])
fig.update_layout(barmode='group')
fig.show()

In [6]:
scan = proc_data.read_full_3d_scan(scan_paths, '01010')

In [7]:
scan

from src.process_data import batch_processor

@batch_processor(batch_size=50)
def process_scans(batch):
    processed_batch = []
    for patient in batch:
        scan = proc_data.read_full_3d_scan(scan_paths, patient)
        processed_batch.append(scan)
    return processed_batch

import os.path
patients =  xr.DataArray(labels.index[:500], dims='pid')
for processed_batch in process_scans(patients):
    save_path = os.path.join('data/netcdf/', f'{processed_batch.name}.nc')
    processed_batch.to_netcdf(path=save_path)

import os.path
from src.process_data import batch_processor

@batch_processor(batch_size=50)
def process_scans(batch):
    for patient in batch:
        scan = proc_data.read_full_3d_scan(scan_paths, patient)
        
        # Scale by scan maximum due to non-uniform brightness 
        max = scan/scan.max()
        scan = scan/max 
        # Trim data where variance of image < 4E-4 (basically blank image)
        var = scan.var(dim=['x', 'y'])
        scan = scan[var > 4E-4]
       
        # save processed data as netCDF 
        save_path = os.path.join('data/processed/', f'{scan.name}.nc')
        scan.to_netcdf(path=save_path)
    return processed_batch

patients =  xr.DataArray(labels.index[500:], dims='pid')
max_val = 0
for processed_batch in process_scans(patients):
    if processed_batch.max() > 0:
        max_val = processed_batch.max
        print('Updated max to', max_val)
    print(f'Processed batch: {processed_batch}')