In [1]:
import os, dotenv, sys
from pathlib import Path
sys.path.append('bacili_detection/src')
sys.path.append('bacili_detection/detr')

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from annotations.object_detection.object_detection import ImageForObjectDetection, Rect
from annotations.object_detection.dataset import DatasetForObjectDetection
from annotations import db
from bacili_detection.src.dataset.preprocessing import mask_filter, tile_coords
from bacili_detection.detr.util.misc import collate_fn
from bacili_detection.detr.datasets.tb_bacillus import TBBacilliDataset, make_ds_transforms
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch


dotenv.load_dotenv('.env')
session = db.get_session(os.environ.get("DATABASE_URI"))


In [15]:
train_artifacts = session.query(db.Artifact)\
    .join(db.Project)\
    .join(db.ArtifactTag, isouter=True)\
    .where(db.Project.name == "Bacilli Detection")\
    .group_by(db.Artifact.id)\
    .where(db.ArtifactTag.tag == "train")\
    .all()

print("Found {} train artifacts".format(len(train_artifacts)))

# take out half of the artifacts as the holdout set
inds = np.arange(len(train_artifacts))
holdout_artifacts_inds = np.random.choice(inds, size=len(train_artifacts)//2, replace=False)
# tag them as holdout
for i in holdout_artifacts_inds:
    artifact = train_artifacts[i]
    newtag = db.ArtifactTag(tag="holdout", artifact_id=artifact.id)
    session.add(newtag)
session.commit()
# add the tag 'train-cl' to the rest
for i in inds:
    if i not in holdout_artifacts_inds:
        artifact = train_artifacts[i]
        newtag = db.ArtifactTag(tag="incremental_training", artifact_id=artifact.id)
        session.add(newtag)
session.commit()

Found 202 train artifacts


In [19]:
holdout_ds = TBBacilliDataset('holdout', db_session=session)
print("Found {} holdout artifacts".format(len(holdout_ds)))
train_cl_ds = TBBacilliDataset('incremental_training', db_session=session)
print("Found {} incremental_training artifacts".format(len(train_cl_ds)))

# for imod in (holdout_ds + train_cl_ds)._images:
#     artifact = imod.artifact
#     for tag in artifact.tags:
#         if tag.tag == "holdout":
#             # tag.tag = "N/A"
#             session.delete(tag)
#         if tag.tag == "incremental_training":
#             # tag.tag = "train"
#             session.delete(tag)
    
# session.commit()

Found 101 holdout artifacts
Found 121 incremental_training artifacts
