In [1]:
import mlconfig
import mlflow
import numpy as np
import torch

import src

def manual_seed(seed=0):
    """https://pytorch.org/docs/stable/notes/randomness.html"""
    torch.manual_seed(seed)
    np.random.seed(seed)


def main():
    config_path = './configs/config.yaml'
    config = mlconfig.load(config_path)
    mlflow.log_artifact(config_path)
    mlflow.log_params(config.flat())

    manual_seed()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = config.model().to(device)
    optimizer = config.optimizer(model.parameters())
    scheduler = config.scheduler(optimizer)
    train_loader = config.dataset(root='../data', list_file='train')
    test_loader = config.dataset(root='../data', list_file='test')

    trainer = config.trainer(device, model, optimizer, scheduler, train_loader, test_loader)

    trainer.fit()


if __name__ == '__main__':
    main()

In [None]:
from torchsummary import summary
summary(model, (3, 224, 224))