# Create masks to frame the problem as a segmentation task

Sometimes it is useful to have the classifications as a mask instead of boxes.

In this notebook we will convert the boxes to a mask and save a mask for every classification.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from pyclouds.imports import *
from pyclouds.plot import *
from pyclouds.zooniverse import *
from pyclouds.helpers import wh2xy

In [3]:
IMGDIR = '/local/S.Rasp/cloud-classification/'

## Load datasets

In [4]:
clas_prac = pd.read_pickle('../../preprocessed_data/clas_prac.pkl')
clas_full = pd.read_pickle('../../preprocessed_data/clas_full.pkl')
annos_prac = pd.read_pickle('../../preprocessed_data/annos_prac.pkl')
annos_full = pd.read_pickle('../../preprocessed_data/annos_full.pkl')

In [None]:
subj_ids_prac = clas_prac.subject_ids.unique()
subj_ids_full = clas_full.subject_ids.unique()

## Create one mask per classification

In [None]:
clas_ids = annos_full.classification_id.unique(); len(clas_ids)

30310

In [None]:
classes

['Sugar', 'Flower', 'Fish', 'Gravel']

In [None]:
cl2id = {c: i+1 for i, c in enumerate(classes)}

In [None]:
cl2id

{'Sugar': 1, 'Flower': 2, 'Fish': 3, 'Gravel': 4}

In [None]:
def size(xywh): return xywh[2] * xywh[3]

In [None]:
def create_mask(clas_id, annos_df, img_dir, img_size=(2100, 1400), mask_dir='masks'):
    """
    Create a mask for one classification.
    I checked and the amount of overlap between boxes in one classification is relatively small. 
    For this reason we will just start with the largest box for a classification and 
    then overwrite the image with progressively smaller boxes.
    """
    ans = annos_df[annos_df.classification_id == clas_id]
    fn = ans.fn.iloc[0]
    mask_fn = mask_dir + '/' + fn.rstrip('.jpeg') + '_' + str(clas_id) + '.png'
    os.makedirs(os.path.join(img_dir, *mask_fn.split('/')[:-1]), exist_ok=True)
    mask = np.zeros(img_size, dtype=np.int8)
    
    if np.isnan(ans.x.iloc[0]): # Save image of zeros if no annotations
        m = Image.fromarray(mask.astype(np.int8).T)
        m.save(img_dir + mask_fn)
    
    else:
        ans_list = [[list(r[['x', 'y', 'width', 'height']].astype(int)), 
                     r['tool_label']] for i, r in ans.iterrows()]
        ans_list.sort(key=lambda x: size(x[0]), reverse=True)

        for a in ans_list:
            x1, y1, x2, y2 = wh2xy(*a[0])
            mask[x1:x2, y1:y2] = cl2id[a[1]]

        m = Image.fromarray(mask.astype(np.int8).T)
        m.save(img_dir + mask_fn)
    
    return fn, mask_fn

In [None]:
fns, mask_fns = [], []
for c in tqdm(clas_ids):
    fn, mask_fn = create_mask(c, annos_full, IMGDIR, mask_dir='masks2')

HBox(children=(IntProgress(value=0, max=30310), HTML(value='')))