In [None]:
import argparse
import os
import warnings
import yaml
import sys
sys.path.append('/home/pasti/PycharmProjects/Continual_Nanodet/')
import pytorch_lightning as pl
import torch
import numpy as np
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.accelerators import find_usable_cuda_devices
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
from nanodet.data.collate import naive_collate
from nanodet.data.dataset import build_dataset
from nanodet.evaluator import build_evaluator
from nanodet.trainer.task import TrainingTask
from nanodet.trainer.pseudo_lab_task import PseudoLabelTrainingTask
from nanodet.trainer.dist_task import DistTrainingTask
from torchvision.transforms import ToTensor, ToPILImage
from nanodet.util import (
    NanoDetLightningLogger,
    cfg,
    convert_old_model,
    env_utils,
    load_config,
    load_model_weight,
    mkdir,
)

#Set logger and seed
logger = NanoDetLightningLogger('test')
pl.seed_everything(9)

#Function to create the task configuration file required for training
def create_exp_cfg(yml_path, task):
    all_names = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    #Load the YAML file
    with open(yml_path, 'r') as file:
        temp_cfg = yaml.safe_load(file)
    #Save dir of the model
    temp_cfg['save_dir'] = 'models/task' + str(task)
    #If base task, training and testing classes are the same
    if task == 0:
        temp_cfg['data']['train']['class_names'] = all_names[:15]
        temp_cfg['data']['val']['class_names'] = all_names[:15]
        temp_cfg['model']['arch']['head']['num_classes'] = 20 #15
        #temp_cfg['model']['arch']['aux_head']['num_classes'] = 20 #15
    #Else, training only on task specific class, and testing on all classes
    else:
        temp_cfg['data']['train']['class_names'] = [all_names[14+task]]
        temp_cfg['data']['val']['class_names'] = all_names[:15+task]
        temp_cfg['model']['arch']['head']['num_classes'] = 20#15+task
        #temp_cfg['model']['arch']['aux_head']['num_classes'] = 20#15+task
        temp_cfg['schedule']['load_model'] = 'models/task' + str(task-1) + '/model_last.ckpt'
        
    temp_cfg_name = 'cfg/task' + str(task) + '.yml'
    print(temp_cfg_name)
    #Save the new configuration file
    with open(temp_cfg_name, 'w') as file:
        yaml.safe_dump(temp_cfg, file)

In [None]:
###LEARNING STREAM###
#task 0: train on first 15 classes, test on 15 classes
#task 1: train on class n°16, test on 16 classes
#task 2: train on class n°17, test on 17 classes
#task 3: train on class n°18, test on 18 classes
#task 4: train on class n°19, test on 19 classes
#task 5: train on class n°20, test on 20 classes
#torch.set_printoptions(profile="full")
#opt_epochs = [60, 80, 40, 60 ,40]
for task in range (1, 6):
    logger = NanoDetLightningLogger('run_logs/task'+str(task))
    logger.info("Starting task" + str(task))
    logger.info("Setting up data...")
    #Create the task configuration file based on the task number and load the configuration
    create_exp_cfg('cfg/VOC.yml', task)
    load_config(cfg, 'cfg/task' + str(task) + '.yml')
    #Build datasets and dataloaders based on the task configuration file
    train_dataset = build_dataset(cfg.data.train, "train")
    #If task is not 0, create the replay dataset using the buffer
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.device.batchsize_per_gpu,
        shuffle=True,
        num_workers=cfg.device.workers_per_gpu,
        pin_memory=True,
        collate_fn=naive_collate,
        drop_last=True,
    )
    val_dataset = build_dataset(cfg.data.val, "test")
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.device.batchsize_per_gpu,
        shuffle=False,
        num_workers=cfg.device.workers_per_gpu,
        pin_memory=True,
        collate_fn=naive_collate,
        drop_last=False,
    )
    evaluator = build_evaluator(cfg.evaluator, val_dataset)
    
    #Create the model based on the task configuration file
    logger.info("Creating models")
    if task == 0:
        TrainTask = TrainingTask(cfg, evaluator)
    else:
        TrainTask = PseudoLabelTrainingTask(cfg, evaluator)
        #Load the model weights if task is not 0
        if "load_model" in cfg.schedule:
            ckpt = torch.load(cfg.schedule.load_model)
            load_model_weight(TrainTask.model, ckpt, logger)
            logger.info("Loaded model weight from {}".format(cfg.schedule.load_model))
    
    model_resume_path = (
        os.path.join(cfg.save_dir, "model_last.ckpt")
        if "resume" in cfg.schedule
        else None
    )
    #Set the device to GPU if available
    if cfg.device.gpu_ids == -1:
        logger.info("Using CPU training")
        accelerator, devices, strategy, precision = (
            "cpu",
            None,
            None,
            cfg.device.precision,
        )
    else:
        accelerator, devices, strategy, precision = (
            "gpu",
            cfg.device.gpu_ids,
            None,
            cfg.device.precision,
        )

    if devices and len(devices) > 1:
        strategy = "ddp"
        env_utils.set_multi_processing(distributed=True)

    trainer = pl.Trainer(
        default_root_dir=cfg.save_dir,
        max_epochs=100,
        check_val_every_n_epoch=10,
        accelerator=accelerator,
        devices=[2],
        log_every_n_steps=cfg.log.interval,
        num_sanity_val_steps=0,
        callbacks=[TQDMProgressBar(refresh_rate=0)],# TrainTask.early_stop_callback],
        logger=logger,
        benchmark=cfg.get("cudnn_benchmark", True),
        gradient_clip_val=cfg.get("grad_clip", 0.0),
        strategy=strategy,
        precision=precision,
    )
    trainer.fit(TrainTask, train_dataloader, val_dataloader, ckpt_path=model_resume_path)

## DEBUGGING

In [None]:
task = 0
logger = NanoDetLightningLogger('run_logs/task'+str(task))
logger.info("Starting task" + str(task))
logger.info("Setting up data...")
#Create the task configuration file based on the task number and load the configuration
create_exp_cfg('cfg/VOC.yml', task)
load_config(cfg, 'cfg/task' + str(task) + '.yml')
#Build datasets and dataloaders based on the task configuration file
train_dataset = build_dataset(cfg.data.train, "train")
#If task is not 0, create the replay dataset using the buffer
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=cfg.device.workers_per_gpu,
    pin_memory=True,
    collate_fn=naive_collate,
    drop_last=True,
)
val_dataset = build_dataset(cfg.data.val, "test")
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=cfg.device.workers_per_gpu,
    pin_memory=True,
    collate_fn=naive_collate,
    drop_last=False,
)
evaluator = build_evaluator(cfg.evaluator, val_dataset)

#Create the model based on the task configuration file
logger.info("Creating models")
if task == 0:
    TrainTask = TrainingTask(cfg, evaluator)
else:
    TrainTask = DistTrainingTask(cfg, evaluator)
    #Load the model weights if task is not 0
    if "load_model" in cfg.schedule:
        ckpt = torch.load(cfg.schedule.load_model)
        load_model_weight(TrainTask.model, ckpt, logger)
        load_model_weight(TrainTask.teacher, ckpt, logger)
        logger.info("Loaded model weight from {}".format(cfg.schedule.load_model))


In [None]:
i = 0
for batch in train_dataloader:
    print(batch['img_info']['id'])
    if i == 2:
        break
    batch = batch
    i += 1

In [None]:
gt_boxes = batch["gt_bboxes"]
print(gt_boxes)

In [None]:
img = batch["img"]
print(img)
stud_feat = TrainTask.model.backbone(img)
stud_fpn_feat = TrainTask.model.fpn(stud_feat)
head_out = TrainTask.model.head(stud_fpn_feat)
dets = TrainTask.model.head.post_process(head_out, batch)

In [None]:
import numpy as np
id = 0
for img_id, img_dets in dets.items():
    boxes = []
    labels = []
    for label, bboxes in img_dets.items():
        for bbox in bboxes:
            score = bbox[-1]
            if score > 0.3:
                x0, y0, x1, y1 = [int(i) for i in bbox[:4]]
                boxes.append([x0, y0, x1, y1])
                labels.append(label)
    
    batch["gt_bboxes"][id] = np.append(batch["gt_bboxes"][id],np.array(boxes,dtype=np.float32))
    batch["gt_labels"][id] = np.append(batch["gt_labels"][id],np.array(labels))
    id += 1
print(len(batch["gt_bboxes"]))
        #batch["gt_bboxes"].extend(np.array(all_box, dtype=np.float32))