# Train File
This is a copycat of original `train.py` file, with simple notebook-oriented adaptation, for faster debug work.

## 1. Some libraries if necessary

In [1]:
# !pip install horovod
# !pip install mpi4py

## 2. Import libraries

In [2]:
# Copyright (c) 2020 Uber Technologies, Inc.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

os.umask(0)
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
import argparse
import numpy as np
import random
import sys
import time
import shutil
from importlib import import_module
from numbers import Number

from tqdm.notebook import tqdm
import torch
from torch.utils.data import Sampler, DataLoader
import horovod.torch as hvd


from torch.utils.data.distributed import DistributedSampler

from utils import Logger, load_pretrain

from mpi4py import MPI

## 3. Initialization

In [3]:
comm = MPI.COMM_WORLD
hvd.init()
torch.cuda.set_device(hvd.local_rank())

root_path = os.path.dirname(os.path.abspath(''))
sys.path.insert(0, root_path)


parser = argparse.ArgumentParser(description="Fuse Detection in Pytorch")
parser.add_argument(
    "-m", "--model", default="lanegcn", type=str, metavar="MODEL", help="model name"
)
parser.add_argument("--eval", action="store_true")
parser.add_argument(
    "--resume", default="", type=str, metavar="RESUME", help="checkpoint path"
)
parser.add_argument(
    "--weight", default="/home/jinwei/wise/LaneGCN/results/lanegcn/17.000.ckpt", type=str, metavar="WEIGHT", help="checkpoint path"
)

_StoreAction(option_strings=['--weight'], dest='weight', nargs=None, const=None, default='/home/jinwei/wise/LaneGCN/results/lanegcn/17.000.ckpt', type=<class 'str'>, choices=None, help='checkpoint path', metavar='WEIGHT')

## 4. Support Functions

In [4]:
def worker_init_fn(pid):
    np_seed = hvd.rank() * 1024 + int(pid)
    np.random.seed(np_seed)
    random_seed = np.random.randint(2 ** 32 - 1)
    random.seed(random_seed)


def train(epoch, config, train_loader, net, loss, post_process, opt, val_loader=None):
    train_loader.sampler.set_epoch(int(epoch))
    net.train()
    num_batches = len(train_loader)
    epoch_per_batch = 1.0 / num_batches
    save_iters = int(np.ceil(config["save_freq"] * num_batches))
#     display_iters = int(
#         config["display_iters"] / (hvd.size() * config["batch_size"])
#     )
    display_iters = 368
    print('display_iters: ' +str(display_iters))
    val_iters = int(config["val_iters"] / (hvd.size() * config["batch_size"]))

    start_time = time.time()
    metrics = dict()
    for i, data in tqdm(enumerate(train_loader), total=num_batches, disable=hvd.rank()):
        epoch += epoch_per_batch
        data = dict(data)

        output = net(data)
        loss_out = loss(output, data)
        post_out = post_process(output, data)
        post_process.append(metrics, loss_out, post_out)

        opt.zero_grad()
        loss_out["loss"].backward()
        lr = opt.step(epoch)

        num_iters = int(np.round(epoch * num_batches))
#         print('num_iters: '+str(num_iters))
        if hvd.rank() == 0 and (
            num_iters % save_iters == 0 or epoch >= config["num_epochs"]
        ):
            save_ckpt(net, opt, config["save_dir"], epoch)

        if num_iters % display_iters == 0:
            dt = time.time() - start_time
            metrics = sync(metrics)
            if hvd.rank() == 0:
                post_process.display(metrics, dt, epoch, lr)
            start_time = time.time()
            metrics = dict()

        if num_iters % val_iters == 0:
            val(config, val_loader, net, loss, post_process, epoch)

        if epoch >= config["num_epochs"]:
            val(config, val_loader, net, loss, post_process, epoch)
            return


def val(config, data_loader, net, loss, post_process, epoch):
    net.eval()

    start_time = time.time()
    metrics = dict()
    for i, data in enumerate(data_loader):
        data = dict(data)
        with torch.no_grad():
            output = net(data)
            loss_out = loss(output, data)
            post_out = post_process(output, data)
            post_process.append(metrics, loss_out, post_out)

    dt = time.time() - start_time
    metrics = sync(metrics)
    if hvd.rank() == 0:
        post_process.display(metrics, dt, epoch)
    net.train()


def save_ckpt(net, opt, save_dir, epoch):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    state_dict = net.state_dict()
    for key in state_dict.keys():
        state_dict[key] = state_dict[key].cpu()

    save_name = "%3.3f.ckpt" % epoch
    torch.save(
        {"epoch": epoch, "state_dict": state_dict, "opt_state": opt.opt.state_dict()},
        os.path.join(save_dir, save_name),
    )


def sync(data):
    data_list = comm.allgather(data)
    data = dict()
    for key in data_list[0]:
        if isinstance(data_list[0][key], list):
            data[key] = []
        else:
            data[key] = 0
        for i in range(len(data_list)):
            data[key] += data_list[i][key]
    return data

## 5. Main program

In [5]:
seed = hvd.rank()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Import all settings for experiment.
args = parser.parse_args(args=[])
# print(args)
model = import_module(args.model)
config, Dataset, collate_fn, net, loss, post_process, opt = model.get_model()
# print(net)

if config["horovod"]:
    opt.opt = hvd.DistributedOptimizer(
        opt.opt, named_parameters=net.named_parameters()
    )

if args.resume or args.weight:
    ckpt_path = args.resume or args.weight
    if not os.path.isabs(ckpt_path):
        ckpt_path = os.path.join(config["save_dir"], ckpt_path)
    ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    load_pretrain(net, ckpt["state_dict"])
    if args.resume:
        config["epoch"] = ckpt["epoch"]
        opt.load_state_dict(ckpt["opt_state"])

if args.eval:
    # Data loader for evaluation
    dataset = Dataset(config["val_split"], config, train=False)
    val_sampler = DistributedSampler(
        dataset, num_replicas=hvd.size(), rank=hvd.rank()
    )
    val_loader = DataLoader(
        dataset,
        batch_size=config["val_batch_size"],
        num_workers=config["val_workers"],
        sampler=val_sampler,
        collate_fn=collate_fn,
        pin_memory=True,
    )

    hvd.broadcast_parameters(net.state_dict(), root_rank=0)
    val(config, val_loader, net, loss, post_process, 999)
else:
    # Create log and copy all code
    save_dir = config["save_dir"]
    log = os.path.join(save_dir, "log")
    if hvd.rank() == 0:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        sys.stdout = Logger(log)

        src_dirs = [root_path]
        dst_dirs = [os.path.join(save_dir, "files")]
        for src_dir, dst_dir in zip(src_dirs, dst_dirs):
            files = [f for f in os.listdir(src_dir) if f.endswith(".py")]
            if not os.path.exists(dst_dir):
                os.makedirs(dst_dir)
            for f in files:
                shutil.copy(os.path.join(src_dir, f), os.path.join(dst_dir, f))

    # Data loader for training
    dataset = Dataset(config["train_split"], config, train=True)
    train_sampler = DistributedSampler(
        dataset, num_replicas=hvd.size(), rank=hvd.rank()
    )
    train_loader = DataLoader(
        dataset,
        batch_size=config["batch_size"],
        num_workers=config["workers"],
        sampler=train_sampler,
        collate_fn=collate_fn,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        drop_last=True,
    )
    print("Data loader for training done.")

    # Data loader for evaluation
    dataset = Dataset(config["val_split"], config, train=False)
    val_sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank())
    val_loader = DataLoader(
        dataset,
        batch_size=config["val_batch_size"],
        num_workers=config["val_workers"],
        sampler=val_sampler,
        collate_fn=collate_fn,
        pin_memory=True,
    )
    print("Data loader for evaluation done.")


    hvd.broadcast_parameters(net.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(opt.opt, root_rank=0)

    epoch = config["epoch"]
    print("epoch number: " +str(epoch))
    remaining_epochs = int(np.ceil(config["num_epochs"] - epoch))
    print("remaining epoch number: " +str(remaining_epochs))
    for i in range(remaining_epochs):
        print('epoch number: '+str(i))
        train(epoch + i, config, train_loader, net, loss, post_process, opt, val_loader)

Data loader for training done.


  0%|          | 0/368 [00:00<?, ?it/s]

Data loader for evaluation done.
epoch number: 0
remaining epoch number: 36
epoch number: 0
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 1.000, lr 0.00100, time 166.63
loss 6572.7516 0.1974 6572.5542, ade1 293.1707, fde1 528.3082, ade 279.1145, fde 478.5155

epoch number: 1
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 2.000, lr 0.00100, time 167.70
loss 8378.1532 0.1986 8377.9546, ade1 291.6299, fde1 528.9738, ade 283.4040, fde 495.7129

epoch number: 2
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 3.000, lr 0.00100, time 167.43
loss 7349.2120 0.1996 7349.0123, ade1 285.8035, fde1 504.3552, ade 281.1135, fde 486.3497

epoch number: 3
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 4.000, lr 0.00100, time 167.47
loss 7756.8565 0.1975 7756.6590, ade1 285.6861, fde1 507.1596, ade 281.2123, fde 482.3486

epoch number: 4
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 5.000, lr 0.00100, time 167.43
loss 7199.5845 0.1990 7199.3855, ade1 284.5266, fde1 506.7756, ade 277.4088, fde 463.7691

epoch number: 5
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 6.000, lr 0.00100, time 167.40
loss 7436.7854 0.2027 7436.5827, ade1 284.3772, fde1 505.6396, ade 277.3804, fde 455.2974

epoch number: 6
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 7.000, lr 0.00100, time 167.83
loss 8878.3184 0.1977 8878.1207, ade1 279.6692, fde1 475.3388, ade 273.5251, fde 440.5836

epoch number: 7
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 8.000, lr 0.00100, time 168.15
loss 8754.6989 0.1981 8754.5007, ade1 277.2729, fde1 469.9166, ade 266.6991, fde 419.0268

epoch number: 8
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 9.000, lr 0.00100, time 163.54
loss 8291.4486 0.1984 8291.2502, ade1 285.4881, fde1 500.7003, ade 274.6588, fde 443.8308

epoch number: 9
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 10.000, lr 0.00100, time 162.67
loss 7566.4558 0.1975 7566.2583, ade1 282.1814, fde1 493.2731, ade 271.4242, fde 441.0500

epoch number: 10
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 11.000, lr 0.00100, time 162.94
loss 7425.2403 0.1976 7425.0427, ade1 274.9346, fde1 470.3916, ade 263.3489, fde 408.7680

epoch number: 11
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 12.000, lr 0.00100, time 162.75
loss 7025.4481 0.2045 7025.2436, ade1 275.8184, fde1 463.0755, ade 266.5092, fde 419.9170

epoch number: 12
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 13.000, lr 0.00100, time 162.41
loss 8390.2170 0.2001 8390.0169, ade1 269.7005, fde1 438.1502, ade 261.8715, fde 407.6747

epoch number: 13
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 14.000, lr 0.00100, time 163.54
loss 8288.9389 0.2021 8288.7368, ade1 270.3782, fde1 439.5570, ade 263.0730, fde 411.6027

epoch number: 14
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 15.000, lr 0.00100, time 162.31
loss 7517.0940 0.2078 7516.8861, ade1 267.3867, fde1 436.5807, ade 261.2051, fde 401.4845

epoch number: 15
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 16.000, lr 0.00100, time 162.87
loss 6116.9390 0.1988 6116.7402, ade1 264.0275, fde1 423.1031, ade 257.9866, fde 390.2590

epoch number: 16
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 17.000, lr 0.00100, time 163.42
loss 7855.4706 0.2002 7855.2704, ade1 264.8955, fde1 424.4250, ade 259.0879, fde 394.3051

epoch number: 17
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 18.000, lr 0.00100, time 162.98
loss 7026.3876 0.1952 7026.1924, ade1 267.4177, fde1 428.1321, ade 263.9785, fde 396.0187

epoch number: 18
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 19.000, lr 0.00100, time 162.49
loss 7646.7830 0.2006 7646.5824, ade1 266.5034, fde1 425.3403, ade 258.6068, fde 386.9594

epoch number: 19
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 20.000, lr 0.00100, time 163.53
loss 7229.5251 0.1995 7229.3256, ade1 255.4402, fde1 384.6261, ade 247.9569, fde 345.7057

epoch number: 20
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 21.000, lr 0.00100, time 162.30
loss 8675.7275 0.1958 8675.5316, ade1 258.6728, fde1 399.1656, ade 250.6961, fde 356.0137

epoch number: 21
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 22.000, lr 0.00100, time 162.89
loss 6772.6772 0.1995 6772.4776, ade1 255.1154, fde1 397.7869, ade 247.8518, fde 351.0736

epoch number: 22
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 23.000, lr 0.00100, time 162.10
loss 7120.0486 0.2020 7119.8466, ade1 258.2842, fde1 408.4549, ade 248.8359, fde 360.1351

epoch number: 23
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 24.000, lr 0.00100, time 162.33
loss 8792.2872 0.2002 8792.0870, ade1 269.6835, fde1 437.9901, ade 256.3692, fde 389.5270

epoch number: 24
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 25.000, lr 0.00100, time 162.46
loss 7610.6559 0.2003 7610.4556, ade1 259.8780, fde1 401.6173, ade 248.7979, fde 364.2578

epoch number: 25
display_iters: 368


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 27.000, lr 0.00100, time 163.04
loss 6806.2784 0.2005 6806.0779, ade1 250.8220, fde1 377.5180, ade 240.5563, fde 343.5787

epoch number: 27
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 28.000, lr 0.00100, time 163.30
loss 7209.1291 0.1988 7208.9303, ade1 265.2723, fde1 430.0607, ade 257.3038, fde 395.9982

epoch number: 28
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 29.000, lr 0.00100, time 162.80
loss 6639.8843 0.2035 6639.6808, ade1 279.1158, fde1 471.8730, ade 271.9838, fde 438.5145

epoch number: 29
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 30.000, lr 0.00100, time 163.11
loss 7949.1572 0.1982 7948.9590, ade1 261.3596, fde1 414.3690, ade 251.1667, fde 374.8115

epoch number: 30
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 31.000, lr 0.00100, time 163.31
loss 7296.2110 0.2015 7296.0095, ade1 249.7135, fde1 380.0193, ade 240.9956, fde 338.6531

epoch number: 31
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 32.000, lr 0.00100, time 163.33
loss 6153.6842 0.1946 6153.4896, ade1 258.8593, fde1 401.5928, ade 244.9024, fde 353.9257

epoch number: 32
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 33.000, lr 0.00010, time 162.60
loss 7459.2869 0.1896 7459.0972, ade1 252.6957, fde1 381.5049, ade 236.5711, fde 333.2677

epoch number: 33
display_iters: 368


  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 34.000, lr 0.00010, time 162.91
loss 7304.5381 0.1880 7304.3501, ade1 251.7997, fde1 380.2095, ade 236.2893, fde 332.3808

epoch number: 34
display_iters: 368
************************* Validation, time 112.40 *************************
loss 5534.2417 0.1936 5534.0481, ade1 247.6481, fde1 376.9258, ade 234.8548, fde 328.7293



  0%|          | 0/368 [00:00<?, ?it/s]

Epoch 35.000, lr 0.00010, time 275.97
loss 6401.7595 0.1914 6401.5682, ade1 250.5596, fde1 378.2934, ade 235.8819, fde 330.7509

epoch number: 35
display_iters: 368
Epoch 36.000, lr 0.00010, time 163.08
loss 6141.0460 0.1923 6140.8537, ade1 245.3643, fde1 371.6467, ade 234.5219, fde 329.0718

************************* Validation, time 119.83 *************************
loss 5509.7776 0.1925 5509.5851, ade1 245.1316, fde1 371.2634, ade 233.8586, fde 327.8026

