In [1]:
import os
os.environ["GIT_PYTHON_REFRESH"] = "quiet"

In [2]:
# %reset -f
import torch
import mlflow
import mlflow.pytorch
from torchsummary import summary
from pipeline import monai_utils
from monai.handlers import StatsHandler
from mlflow.tracking import MlflowClient
from pipeline.monai_utils import parseInputs

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [4]:
def print_auto_logged_info(r):
    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print("run_id: {}".format(r.info.run_id))
    print("artifacts: {}".format(artifacts))
    print("params: {}".format(r.data.params))
    print("metrics: {}".format(r.data.metrics))
    print("tags: {}".format(tags))

In [5]:
def train(cfg, overrides=None):
    with mlflow.start_run() as run:
        if overrides is not None:
            cfg.update_from_args(overrides)
        mlflow.log_params(cfg.config)
        (train_dataset,valid_dataset) = cfg.get_dataset()
        (train_loader,valid_loader) = cfg.get_loaders(train_dataset,valid_dataset)
        my_model = cfg.get_model_instance()
        my_model = my_model.to(device)
        opt = cfg.get_optimizer_instance(my_model)
        loss_func = cfg.get_criterion_instance()
#         summary(my_model,(1,128,128,128))
        inferer = None
        train_handlers = StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        trainer = cfg.get_train_engine(device,train_loader,valid_loader,my_model,inferer,opt,loss_func,train_handlers)
        trainer.run()
        mlflow.pytorch.log_model(my_model, "model")
#         print(run.info, run.data)
        print_auto_logged_info(mlflow.get_run(run_id = run.info.run_id))

In [6]:
cfg, my_args = parseInputs()

In [7]:
train(cfg = cfg,overrides = my_args)

  0%|                                                                                          | 0/130 [00:00<?, ?it/s]

Reading the training data:


100%|████████████████████████████████████████████████████████████████████████████████| 130/130 [00:25<00:00,  5.03it/s]
  0%|                                                                                           | 0/20 [00:00<?, ?it/s]

Reading the validation data:


100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.73it/s]


Training Results - Epoch: 1  Avg accuracy: 0.90 Avg loss: 0.46
Validation Results - Epoch: 1  Avg accuracy: 0.80 Avg loss: 0.58
Training Results - Epoch: 2  Avg accuracy: 0.99 Avg loss: 0.20
Validation Results - Epoch: 2  Avg accuracy: 0.75 Avg loss: 0.52
Training Results - Epoch: 3  Avg accuracy: 0.99 Avg loss: 0.07
Validation Results - Epoch: 3  Avg accuracy: 0.75 Avg loss: 0.47
run_id: 4ef567df07e64bc5880be4e355a068b5
artifacts: ['model/MLmodel', 'model/conda.yaml', 'model/data']
params: {'cfgname': 'exploration/monai_abc.json', 'CRITERION': "{'CRITERION_TYPE': 'torch.nn.CrossEntropyLoss'}", 'DATASET': "{'DATA_TYPE': 'monai.data.CacheDataset', 'VOLUME_SHAPE': [128, 128, 128], 'DATASET_PATH': 'D:/iitm/IU/IU_04/IU_APWS/BRATS2017/Brats17TrainingData/', 'TRANSFORMS_KEYS': ['img', 'img', 'img', 'img', 'img', ['img', 'label']], 'TRANSFORMS_DICT': 'None'}", 'LOADER': "{'LOADER_TYPE': 'monai.data.DataLoader', 'LOADER_ARGS': {'Train': {'batch_size': 6, 'shuffle': True}, 'Valid': {'batch_size