In [None]:
from training_libs import retrievePytorchDetectionLibs
from inference_training import Configuration, ImageDataset
from inference_training import initCudaEnvironment, createTransforms
from inference_training import drawImageAndFeatureMasks
from inference_training import exportOnnxModel, writeONNXMeta, loadONNX
from inference_training import trainModel, saveModel, loadModel
from inference_training import createModelInstance, testInference

In [None]:
retrievePytorchDetectionLibs()

In [None]:
initCudaEnvironment(numCudaDevices=1,
                    visibleCudaDevices="0",
                    clearCudaDeviceCount=False)

In [None]:
# train on the GPU or on the CPU, if a GPU is not available
config = Configuration()
print("Device: " + str(config.device))

trainDirectory = "<PATH TO TRAIN FILES>"
testDirectory = "<PATH TO TEST FILES>"

config.setDatasetPaths(trainPath=trainDirectory, testPath=testDirectory)
config.setFilePrefix("foliage_")
config.setModelName("foliage")
config.setInputSizes(inputWidth=250, inputHeight=250)
config.setInputCellSize(cellSizeM=0.25, minCellSizeM=0.1, maxCellSizeM=0.5)

config.setVersion(20250121)

print("Version: " + str(config.version))

config.setModelInfo(channels=3, numClasses=8+1,  # (1 + background)
                    bboxOverlap=True, bboxPerImage=250, reuseModel=False)
config.setEpochs(1)

description = "Inference model to detect deciduous trees, pine trees, "\
    "heather, hedges,plants, reed, shrubbery, flowbeds. " \
    "Additionally regions of decidious trees without leaves can be detected."
config.setOnnxInfo(producer="Tygron", description=description)

config.addLegendEntry("Background", 0, "#00000000")
config.addLegendEntry("Deciduous Tree", 1, "#00ffbf")
config.addLegendEntry("Pine Tree", 2, "#12d900")
config.addLegendEntry("Heather", 3, "#f3a6b2")
config.addLegendEntry("Hedge", 4, "#8d5a99")
config.addLegendEntry("Shrubbery", 5, "#e80004")
config.addLegendEntry("Reed", 6, "#f8ff20")
config.addLegendEntry("Flowerbed", 7, "#b7484b")
config.addLegendEntry("Deciduous Tree (Leafless)", 8, "#e6994d")

config.setOnnxMetaData(scoreThreshold=0.2,
                       maskThreshold=0.3,
                       strideFraction=0.5)

config.setTensorInfo(tensorName='input_A:RGB_normalized', batchAmount=1)
trainingDataset = ImageDataset(config, True, createTransforms(True))
testDataset = ImageDataset(config, False, createTransforms(False))

print("Train Image count: "+str(trainingDataset.__len__()))
print("Test Image count: "+str(testDataset.__len__()))

if not trainingDataset.validateFiles(False):
    print("Inconsistent training dataset ")
    trainingDataset.validateFiles(True)

if not testDataset.validateFiles(False):
    print("Inconsistent test dataset ")
    testDataset.validateFiles(True)

print("Pytorch model name " + config.getPytorchModelFileName())
print("Onnx file name " + config.getOnnxFileName())

In [None]:
imageNumber = 5
print(trainingDataset.getLabelList(imageNumber))
drawImageAndFeatureMasks(config, trainingDataset, imageNumber)

In [None]:
loadExistingModel = False

if loadExistingModel:
    model = createModelInstance(config)
    loadModel(config, model, path=config.getPytorchModelFileName())

else:
    model = trainModel(config, trainingDataset, testDataset)
    saveModel(config, model, path=config.getPytorchModelFileName())

In [None]:
model.eval()
testPrediction = testInference(config, model=model,
                               dataset=testDataset, imageNumber=88)

In [None]:
exportOnnxModel(config, model)

In [None]:
writeONNXMeta(config)

In [None]:
onnx_model = loadONNX(config)
print(f"metadata_props={onnx_model.metadata_props}")