#### Instructions

The model can be loaded by first generating a conda environment using the env_for_fern_segmentaion.txt file. Run the following two lines in a terminal first. You must have conda installed. 

In [None]:
# conda create --name herb_segmentation --file env_for_fern_segmentation.txt
# conda activate herb_segmentation

Note: fastai and pytorch are not simple to install. Be prepared to spend some time getting fastai and pytorch in the right configuration according to the env_for_fern_segmentation.txt specifications. This may include downloading older versions of these libraries, as they are in active development and new versions are released frequently. 

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import shutil
from tqdm import tqdm
from PIL import Image as image_save
import itertools
import operator
import fastai
from fastai import *
from fastai.vision import *
from fastai.vision.models.wrn import wrn_22
import dask.dataframe as dd
import functools, traceback
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

## Load the custom classes and the trained model "fern_segmentation.pkl"

In [None]:
class SegLabelListCustom(SegmentationLabelList):
    def open(self, fn): return open_mask(fn, div=True)
    
class SegItemListCustom(SegmentationItemList):
    _label_cls = SegLabelListCustom

In [None]:
path_to_pickle = '/this/is/where/i/put/the/fern_segmentation.pkl/file' # MUST CHANGE!

In [None]:
seg_bot = load_learner(path = path_to_pickle,
                      file = 'fern_segmentation.pkl')

## Running a single image through

In [None]:
path_img = Path("this/is/a/folder/full/of/images/i/want/to/run/through/the/model/01452951.jpg") # MUST CHANGE!

In [None]:
img = open_image(path_img)
img_mask_pred = seg_bot.predict(img)

In [None]:
img_mask_pred

## Running a large batch of images through and saving the masked versions

In [None]:
path_to_images = 'this/is/a/folder/full/of/images/i/want/to/run/through/the/model' # MUST CHANGE!

In [None]:
data_test = (ImageList.from_folder(path = path_to_images, 
                                        extensions = ".jpg")
             .split_none()
             .label_empty()).transform(tfms=None, size=256).databunch(bs=64).normalize(imagenet_stats)

In [None]:
bs = 64 # could change this is you are having issues with memory
seg_bot.data.test_dl = data_test.fix_dl

In [None]:
number_of_batches = int(len(seg_bot.data.test_ds)/bs)

In [None]:
path_to_save_masked_images = "this/is/a/folder/where/i/want/to/save/masked/images" # MUST CHANGE!

In [None]:
test_batch_iter = iter(seg_bot.data.test_dl)
test_filenames_iter = iter(seg_bot.data.test_ds.items)
test_images_iter = iter(seg_bot.data.test_ds)

for n in tqdm(range(number_of_batches)):
    batch = next(test_batch_iter)
    preds_tup = seg_bot.pred_batch(batch=batch)
    pred_masks = np.argmax(preds_tup, axis = 1)
    pred_names = array(itertools.islice(test_filenames_iter, bs))
    orig_images = array(itertools.islice(test_images_iter, bs))
    for z in range(bs):
        #print(pred_names[z].parts[-1] + " being masked and output to masked_" + pred_names[z].parts[-1])
        orig_loaded_img = orig_images[z][0].data
        pred_mask = pred_masks[z].unsqueeze(0).double()
        masked = orig_loaded_img.cpu().double() * pred_mask
        mask2 = masked.data.permute(1, 2, 0)
        ndarr = mask2.mul_(255).add_(0.5).clamp_(0, 255).to('cpu', torch.uint8).numpy()
        im = image_save.fromarray(ndarr)
        im.save(path_to_save_masked_images+"/masked_"+pred_names[z].parts[-1])