In [2]:
import os
import numpy as np
import sys
import logging
from time import time
from tensorboardX import SummaryWriter
import argparse

import torch
from torch.optim.lr_scheduler import StepLR
from loss import SimpleLoss, DiscriminativeLoss

# from data.dataset import semantic_dataset
from data.const import NUM_CLASSES
from evaluation.iou import get_batch_iou
from evaluation.angle_diff import calc_angle_diff
from model import get_model
from evaluate import onehot_encoding, eval_iou
from data.vector_map import *

from data.dataset import *


def write_log(writer, ious, title, counter):
    writer.add_scalar(f'{title}/iou', torch.mean(ious[1:]), counter)

    for i, iou in enumerate(ious):
        writer.add_scalar(f'{title}/class_{i}/iou', iou, counter)

def semantic_dataset(version, dataroot, data_conf, bsz, nworkers):
    train_dataset = mbmapdatatest(version='v1.0-trainval', dataroot='/home/qichen/projects/Nuscene-full/', data_conf=data_conf, is_train=True)
    val_dataset = mbmapdatatest(version, dataroot, data_conf, is_train=False,)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bsz, shuffle=True, num_workers=nworkers, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bsz, shuffle=False, num_workers=nworkers)
    return train_loader, val_loader

def train(args):
    if not os.path.exists(args.logdir):
        os.mkdir(args.logdir)
    logging.basicConfig(filename=os.path.join(args.logdir, "results.log"),
                        filemode='w',
                        format='%(asctime)s: %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO)
    logging.getLogger('shapely.geos').setLevel(logging.CRITICAL)

    logger = logging.getLogger()
    logger.addHandler(logging.StreamHandler(sys.stdout))

    data_conf = {
        'num_channels': NUM_CLASSES + 1,
        'image_size': args.image_size,
        'xbound': args.xbound,
        'ybound': args.ybound,
        'zbound': args.zbound,
        'dbound': args.dbound,
        'thickness': args.thickness,
        'angle_class': args.angle_class,
    }

    train_loader, val_loader = semantic_dataset(args.version, args.dataroot, data_conf, args.bsz, args.nworkers)
    model = get_model(args.model, data_conf, args.instance_seg, args.embedding_dim, args.direction_pred, args.angle_class)

    if args.finetune:
        model.load_state_dict(torch.load(args.modelf), strict=False)
        for name, param in model.named_parameters():
            if 'bevencode.up' in name or 'bevencode.layer3' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
    if torch.cuda.is_available():
        model.cuda()

    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    sched = StepLR(opt, 10, 0.1)
    writer = SummaryWriter(logdir=args.logdir)
    if torch.cuda.is_available():
        loss_fn = SimpleLoss(args.pos_weight).cuda()
        embedded_loss_fn = DiscriminativeLoss(args.embedding_dim, args.delta_v, args.delta_d).cuda()
    else:
        loss_fn = SimpleLoss(args.pos_weight)
        embedded_loss_fn = DiscriminativeLoss(args.embedding_dim, args.delta_v, args.delta_d)
    direction_loss_fn = torch.nn.BCELoss(reduction='none')

    model.train()
    counter = 0
    last_idx = len(train_loader) - 1
    best_iou = 0.0

    
    for epoch in range(args.nepochs):
        # for batchi, (imgs, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans,
        #             yaw_pitch_roll, semantic_gt, instance_gt, direction_gt, sample_token) in enumerate(train_loader):
        for batchi, (imgs, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans,
                     yaw_pitch_roll, semantic_gt, instance_gt, direction_gt, sample_token) in enumerate(train_loader):
            # print(enumerate(train_loader))
            t0 = time()
            opt.zero_grad()
            if torch.cuda.is_available():
                semantic, embedding, direction = model(imgs.cuda(), trans.cuda(), rots.cuda(), intrins.cuda(),
                                                    post_trans.cuda(), post_rots.cuda(), lidar_data.cuda(),
                                                    lidar_mask.cuda(), car_trans.cuda(), yaw_pitch_roll.cuda(),sample_token[0])
                semantic_gt = semantic_gt.cuda().float()
                instance_gt = instance_gt.cuda()
            else:
                semantic, embedding, direction = model(imgs, trans, rots, intrins,
                                                    post_trans, post_rots, lidar_data,
                                                    lidar_mask, car_trans, yaw_pitch_roll,sample_token[0])
                semantic_gt = semantic_gt.float()
            seg_loss = loss_fn(semantic, semantic_gt)
            # print(semantic_gt.shape)
            if args.instance_seg:
                var_loss, dist_loss, reg_loss = embedded_loss_fn(embedding, instance_gt)
            else:
                var_loss = 0
                dist_loss = 0
                reg_loss = 0

            if args.direction_pred:
                if torch.cuda.is_available():
                    direction_gt = direction_gt.cuda()
                lane_mask = (1 - direction_gt[:, 0]).unsqueeze(1)
                direction_loss = direction_loss_fn(torch.softmax(direction, 1), direction_gt)
                direction_loss = (direction_loss * lane_mask).sum() / (lane_mask.sum() * direction_loss.shape[1] + 1e-6)
                angle_diff = calc_angle_diff(direction, direction_gt, args.angle_class)
            else:
                direction_loss = 0
                angle_diff = 0

            final_loss = seg_loss * args.scale_seg + var_loss * args.scale_var + dist_loss * args.scale_dist + direction_loss * args.scale_direction
            final_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            opt.step()
            counter += 1
            t1 = time()
            
            if counter % 10 == 0:
                intersects, union = get_batch_iou(onehot_encoding(semantic), semantic_gt)
                iou = intersects / (union + 1e-7)
                logger.info(f"TRAIN[{epoch:>3d}]: [{batchi:>4d}/{last_idx}]    "
                            f"Time: {t1-t0:>7.4f}    "
                            f"Loss: {final_loss.item():>7.4f}    "
                            f"IOU: {np.array2string(iou[1:].numpy(), precision=3, floatmode='fixed')}")

                write_log(writer, iou, 'train', counter)
                writer.add_scalar('train/step_time', t1 - t0, counter)
                writer.add_scalar('train/seg_loss', seg_loss, counter)
                writer.add_scalar('train/var_loss', var_loss, counter)
                writer.add_scalar('train/dist_loss', dist_loss, counter)
                writer.add_scalar('train/reg_loss', reg_loss, counter)
                writer.add_scalar('train/direction_loss', direction_loss, counter)
                writer.add_scalar('train/final_loss', final_loss, counter)
                writer.add_scalar('train/angle_diff', angle_diff, counter)
        
        #Anyway to marrallel the loss and eval ???
        iou = eval_iou(model, val_loader)
        logger.info(f"EVAL[{epoch:>2d}]:    "
                    f"IOU: {np.array2string(iou[1:].numpy(), precision=3, floatmode='fixed')}")

        write_log(writer, iou, 'eval', counter)
        # Commit Jul-15, training logic improvement that only the best model sofar 'll be stored, Compute the overall IOU for the three subcategories
        overall_iou = (iou[1:][0] + iou[1:][1] + iou[1:][2]) / 3
        
        # Check if this is the best model so far
        if overall_iou > best_iou:
            best_model = model.state_dict()  # Save the model state
            best_iou = overall_iou
            best_epoch = epoch  # Save the epoch number
            write_log(writer, iou, 'eval', counter)
            model_name = os.path.join(args.logdir, f"model{epoch}.pt")
            torch.save(best_model, model_name)
            logger.info(f"{model_name} saved")
            # Print the current epoch and IOU value
            logger.info(f'The new best Epoch found !!! epoch num is : {epoch}, val iou: {iou[1:][0] + iou[1:][1] + iou[1:][2]} , overall IOU: {overall_iou:.4f}')
            print(f'The new best Epoch found !!! epoch num is : {epoch}, val iou: {iou[1:][0] + iou[1:][1] + iou[1:][2]} , overall IOU: {overall_iou:.4f}')
            
        model.train()

        sched.step()


class mbmapdatatest(HDMapNetSemanticDataset):
    def __init__(self, version, dataroot, data_conf, is_train):
        super(mbmapdatatest, self).__init__(version, dataroot, data_conf, is_train)
        self.vector_map = mbmap(dataroot, patch_size=self.patch_size, canvas_size=self.canvas_size)

class mbmap(VectorizedLocalMap):
    def __init__(self):
        super(mbmap, self).__init__(self,
                 dataroot,
                 patch_size,
                 canvas_size,
                 line_classes=['road_divider', 'lane_divider'],
                 ped_crossing_classes=['ped_crossing'],
                 contour_classes=['road_segment', 'lane'],
                 sample_dist=1,
                 num_samples=250,
                 padding=False,
                 normalize=False,
                 fixed_num=-1)
        '''
        Args:
            fixed_num = -1 : no fixed num
        '''
        super().__init__()
        self.MAPS = ['boston-seaport', 'singapore-hollandvillage',
                     'singapore-onenorth', 'singapore-queenstown', 'singapore-hollandvillage-mbfake']

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='HDMapNet training.')
    # logging config
    parser.add_argument("--logdir", type=str, default='./runs')

    # nuScenes config
    parser.add_argument('--dataroot', type=str, default='dataset/nuScenes/')
    parser.add_argument('--version', type=str, default='v1.0-mini', choices=['v1.0-trainval', 'v1.0-mini'])

    # model config
    parser.add_argument("--model", type=str, default='HDMapNet_cam')

    # training config
    parser.add_argument("--nepochs", type=int, default=50)
    parser.add_argument("--max_grad_norm", type=float, default=5.0)
    parser.add_argument("--pos_weight", type=float, default=2.13)
    parser.add_argument("--bsz", type=int, default=4)
    parser.add_argument("--nworkers", type=int, default=10)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--weight_decay", type=float, default=1e-7)

    # finetune config
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--modelf', type=str, default=None)

    # data config
    parser.add_argument("--thickness", type=int, default=5)
    parser.add_argument("--image_size", nargs=2, type=int, default=[128, 352])
    parser.add_argument("--xbound", nargs=3, type=float, default=[-30.0, 30.0, 0.15])
    parser.add_argument("--ybound", nargs=3, type=float, default=[-15.0, 15.0, 0.15])
    parser.add_argument("--zbound", nargs=3, type=float, default=[-10.0, 10.0, 20.0])
    parser.add_argument("--dbound", nargs=3, type=float, default=[4.0, 45.0, 1.0])

    # embedding config
    parser.add_argument('--instance_seg', action='store_true')
    parser.add_argument("--embedding_dim", type=int, default=16)
    parser.add_argument("--delta_v", type=float, default=0.5)
    parser.add_argument("--delta_d", type=float, default=3.0)

    # direction config
    parser.add_argument('--direction_pred', action='store_true')
    parser.add_argument('--angle_class', type=int, default=36)

    # loss config
    parser.add_argument("--scale_seg", type=float, default=1.0)
    parser.add_argument("--scale_var", type=float, default=1.0)
    parser.add_argument("--scale_dist", type=float, default=1.0)
    parser.add_argument("--scale_direction", type=float, default=0.2)

    args = parser.parse_args()
    train(args)
    
    """train.py --dataroot /home/qichen/projects/Nuscene-full/ --instance_seg --direction_pred --version v1.0-trainval 
    --model lift_splat --logdir ./log --xbound -50.0 50.0 0.5 --ybound -50.0 50.0 0.5 """


usage: ipykernel_launcher.py [-h] [--logdir LOGDIR] [--dataroot DATAROOT]
                             [--version {v1.0-trainval,v1.0-mini}]
                             [--model MODEL] [--nepochs NEPOCHS]
                             [--max_grad_norm MAX_GRAD_NORM]
                             [--pos_weight POS_WEIGHT] [--bsz BSZ]
                             [--nworkers NWORKERS] [--lr LR]
                             [--weight_decay WEIGHT_DECAY] [--finetune]
                             [--modelf MODELF] [--thickness THICKNESS]
                             [--image_size IMAGE_SIZE IMAGE_SIZE]
                             [--xbound XBOUND XBOUND XBOUND]
                             [--ybound YBOUND YBOUND YBOUND]
                             [--zbound ZBOUND ZBOUND ZBOUND]
                             [--dbound DBOUND DBOUND DBOUND] [--instance_seg]
                             [--embedding_dim EMBEDDING_DIM]
                             [--delta_v DELTA_V] [--delta_d DELTA_D]
     

AttributeError: 'tuple' object has no attribute 'tb_frame'