## **Training Pipeline**
This notebook will be used to train the diffusion model using the defined train script. 

In [None]:
!nvidia-smi


In [None]:
import sys
sys.path.append("/home/jupyter-group3/reconstruction/reconstruction-deep-network")

In [None]:
import numpy as np
import os
import torch
from torch.utils.data import Subset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from argparse import ArgumentParser
import yaml

import reconstruction_deep_network
from reconstruction_deep_network.data_loader.custom_loader import CustomDataLoader
from reconstruction_deep_network.trainer.trainer import ModelTrainer

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
torch.set_float32_matmul_precision('medium')

In [None]:
torch.backends.cudnn.benchmark =  True
torch.backends.cudnn.enabled =  True

In [None]:
module_path = reconstruction_deep_network.__path__[0]
root_dir = os.path.dirname(module_path)
data_path = os.path.join(root_dir, "data", "v1")
text_embeddings = os.path.join(data_path, "text_embeddings")
null_embeddings = os.path.join(text_embeddings, "null")
if not os.path.isdir(null_embeddings):
    os.makedirs(null_embeddings)
trainer_config_path = os.path.join(module_path, "trainer", "trainer_config.yaml")

In [None]:
def parse_args(args=None):

    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument("--main_config_path", type = str, dest = "main_config_path")
    parser.add_argument("--train_metadata", type = str, dest = "train_metadata")
    parser.add_argument("--val_metadata", type = str, dest = "val_metadata")
    parser.add_argument("--num_workers", type = int, dest = "num_workers")
    parser.add_argument("--exp_name", dest = "exp_name", type = str)
    parser.add_argument("--batch_size", dest = "batch_size", type = int)
    parser.add_argument("--n_epochs", dest = "n_epochs", type = int)
    parser.add_argument("--learning_rate", dest = "learning_rate", type = float)
    parser.add_argument("--ckpt_path", dest = "ckpt_path", type = str)

    args = pl.Trainer.parse_argparser(parser.parse_args())
    return args

In [None]:
def main(args):

    config_file_path = args.main_config_path
    with open(config_file_path, 'r') as f:
        config = yaml.load(f, Loader = yaml.FullLoader)
    
    config["train"]["learning_rate"] = args.learning_rate
    config["train"]["max_epochs"] = args.n_epochs
    config["train"]["batch_size"] = args.batch_size

    train_dataset = CustomDataLoader(mode = "train", debug = False, metadata_filename = args.train_metadata, num_views = args.num_views)
    train_indices = list(range(0, 100))
    train_dataset = Subset(train_dataset, train_indices)
    print(f"Size of train dataset: {len(train_dataset)}")
#     val_dataset = CustomDataLoader(mode = "val", debug = False, metadata_filename = args.val_metadata, num_views = args.num_views)    
    
    train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size = config["train"]["batch_size"],
                    shuffle = True,
                    num_workers = args.num_workers,
                    drop_last = True)
    
#     val_loader = torch.utils.data.DataLoader(
#                     val_dataset,
#                     batch_size = 1,
#                     shuffle = False,
#                     num_workers = args.num_workers,
#                     drop_last = False)
    
#     torch.cuda.empty_cache()
    model_trainer = ModelTrainer()

        
    print(f"Training for {model_trainer.max_epochs} epochs...")
    print(f"Diffusion Training timesteps: {model_trainer.scheduler.num_train_timesteps}")
    
    
    ckpt_path = None if args.ckpt_path == "None" else args.ckpt_path
    if ckpt_path is not None:
        model_trainer.load_state_dict(torch.load(args.ckpt_path, map_location='cpu')[
            'state_dict'], strict=False)

    checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="train_loss",
                                          mode="min", save_last=1,
                                          filename='epoch={epoch}-loss={train_loss:.4f}')
    

#     logger = TensorBoardLogger(
#         save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
    
    training_pipeline = pl.Trainer.from_argparse_args(
        args,
        callbacks=[checkpoint_callback],
#         limit_train_batches=1,
#         strategy = "ddp_notebook",
        amp_backend="apex",
        amp_level="O2"
        )
    
    training_pipeline.fit(model_trainer, train_loader)

In [None]:
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--main_config_path", type = str, dest = "main_config_path")
parser.add_argument("--train_metadata", type = str, dest = "train_metadata")
parser.add_argument("--val_metadata", type = str, dest = "val_metadata")
parser.add_argument("--num_views", type = int, dest = "num_views")
parser.add_argument("--num_workers", type = int, dest = "num_workers")
parser.add_argument("--exp_name", dest = "exp_name", type = str)
parser.add_argument("--batch_size", dest = "batch_size", type = int)
parser.add_argument("--n_epochs", dest = "n_epochs", type = int)
parser.add_argument("--learning_rate", dest = "learning_rate", type = float)
parser.add_argument("--ckpt_path", dest = "ckpt_path", type = str)

args = pl.Trainer.parse_argparser(parser.parse_args([
    "--main_config_path", trainer_config_path,
    "--train_metadata", "ir-20231129-train-split",
    "--val_metadata", "ir-20231129-val-split",
    "--num_views", "1",
    "--num_workers", "12",
    "--exp_name", "ir-training-pipeline-test",
    "--batch_size", "1",
    "--n_epochs", "10",
    "--learning_rate", "0.0002",
    "--ckpt_path", "None"
]))

In [None]:
## set devices and epochs
args.accelerator = "gpu"
args.devices = 1
args.max_epochs = 30
args.num_sanity_val_steps=0 
# args.ckpt_path = "/home/jupyter-group3/reconstruction/reconstruction-deep-network/reconstruction_deep_network/notebooks/pipelines/lightning_logs/version_17/checkpoints/epoch=epoch=0-loss=train_loss=0.1843.ckpt"

In [None]:
args.num_views

In [None]:
main(args)