Importing libraries

In [None]:
import yaml
import os, sys
import cv2
import time

from detectron2.utils.logger import setup_logger
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.visualizer import ColorMode

In [None]:
sys.path.insert(0, "..")

In [None]:
import json

import pandas as pd
import geopandas as gpd

from joblib import Parallel, delayed
from tqdm import tqdm

from helpers.misc import reformat_xyz, scale_polygon
from helpers import COCO

In [None]:
from helpers.detectron2 import LossEvalHook # , CocoTrainer
from helpers.detectron2 import dt2predictions_to_list

# Prepare_data.py
Making the labels right

In [None]:
ROADS="/mnt/data-01/gsalamin/proj-roadsurf-b/02_Data/processed/shapefiles_gpkg/roads_polygons.shp"
TILES="/mnt/data-01/gsalamin/proj-roadsurf-b/02_Data/processed/json/tiles_aoi_z17.geojson"

OUTPUT_DIR="/mnt/data-01/gsalamin/proj-roadsurf-b/02_Data/processed/json"

written_files=[]


In [None]:
roads=gpd.read_file(ROADS)
tiles=gpd.read_file(TILES)

In [None]:
# TODO: write a function valid_geom(poly_gdf, correct=False, gdf_obj_name=None)

try:
    assert(roads[roads.is_valid==False].shape[0]==0), "Some geometries for the roads are invalid"
except Exception as e:
    print(e)
    sys.exit(1)

In [None]:
labels_gdf = roads.rename(columns={'BELAGSART': 'CATEGORY', 'road_width':'WIDTH'}).drop(columns=[
                    'DATUM_AEND', 'DATUM_ERST', 'ERSTELLUNG', 'ERSTELLU_1',
                    'REVISION_J', 'REVISION_M', 'GRUND_AEND', 'HERKUNFT', 'HERKUNFT_J',
                    'HERKUNFT_M', 'REVISION_Q', 'WANDERWEGE', 'VERKEHRSBE', 
                    'BEFAHRBARK', 'EROEFFNUNG', 'STUFE', 'RICHTUNGSG', 
                    'KREISEL', 'EIGENTUEME', 'VERKEHRS_1', 'NAME',
                    'TLM_STRASS', 'STRASSENNA', 'SHAPE_Leng',])
labels_gdf['SUPERCATEGORY']='road'
labels_gdf = labels_gdf.to_crs(epsg=4326)


In [None]:
labels_gdf.columns

In [None]:
try:
    assert( labels_gdf.crs == tiles.crs ), f"CRS mismatching: labels' CRS = {labels_gdf.crs} != OK_tiles' CRS = {tiles.crs}"
except Exception as e:
    print(e)
    sys.exit(1)

In [None]:
if labels_gdf[labels_gdf.is_valid==False].shape[0]!=0:
    print(f"There are {labels_gdf[labels_gdf.is_valid==False].shape[0]} invalid geometries for the reprojected labels.")

    print("Correction of the roads presenting an invalid geometry with a buffer of 0 m...")
    corrected_labels=labels_gdf.copy()
    corrected_labels.loc[corrected_labels.is_valid==False,'geometry']=corrected_labels[corrected_labels.is_valid==False]['geometry'].buffer(0)
    labels_gdf=corrected_labels.copy()

In [None]:
GT_labels_gdf = gpd.sjoin(labels_gdf, tiles, how='inner', op='intersects')

In [None]:

# the following two lines make sure that no object is counted more than once in case it intersects multiple tiles
GT_labels_gdf = GT_labels_gdf[labels_gdf.columns]
GT_labels_gdf.drop_duplicates(inplace=True)
OTH_labels_gdf = labels_gdf[ ~labels_gdf.index.isin(GT_labels_gdf.index)]

# In the current case, OTH_labels_gdf should be empty

In [None]:
try:
    assert( len(labels_gdf) == len(GT_labels_gdf) + len(OTH_labels_gdf) ),\
        f"Something went wrong when splitting labels into Ground Truth Labels and Other Labels. Total no. of labels = {len(labels_gdf)}; no. of Ground Truth Labels = {len(GT_labels_gdf)}; no. of Other Labels = {len(OTH_labels_gdf)}"
except Exception as e:
    print(e)
    sys.exit(1)


In [None]:
GT_LABELS_GEOJSON = os.path.join(OUTPUT_DIR, f'ground_truth_labels.geojson')
OTH_LABELS_GEOJSON = os.path.join(OUTPUT_DIR, f'other_labels.geojson')

GT_labels_gdf.to_file(GT_LABELS_GEOJSON, driver='GeoJSON')
written_files.append(GT_LABELS_GEOJSON)

if not OTH_labels_gdf.empty:
    OTH_labels_gdf.to_file(OTH_LABELS_GEOJSON, driver='GeoJSON')
    written_files.append(OTH_LABELS_GEOJSON)

In [None]:
GT_labels_gdf[GT_labels_gdf.is_valid==False]


# generate_tilesets.py

Definition of functions

In [None]:
def bounds_to_bbox(bounds):
    
    xmin = bounds[0]
    ymin = bounds[1]
    xmax = bounds[2]
    ymax = bounds[3]
    
    bbox = f"{xmin},{ymin},{xmax},{ymax}"
    
    return bbox

In [None]:
def get_COCO_image_and_segmentations(tile, labels, COCO_license_id, output_dir):
    
    _id, _tile = tile

    coco_obj = COCO.COCO()

    this_tile_dirname = os.path.relpath(_tile['img_file'].replace('all', _tile['dataset']), output_dir)
    this_tile_dirname = this_tile_dirname.replace('\\', '/') # should the dirname be generated from Windows

    COCO_image = coco_obj.image(output_dir, this_tile_dirname, COCO_license_id)
    segmentations = []
    
    if len(labels) > 0:
        
        xmin, ymin, xmax, ymax = [float(x) for x in bounds_to_bbox(_tile['geometry'].bounds).split(',')]
        
        # note the .explode() which turns Multipolygon into Polygons
        clipped_labels_gdf = gpd.clip(labels, _tile['geometry']).explode()

        #try:
        #    assert( len(clipped_labels_gdf) > 0 ) 
        #except:
        #    raise Exception(f'No labels found within this tile! Tile ID = {tile.id}')  

        for label in clipped_labels_gdf.itertuples():
            scaled_poly = scale_polygon(label.geometry, xmin, ymin, xmax, ymax, 
                                             COCO_image['width'], COCO_image['height'])
            scaled_poly = scaled_poly[:-1] # let's remove the last point

            segmentation = my_unpack(scaled_poly)

            try:
                assert(min(segmentation) >= 0)
                assert(max(segmentation) <= min(COCO_image['width'], COCO_image['height']))
            except Exception as e:
                raise Exception(f"Label boundaries exceed this tile size! Tile ID = {_tile['id']}")
                
            segmentations.append(segmentation)
            
    return (COCO_image, segmentations)


In [None]:
def make_hard_link(row):

    if not os.path.isfile(row.img_file):
        raise Exception('File not found.')

    src_file = row.img_file
    dst_file = src_file.replace('all', row.dataset)

    dirname = os.path.dirname(dst_file)

    if not os.path.exists(dirname):
        os.makedirs(dirname)

    if os.path.exists(dst_file):
        os.remove(dst_file)

    os.link(src_file, dst_file)

    return None

def img_md_record_to_tile_id(img_md_record):
    
    filename = os.path.split(img_md_record.img_file)[-1]
    
    z_x_y = filename.split('.')[0]
    z, x, y = z_x_y.split('_')
    
    return f"({x}, {y}, {z})"

def check_aoi_tiles(aoi_tiles_gdf):
    '''
    Check that the id of the AoI tile is exists and will be accepted by the function reformat_xyz
    The format should be "(<x>, <y>, <z>)" or "<x>, <y>, <z>"
    '''
    
    if 'id' not in aoi_tiles_gdf.columns.to_list():
        raise Exception("No 'id' column was found in the AoI tiles dataset.")
    if len(aoi_tiles_gdf[aoi_tiles_gdf.id.duplicated()]) > 0:
        raise Exception("The 'id' column in the AoI tiles dataset should not contain any duplicate.")
    
    try:
        aoi_tiles_gdf.apply(reformat_xyz, axis=1)
    except:
        raise Exception("IDs do not seem to be well-formatted. Here's how they must look like: (<integer 1>, <integer 2>, <integer 3>), e.g. (<x>, <y>, <z>).")
    
    if not aoi_tiles_gdf['id'].str.startswith('(').all():
        aoi_tiles_gdf['id']='('+aoi_tiles_gdf['id']
    if not aoi_tiles_gdf['id'].str.endswith(')').all():
        aoi_tiles_gdf['id']=aoi_tiles_gdf['id']+')'
    
    return

def my_unpack(list_of_tuples):
    # cf. https://www.geeksforgeeks.org/python-convert-list-of-tuples-into-list/
    
    return [item for t in list_of_tuples for item in t]

Definition of constants

In [None]:
with open("/home/gsalamin/Documents/GitHub/object-detector/scripts/config_test_NIR.yaml") as fp:
        cfg = yaml.load(fp, Loader=yaml.FullLoader)["generate_tilesets.py"]

In [None]:
DEBUG_MODE = cfg['debug_mode']

OUTPUT_DIR = cfg['output_folder']

ORTHO_WS_TYPE = cfg['datasets']['orthophotos_web_service']['type']
ORTHO_WS_URL = cfg['datasets']['orthophotos_web_service']['url']
ORTHO_WS_SRS = cfg['datasets']['orthophotos_web_service']['srs']
if 'layers' in cfg['datasets']['orthophotos_web_service'].keys():
    ORTHO_WS_LAYERS = cfg['datasets']['orthophotos_web_service']['layers']

AOI_TILES_GEOJSON = cfg['datasets']['aoi_tiles_geojson']

if 'ground_truth_labels_geojson' in cfg['datasets'].keys():
    GT_LABELS_GEOJSON = cfg['datasets']['ground_truth_labels_geojson']
else:
    GT_LABELS_GEOJSON = None
if 'other_labels_geojson' in cfg['datasets'].keys():
    OTH_LABELS_GEOJSON = cfg['datasets']['other_labels_geojson']
else:
    OTH_LABELS_GEOJSON = None

SAVE_METADATA = True
OVERWRITE = cfg['overwrite']
TILE_SIZE = cfg['tile_size']
N_JOBS = cfg['n_jobs']
COCO_YEAR = cfg['COCO_metadata']['year']
COCO_VERSION = cfg['COCO_metadata']['version']
COCO_DESCRIPTION = cfg['COCO_metadata']['description']
COCO_CONTRIBUTOR = cfg['COCO_metadata']['contributor']
COCO_URL = cfg['COCO_metadata']['url']
COCO_LICENSE_NAME = cfg['COCO_metadata']['license']['name']
COCO_LICENSE_URL = cfg['COCO_metadata']['license']['url']
COCO_CATEGORY_NAME = cfg['COCO_metadata']['category']['name']
COCO_CATEGORY_SUPERCATEGORY = cfg['COCO_metadata']['category']['supercategory']

In [None]:
print(GT_LABELS_GEOJSON)

In [None]:
# let's make the output directory in case it doesn't exist
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

written_files = []

Loading datasets

In [None]:
# ------ Loading datasets

aoi_tiles_gdf = gpd.read_file(AOI_TILES_GEOJSON)
try:
    check_aoi_tiles(aoi_tiles_gdf)
except Exception as e:
    print(f"AoI tiles check failed. Exception: {e}")
    sys.exit(1)

if GT_LABELS_GEOJSON:
    print("Loading Ground Truth Labels as a GeoPandas DataFrame...")
    gt_labels_gdf = gpd.read_file(GT_LABELS_GEOJSON)
    print(f"...done. {len(gt_labels_gdf)} records were found.")

    try:
        assert(gt_labels_gdf[gt_labels_gdf.is_valid==False].shape[0]==0), "Some geometries for the ground truth labels are invalid."
    except Exception as e:
        print(e)
        corrected_labels=gt_labels_gdf.copy()
        corrected_labels.loc[corrected_labels.is_valid==False,'geometry']=corrected_labels[corrected_labels.is_valid==False]['geometry'].buffer(0)
        gt_labels_gdf=corrected_labels.copy()

if OTH_LABELS_GEOJSON:
    print("Loading Other Labels as a GeoPandas DataFrame...")
    oth_labels_gdf = gpd.read_file(OTH_LABELS_GEOJSON)
    print(f"...done. {len(oth_labels_gdf)} records were found.")

    try:
        assert(oth_labels_gdf[oth_labels_gdf.is_valid==False].shape[0]==0), "Some geometries for the other labels are invalid."
    except Exception as e:
        print(e)
        sys.exit(1)
            

print("Generating the list of tasks to be executed (one task per tile)...")

DEBUG_MODE_LIMIT = 100
if DEBUG_MODE:
    print(f"Debug mode: ON => Only {DEBUG_MODE_LIMIT} tiles will be processed.")

    if GT_LABELS_GEOJSON:
        assert( aoi_tiles_gdf.crs == gt_labels_gdf.crs )
        aoi_tiles_intersecting_gt_labels = gpd.sjoin(aoi_tiles_gdf, gt_labels_gdf, how='inner', op='intersects')
        aoi_tiles_intersecting_gt_labels = aoi_tiles_intersecting_gt_labels[aoi_tiles_gdf.columns]
        aoi_tiles_intersecting_gt_labels.drop_duplicates(inplace=True)

    if OTH_LABELS_GEOJSON:
        assert( aoi_tiles_gdf.crs == oth_labels_gdf.crs )
        aoi_tiles_intersecting_oth_labels = gpd.sjoin(aoi_tiles_gdf, oth_labels_gdf, how='inner', op='intersects')
        aoi_tiles_intersecting_oth_labels = aoi_tiles_intersecting_oth_labels[aoi_tiles_gdf.columns]
        aoi_tiles_intersecting_oth_labels.drop_duplicates(inplace=True)
        
    # sampling tiles according to whether GT and/or GT labels are provided
    if GT_LABELS_GEOJSON and OTH_LABELS_GEOJSON:

        aoi_tiles_gdf = pd.concat([
            aoi_tiles_intersecting_gt_labels.head(DEBUG_MODE_LIMIT//2), # a sample of tiles covering GT labels
            aoi_tiles_intersecting_oth_labels.head(DEBUG_MODE_LIMIT//4), # a sample of tiles convering OTH labels
            aoi_tiles_gdf # the entire tileset, so as to also have tiles covering no label at all (duplicates will be dropped)
        ])
        
    elif GT_LABELS_GEOJSON and not OTH_LABELS_GEOJSON:
        aoi_tiles_gdf = pd.concat([
            aoi_tiles_intersecting_gt_labels.head(DEBUG_MODE_LIMIT*3//4),
            aoi_tiles_gdf
        ])
    
    elif not GT_LABELS_GEOJSON and OTH_LABELS_GEOJSON:
        aoi_tiles_gdf = pd.concat([
            aoi_tiles_intersecting_oth_labels.head(DEBUG_MODE_LIMIT*3//4),
            aoi_tiles_gdf
        ])
    else:
        pass # the following two lines of code would apply in this case
        
    aoi_tiles_gdf.drop_duplicates(inplace=True)
    aoi_tiles_gdf = aoi_tiles_gdf.head(DEBUG_MODE_LIMIT).copy()


In [None]:
gt_labels_gdf[gt_labels_gdf.is_valid==False]

Labels treatment

In [None]:
try:
    assert( aoi_tiles_gdf.crs == gt_labels_gdf.crs ), "CRS Mismatch between AoI tiles and labels."
except Exception as e:
    print(e)
    sys.exit(1)

GT_tiles_gdf = gpd.sjoin(aoi_tiles_gdf, gt_labels_gdf, how='inner', op='intersects')
# remove columns generated by the Spatial Join
GT_tiles_gdf = GT_tiles_gdf[aoi_tiles_gdf.columns].copy()
GT_tiles_gdf.drop_duplicates(inplace=True)

In [None]:
# remove tiles including at least one "oth" label (if applicable)
if OTH_LABELS_GEOJSON:
    tmp_GT_tiles_gdf = GT_tiles_gdf.copy()
    tiles_to_remove_gdf = gpd.sjoin(tmp_GT_tiles_gdf, oth_labels_gdf, how='inner', op='intersects')
    GT_tiles_gdf = tmp_GT_tiles_gdf[~tmp_GT_tiles_gdf.id.astype(str).isin(tiles_to_remove_gdf.id.astype(str))].copy()
    del tmp_GT_tiles_gdf

# OTH tiles = AoI tiles which are not GT
OTH_tiles_gdf = aoi_tiles_gdf[ ~aoi_tiles_gdf.id.astype(str).isin(GT_tiles_gdf.id.astype(str)) ].copy()
OTH_tiles_gdf['dataset'] = 'oth'


In [None]:
assert( len(aoi_tiles_gdf) == len(GT_tiles_gdf) + len(OTH_tiles_gdf) )

In [None]:
# 70%, 15%, 15% split
trn_tiles_ids = GT_tiles_gdf\
    .sample(frac=.7, random_state=1)\
    .id.astype(str).values.tolist()

val_tiles_ids = GT_tiles_gdf[~GT_tiles_gdf.id.astype(str).isin(trn_tiles_ids)]\
    .sample(frac=.5, random_state=1)\
    .id.astype(str).values.tolist()

tst_tiles_ids = GT_tiles_gdf[~GT_tiles_gdf.id.astype(str).isin(trn_tiles_ids + val_tiles_ids)]\
    .id.astype(str).values.tolist()

GT_tiles_gdf.loc[GT_tiles_gdf.id.astype(str).isin(trn_tiles_ids), 'dataset'] = 'trn'
GT_tiles_gdf.loc[GT_tiles_gdf.id.astype(str).isin(val_tiles_ids), 'dataset'] = 'val'
GT_tiles_gdf.loc[GT_tiles_gdf.id.astype(str).isin(tst_tiles_ids), 'dataset'] = 'tst'

assert( len(GT_tiles_gdf) == len(trn_tiles_ids) + len(val_tiles_ids) + len(tst_tiles_ids) )

In [None]:
split_aoi_tiles_gdf = pd.concat(
    [
        GT_tiles_gdf,
        OTH_tiles_gdf
    ]
)

# let's free up some memory
del GT_tiles_gdf
del OTH_tiles_gdf

In [None]:
assert( len(split_aoi_tiles_gdf) == len(aoi_tiles_gdf) ) # it means that all the tiles were actually used

SPLIT_AOI_TILES_GEOJSON = os.path.join(OUTPUT_DIR, 'split_aoi_tiles.geojson')

try:
    split_aoi_tiles_gdf.to_file(SPLIT_AOI_TILES_GEOJSON, driver='GeoJSON')
    # sp_tiles_gdf.to_crs(epsg=2056).to_file(os.path.join(OUTPUT_DIR, 'swimmingpool_tiles.shp'))
except Exception as e:
    print(e)
written_files.append(SPLIT_AOI_TILES_GEOJSON)
print(f'...done. A file was written {SPLIT_AOI_TILES_GEOJSON}')


In [None]:
IMG_METADATA_FILE = os.path.join(OUTPUT_DIR, 'img_metadata.json')

with open(IMG_METADATA_FILE) as f:
    img_metadata_dict=json.load(f)

In [None]:
img_md_df = pd.DataFrame.from_dict(img_metadata_dict, orient='index')
img_md_df.reset_index(inplace=True)
img_md_df.rename(columns={"index": "img_file"}, inplace=True)

img_md_df['id'] = img_md_df.apply(img_md_record_to_tile_id, axis=1)



split_aoi_tiles_with_img_md_gdf = split_aoi_tiles_gdf.merge(img_md_df, on='id', how='left')
split_aoi_tiles_with_img_md_gdf.apply(make_hard_link, axis=1)

In [None]:
labels_gdf = gt_labels_gdf.copy().reset_index()

Creating the COCO annotations

In [None]:
for dataset in split_aoi_tiles_with_img_md_gdf.dataset.unique():
    
    coco = COCO.COCO()
    coco.set_info(the_year=COCO_YEAR, 
                    the_version=COCO_VERSION, 
                    the_description=f"{COCO_DESCRIPTION} - {dataset} dataset", 
                    the_contributor=COCO_CONTRIBUTOR, 
                    the_url=COCO_URL)
    
    coco_license = coco.license(the_name=COCO_LICENSE_NAME, the_url=COCO_LICENSE_URL)
    coco_license_id = coco.insert_license(coco_license)

    # TODO: read (super)category from the labels datataset
    coco_category = coco.category(the_name=COCO_CATEGORY_NAME, the_supercategory=COCO_CATEGORY_SUPERCATEGORY)                      
    coco_category_id = coco.insert_category(coco_category)
    
    tmp_tiles_gdf = split_aoi_tiles_with_img_md_gdf[split_aoi_tiles_with_img_md_gdf.dataset == dataset].dropna()
    #tmp_tiles_gdf = tmp_tiles_gdf.to_crs(epsg=3857)
    
    if len(labels_gdf) > 0:
        assert(labels_gdf.crs == tmp_tiles_gdf.crs)
    
    tiles_iterator = tmp_tiles_gdf.sort_index().iterrows()

    results = Parallel(n_jobs=N_JOBS, backend="loky") \
                    (delayed(get_COCO_image_and_segmentations) \
                    (tile, labels_gdf, coco_license_id, OUTPUT_DIR) \
                    for tile in tqdm( tiles_iterator, total=len(tmp_tiles_gdf) ))
    
    for result in results:
        coco_image, segmentations = result
        coco_image_id = coco.insert_image(coco_image)

        for segmentation in segmentations:

            coco_annotation = coco.annotation(coco_image_id,
                coco_category_id,
                [segmentation],
                the_iscrowd=0
            )

            coco.insert_annotation(coco_annotation)
    
    COCO_file = os.path.join(OUTPUT_DIR, f'COCO_{dataset}.json')
    with open(COCO_file, 'w') as fp:
        json.dump(coco.to_json(), fp)
    written_files.append(COCO_file)

In [None]:
print("The following files were written. Let's check them out!")
for written_file in written_files:
    print(written_file)

# train_model.py

In [None]:
# from detectron2.engine.hooks import HookBase
from detectron2.engine import DefaultTrainer
from detectron2.data import build_detection_test_loader, build_detection_train_loader, DatasetMapper
from detectron2.evaluation import COCOEvaluator
# from detectron2.utils import comm
# from detectron2.utils.logger import log_every_n_seconds


Definition of the constant

In [None]:

MODEL_ZOO_CHECKPOINT_URL = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"

MODEL_PTH_FILE = None
    
COCO_TRN_FILE = "COCO_trn.json"
COCO_VAL_FILE = "COCO_val.json"
COCO_TST_FILE = "COCO_tst.json"
        
DETECTRON2_CFG_FILE = '/home/gsalamin/Documents/GitHub/object-detector/scripts/det2_config_test_NIR.yaml'

WORKING_DIR = "/mnt/data-01/gsalamin/proj-roadsurf-b/02_Data/processed/obj_detector"
SAMPLE_TAGGED_IMG_SUBDIR = "sample_training_images"
LOG_SUBDIR = "logs"
OUTPUT_DIR=WORKING_DIR+'/tests'   

os.chdir(WORKING_DIR)
# let's make the output directories in case they don't exist
for DIR in [SAMPLE_TAGGED_IMG_SUBDIR, LOG_SUBDIR]:
    if not os.path.exists(DIR):
        os.makedirs(DIR)

written_files = []

In [None]:
# ---- register datasets
register_coco_instances("trn_dataset", {}, COCO_TRN_FILE, "")
register_coco_instances("val_dataset", {}, COCO_VAL_FILE, "")
register_coco_instances("tst_dataset", {}, COCO_TST_FILE, "")


In [None]:

registered_datasets = ['trn_dataset', 'val_dataset', 'tst_dataset']
    
registered_datasets_prefixes = [x.split('_')[0] for x in registered_datasets]

In [None]:
for dataset in registered_datasets:
    
    for d in DatasetCatalog.get(dataset)[0:min(len(DatasetCatalog.get(dataset)), 4)]:
        output_filename = "tagged_" + d["file_name"].split('/')[-1]
        output_filename = output_filename.replace('tif', 'png')
        
        img = cv2.imread(d["file_name"])  
        
        visualizer = Visualizer(img[:, :, ::-1], metadata=MetadataCatalog.get(dataset), scale=1.0)
        
        vis = visualizer.draw_dataset_dict(d)
        cv2.imwrite(os.path.join(SAMPLE_TAGGED_IMG_SUBDIR, output_filename), vis.get_image()[:, :, ::-1])
        written_files.append( os.path.join(WORKING_DIR, os.path.join(SAMPLE_TAGGED_IMG_SUBDIR, output_filename)) )
            


In [None]:
print("The following files were written. Let's check them out!")
for written_file in written_files:
    print(written_file)

In [None]:
# cf. https://detectron2.readthedocs.io/modules/config.html#config-references
cfg = get_cfg()
cfg.merge_from_file(DETECTRON2_CFG_FILE)
cfg.OUTPUT_DIR = LOG_SUBDIR

In [None]:
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(MODEL_ZOO_CHECKPOINT_URL)

In [None]:
class CocoTrainer(DefaultTrainer):

  @classmethod
  def build_evaluator(cls, cfg, dataset_name, output_folder=None):
    '''
    Adding an evaluator for the test set, because it is not included by default
    '''
      
    if output_folder is None:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        
    os.makedirs("COCO_eval", exist_ok=True)
    
    return COCOEvaluator(dataset_name, cfg, False, output_folder)

  @classmethod
  def build_train_loader(cls, cfg):
    '''
    Build a custom dataloader to handel images with more than 3 channels
    cf. https://detectron2.readthedocs.io/en/latest/tutorials/data_loading.html
    '''

    standard_formats=['1', 'L', 'P', 'RGB', 'RGBA']

    if False: # cfg.INPUT.FORMAT=="RGBNir":
      # TODO: modify the code from: https://detectron2.readthedocs.io/en/latest/_modules/detectron2/data/detection_utils.html#read_image
      mapper=True # Get the custom mapper
      print('This is a test')
      sys.exit(1)
    else:
      mapper=DatasetMapper(cfg, is_train=True) # Default choice for mapper

    return build_detection_train_loader(cfg, mapper=mapper)

  
  def build_hooks(self):
    '''
    A Hook is a function called on each step.
    1- Add a custom Hook to the Trainer that gets called after EVAL_PERIOD steps
    2- When the Hook is called, do inference on the whole Evaluation dataset
    3- Every time inference is done, get the loss on the same way it is done when training, and store the mean value for all the dataset.
    '''
        
    hooks = super().build_hooks()
    
    hooks.insert(-1,
        LossEvalHook(
            self.cfg.TEST.EVAL_PERIOD,
            self.model,
            build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST[0], DatasetMapper(self.cfg, True))
        )
    )

    print('test')
                
    return hooks

    

In [None]:
trainer = CocoTrainer(cfg)

In [None]:
trainer.resume_or_load(resume=False)
trainer.train()

# Test if the cell works without the print
