In [None]:
import torch

from pathlib import Path
from torch.utils.data import DataLoader

from src.utils import load_config
from src.data.dataset import ObjectDetectionDataset
from src.data.entry import read_entries_from_directory, split_entries_train_val_test
from src.data.visualize import plot_entries_original_and_annotated
from src.trainer import Trainer

print("Pytorch version:", torch.__version__)
print("CUDA enabled:", torch.cuda.is_available())

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available else "cpu")
DATASET_DIR = Path("C:/Users/robert/Desktop/sem1/NN/datasets/DetectionPatches_512x512_ALL")
CHECKPOINT_DIR = Path("models")
CONFIG_PATH = Path("config.json")
SINGLE_CLASS = True

In [None]:
config = load_config(CONFIG_PATH)
seed = config["seed"]
torch.manual_seed(seed)
config

In [None]:
entries = read_entries_from_directory(DATASET_DIR)
train_entries, val_entries, test_entries = split_entries_train_val_test(entries, seed=seed)

In [None]:
train_dataset = ObjectDetectionDataset(train_entries, single_class=SINGLE_CLASS, config=config)
val_dataset = ObjectDetectionDataset(val_entries, single_class=SINGLE_CLASS, config=config)
test_dataset = ObjectDetectionDataset(test_entries, single_class=SINGLE_CLASS, config=config)

print(f"Number TRAIN of entries: {len(train_dataset)} | positive {train_dataset.num_positive} | negative {train_dataset.num_negative}")
print(f"Number VAL of entries: {len(val_dataset)} | positive {val_dataset.num_positive} | negative {val_dataset.num_negative}")
print(f"Number TEST of entries: {len(test_dataset)} | positive {test_dataset.num_positive} | negative {test_dataset.num_negative}")

In [None]:
plot_entries_original_and_annotated(train_dataset.entries, samples_to_display=1)

In [None]:
batch_size = config["batch_size"]
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
trainer = Trainer(
    config=config,
    dataloaders=(train_dataloader, val_dataloader, test_dataloader),
    device=DEVICE,
    save_dir=CHECKPOINT_DIR,
    checkpoint_path=None
)

In [None]:
trainer.fit()

In [None]:
trainer.evaluate(train_dataloader)