In [1]:
import os, time, psutil, hydra, torch
os.environ['OMP_NUM_THREADS'] = '20'
os.environ['OPENBLAS_NUM_THREADS'] = '20'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
from modulus import Module
from modulus.models.diffusion import UNet, EDMPrecondSR
from modulus.distributed import DistributedManager
from modulus.metrics.diffusion import RegressionLoss, ResLoss
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper
from modulus.launch.utils import load_checkpoint, save_checkpoint
from datasets.dataset import init_train_valid_datasets_from_config
from helpers.train_helpers import (
    set_patch_shape,
    set_seed,
    configure_cuda_for_consistent_precision,
    compute_num_accumulation_rounds,
    handle_and_clip_gradients,
    is_time_for_periodic_task,
)
### downsample data size
def datacrop(data,size):
    start = data.shape[2]//2 - size//2
    return data[:, :, start:start + size, start:start + size]
###

In [2]:
from omegaconf import OmegaConf
import hydra
from hydra import initialize, compose

## 解析配置文件
with initialize(config_path='conf',version_base= '1.2'):
    mainconf = compose(config_name = 'config_training_diffusion')
    confkey = list(mainconf.keys())
print('============CorrDiff Downscale Conf============')
print(f"Conf list: {confkey}")
for confn in confkey:
    print(f"\nConf Para --> {confn}")
    print(f"<<{mainconf[confn]}>>")
cfg = mainconf
dataset_cfg = OmegaConf.to_container(cfg.dataset)


Conf list: ['dataset', 'model', 'training', 'validation']

Conf Para --> dataset
<<{'type': 'cwb', 'data_path': '/code/2023-01-24-cwb-4years.zarr', 'in_channels': [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19], 'out_channels': [0, 17, 18, 19], 'img_shape_x': 448, 'img_shape_y': 448, 'add_grid': True, 'ds_factor': 4, 'min_path': None, 'max_path': None, 'global_means_path': None, 'global_stds_path': None}>>

Conf Para --> model
<<{'name': 'diffusion', 'hr_mean_conditioning': False}>>

Conf Para --> training
<<{'hp': {'training_duration': 200000000, 'total_batch_size': 256, 'batch_size_per_gpu': 2, 'lr': 0.0002, 'grad_clip_threshold': None, 'lr_decay': 1, 'lr_rampup': 10000000}, 'perf': {'fp_optimizations': 'fp32', 'dataloader_workers': 4, 'songunet_checkpoint_level': 0}, 'io': {'regression_checkpoint_path': 'checkpoints/regression.mdlus', 'print_progress_freq': 1000, 'save_checkpoint_freq': 5000, 'validation_freq': 5000, 'validation_steps': 10}}>>

Conf Para --> validation
<<{'train': False,



In [3]:
###
cfg.dataset.out_channels = [0, 1, 2, 3]
cfg.training.hp.training_duration = 8000000
cfg.training.perf.dataloader_workers = 10
cfg.training.hp.lr_decay = 0.97
cfg.training.hp.lr = 0.00015
cfg.dataset.data_path='/home/sprixin/test/zhangmy/cwa_dataset/cwa_dataset.zarr'
cfg.training.hp.batch_size_per_gpu=16
cfg.training.io.regression_checkpoint_path='/home/sprixin/test/zhangmy/corrdiff/outputs/regression/checkpoints_regression/UNet.0.2000128.mdlus'
###
DistributedManager.initialize()
dist = DistributedManager()

# Initialize loggers
if dist.rank == 0:
    writer = SummaryWriter(log_dir="tensorboard")
logger = PythonLogger("main")  # General python logger
logger0 = RankZeroLoggingWrapper(logger, dist)  # Rank 0 logger

# Resolve and parse configs
OmegaConf.resolve(cfg)
dataset_cfg = OmegaConf.to_container(cfg.dataset)  # TODO needs better handling
if hasattr(cfg, "validation_dataset"):
    validation_dataset_cfg = OmegaConf.to_container(cfg.validation_dataset)
else:
    validation_dataset_cfg = None
fp_optimizations = cfg.training.perf.fp_optimizations
fp16 = fp_optimizations == "fp16"
enable_amp = fp_optimizations.startswith("amp")
amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16
logger.info(f"Saving the outputs in {os.getcwd()}")

# Set seeds and configure CUDA and cuDNN settings to ensure consistent precision
set_seed(dist.rank)
configure_cuda_for_consistent_precision()

[17:59:09 - main - INFO] [94mSaving the outputs in /home/sprixin/test/zhangmy/corrdiff[0m


In [4]:

# Instantiate the dataset
data_loader_kwargs = {
    "pin_memory": True,
    "num_workers": cfg.training.perf.dataloader_workers,
    "prefetch_factor": 2,
}
(
    dataset,
    dataset_iterator,
    validation_dataset,
    validation_dataset_iterator,
) = init_train_valid_datasets_from_config(
    dataset_cfg,
    data_loader_kwargs,
    batch_size=cfg.training.hp.batch_size_per_gpu,
    seed=0,
    validation_dataset_cfg=validation_dataset_cfg,
)


In [5]:
# Parse image configuration & update model args
dataset_channels = len(dataset.input_channels())
img_in_channels = dataset_channels
img_shape = dataset.image_shape()
img_out_channels = len(dataset.output_channels())
# if cfg.model.hr_mean_conditioning:
#     img_in_channels += img_out_channels

In [6]:
# Parse the patch shape
if cfg.model.name == "patched_diffusion":
    patch_shape_x = cfg.training.hp.patch_shape_x
    patch_shape_y = cfg.training.hp.patch_shape_y
else:
    patch_shape_x = None
    patch_shape_y = None
patch_shape = (patch_shape_y, patch_shape_x)
img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
if patch_shape != img_shape:
    logger0.info("Patch-based training enabled")
else:
    logger0.info("Patch-based training disabled")
# interpolate global channel if patch-based model is used
if img_shape[1] != patch_shape[1]:
    img_in_channels += dataset_channels

if cfg.model.name not in ("regression", "diffusion", "patched_diffusion"):
    raise ValueError("Invalid model")
model_args = {  # default parameters for all networks
    "img_out_channels": img_out_channels,
    # "img_resolution": list(img_shape),
    "img_resolution": img_shape[0],
    "use_fp16": fp16,
}
standard_model_cfgs = {  # default parameters for different network types
    "regression": {
        "img_channels": 4,
        "N_grid_channels": 4,
        "embedding_type": "zero",
    },
    "diffusion": {
        "img_channels": img_out_channels,
        "gridtype": "sinusoidal",
        "N_grid_channels": 4,
    },
    "patched_diffusion": {
        "img_channels": img_out_channels,
        "gridtype": "learnable",
        "N_grid_channels": 100,
    },
}

[17:59:13 - main - INFO] [94mPatch-based training disabled[0m


In [7]:
model_args.update(standard_model_cfgs[cfg.model.name])
if hasattr(cfg.model, "model_args"):  # override defaults from config file
    model_args.update(OmegaConf.to_container(cfg.model.model_args))
model_args_para = model_args.copy()
N_grid_channels = model_args_para.pop('N_grid_channels')
if cfg.model.name == "regression":
    embedding_type = model_args_para.pop('embedding_type')
else:
    gridtype = model_args_para.pop('gridtype')
###
cropsize = 64
model_args_para['img_resolution'] = cropsize
###

In [8]:
if cfg.model.name == "regression":
    model = UNet(
        img_in_channels=img_in_channels,# + model_args["N_grid_channels"],
        **model_args_para,
    )
else:  # diffusion or patched diffusion
    model = EDMPrecondSR(
        img_in_channels=img_in_channels,# + model_args["N_grid_channels"],
        **model_args_para,
    )
model.train().requires_grad_(True).to(dist.device)

# Enable distributed data parallel if applicable
if dist.world_size > 1:
    model = DistributedDataParallel(
        model,
        device_ids=[dist.local_rank],
        broadcast_buffers=True,
        output_device=dist.device,
        find_unused_parameters=dist.find_unused_parameters,
    )

# Load the regression checkpoint if applicable
if hasattr(cfg.training.io, "regression_checkpoint_path"):
    regression_checkpoint_path = to_absolute_path(
        cfg.training.io.regression_checkpoint_path
    )
    if not os.path.exists(regression_checkpoint_path):
        raise FileNotFoundError(
            f"Expected a this regression checkpoint but not found: {regression_checkpoint_path}"
        )
    regression_net = Module.from_checkpoint(regression_checkpoint_path)
    regression_net.eval().requires_grad_(False).to(dist.device)
    logger0.success("Loaded the pre-trained regression model")


###
patch_shape = (cropsize, cropsize)
img_shape = (cropsize, cropsize)
###

[17:59:26 - main - INFO] [92mLoaded the pre-trained regression model[0m


In [9]:
# Instantiate the loss function
patch_num = getattr(cfg.training.hp, "patch_num", 1)
if cfg.model.name in ("diffusion", "patched_diffusion"):
    loss_fn = ResLoss(
        regression_net=regression_net,
        img_shape_x=img_shape[1],
        img_shape_y=img_shape[0],
        patch_shape_x=patch_shape[1],
        patch_shape_y=patch_shape[0],
        patch_num=patch_num,
        #hr_mean_conditioning=cfg.model.hr_mean_conditioning,
    )
elif cfg.model.name == "regression":
    loss_fn = RegressionLoss()

# Instantiate the optimizer
optimizer = torch.optim.Adam(
    params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8
)

# Record the current time to measure the duration of subsequent operations.
start_time = time.time()

# Compute the number of required gradient accumulation rounds
# It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size
batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds(
    cfg.training.hp.total_batch_size,
    cfg.training.hp.batch_size_per_gpu,
    dist.world_size,
)
batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu
logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds")


[17:59:28 - main - INFO] [94mUsing 16 gradient accumulation rounds[0m


In [10]:
## Resume training from previous checkpoints if exists
if dist.world_size > 1:
    torch.distributed.barrier()
try:
    cur_nimg = load_checkpoint(
        path=f"checkpoints_{cfg.model.name}",
        models=model,
        optimizer=optimizer,
        device=dist.device,
    )
except:
    cur_nimg = 0



In [11]:
logger0.info(f"Training for {cfg.training.hp.training_duration} images...")
done = False

[17:59:32 - main - INFO] [94mTraining for 8000000 images...[0m


In [12]:
tick_start_nimg = cur_nimg
tick_start_time = time.time()
# Compute & accumulate gradients
optimizer.zero_grad(set_to_none=True)
loss_accum = 0
for _ in range(num_accumulation_rounds):
    img_clean, img_lr, labels = next(dataset_iterator)
    ###
    img_clean = datacrop(img_clean, size = cropsize)
    img_lr = datacrop(img_lr, size = cropsize)
    ###
    img_clean = img_clean.to(dist.device).to(torch.float32).contiguous()
    img_lr = img_lr.to(dist.device).to(torch.float32).contiguous()
    labels = labels.to(dist.device).contiguous()
    with torch.autocast("cuda", dtype=torch.float32, enabled=enable_amp):
        loss = loss_fn(
            net=model,
            img_clean=img_clean,
            img_lr=img_lr,
            labels=labels,
            augment_pipe=None,
        )
    loss = loss.sum() / batch_size_per_gpu
    loss_accum += loss / num_accumulation_rounds
    loss.backward()