In [None]:
# !pip install albumentations
# !pip install -U jupyterlab==3.0.16
# !pip install ipywidgets # --user

In [None]:
# imports
import sys
import torch
from determined.pytorch import PyTorchTrial,PyTorchTrialContext
from model_def import DistNetTrial
import matplotlib.pyplot as plt
from matplotlib import patches
import numpy as np
import warnings

In [None]:
# supress annoying warnings
warnings.filterwarnings('ignore')

In [None]:
# create my config dictionary
config = {'hyperparameters': {'learning_rate': 1e-6,
                              'weight_decay': 1e-4,
                              'global_batch_size': 4,
                              'lambda_d': 10,
                              'lambda_k':0},
          'data': {'train_data_dir': '/irad_mounts/lambda-quad-5-data/beholder/',
                   'val_data_dir': '/irad_mounts/lambda-quad-5-data/beholder/',
                   'make_local': True}}

# create my trial
context = PyTorchTrialContext.from_config(config)
testTrial = DistNetTrial(context)

In [None]:
# build my loaders and check their respective lengths
train_loader = testTrial.build_training_data_loader()
val_loader = testTrial.build_validation_data_loader()

print(len(train_loader))
print(len(val_loader))

In [None]:
# display some test data to be sure its working correctly
idx=4000
(image,boxes,distances,classes) = train_loader.dataset.__getitem__(idx)
image,boxes,classes = image.numpy(),boxes.numpy(),classes.numpy()
fig,ax = plt.subplots(figsize=(20,16))
ax.imshow(np.transpose(image, (1,2,0)))
for i in range(len(boxes)):
    rect = patches.Rectangle((boxes[i][0],boxes[i][3]),boxes[i][2]-boxes[i][0],boxes[i][1]-boxes[i][3], linewidth=2, edgecolor='r', facecolor='none', label='object')
    ax.add_patch(rect)
plt.show()

In [None]:
# Try an epoch of training to make sure everything works
for epoch in range(1):
    for batch_idx, batch in enumerate(train_loader):
        metrics = testTrial.train_batch(batch, epoch, batch_idx)
        if (batch_idx % 100) == 0:
            print(metrics)
            break

In [None]:
# Try one round of eval to make sure everything works
metrics = testTrial.evaluate_full_dataset(val_loader)
print(metrics)

In [None]:
# looks like we're good to go!