In [1]:
from pathlib import Path
from typing import List
import sys
sys.path.append('.')
import hydra
import numpy as np
import torch
import omegaconf
import pytorch_lightning as pl
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything, Callback
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from pytorch_lightning.loggers import WandbLogger

from diffcsp.common.utils import log_hyperparameters, PROJECT_ROOT

import wandb

In [3]:
def build_callbacks(cfg: DictConfig) -> List[Callback]:
    callbacks: List[Callback] = []

    if "lr_monitor" in cfg.logging:
        hydra.utils.log.info("Adding callback <LearningRateMonitor>")
        callbacks.append(
            LearningRateMonitor(
                logging_interval=cfg.logging.lr_monitor.logging_interval,
                log_momentum=cfg.logging.lr_monitor.log_momentum,
            )
        )

    if "early_stopping" in cfg.train:
        hydra.utils.log.info("Adding callback <EarlyStopping>")
        callbacks.append(
            EarlyStopping(
                monitor=cfg.train.monitor_metric,
                mode=cfg.train.monitor_metric_mode,
                patience=cfg.train.early_stopping.patience,
                verbose=cfg.train.early_stopping.verbose,
            )
        )

    if "model_checkpoints" in cfg.train:
        hydra.utils.log.info("Adding callback <ModelCheckpoint>")
        callbacks.append(
            ModelCheckpoint(
                dirpath=Path(HydraConfig.get().run.dir),
                monitor=cfg.train.monitor_metric,
                mode=cfg.train.monitor_metric_mode,
                save_top_k=cfg.train.model_checkpoints.save_top_k,
                verbose=cfg.train.model_checkpoints.verbose,
                save_last=cfg.train.model_checkpoints.save_last,
            )
        )

    return callbacks

In [2]:
PROJECT_ROOT

PosixPath('/blue/hennig/pawanprakash/diffusion/MaterialsDiffusion')

In [5]:
cfg = None  # Declare cfg globally


@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
def main(config: DictConfig):
    global cfg
    cfg = config
    run(cfg)
    return cfg

def run(cfg: DictConfig):
    # Use cfg as needed
    return cfg

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


In [6]:
type(cfg)

NoneType

In [3]:
from diffcsp.pl_data.dataset import CrystDataset

  from .autonotebook import tqdm as notebook_tqdm
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


In [16]:
dataset = CrystDataset('vall',path = '/blue/hennig/pawanprakash/diffusion/MaterialsDiffusion/data/mp_20/val.csv',
                       prop='band_gap',niggli=True,primitive=True,graph_method='crystalnn',preprocess_workers=1,
                       lattice_scale_method='scale_length',save_path='/blue/hennig/pawanprakash/diffusion/MaterialsDiffusion/data/mp_20/val_ori.pt',
                       tolerance=0.1,use_space_group=False,use_pos_index=False)

In [21]:
itd = iter(dataset)

In [22]:
dataset.prop

'band_gap'

In [23]:
next(itd)

Data(edge_index=[2, 32], y=[1, 1], frac_coords=[4, 3], atom_types=[4], lengths=[1, 3], angles=[1, 3], to_jimages=[32, 3], num_atoms=4, num_bonds=32, num_nodes=4)

In [25]:
next(next(itd))

TypeError: 'Data' object is not an iterator

In [19]:
from diffcsp.common.data_utils import get_scaler_from_data_list


In [20]:
dataset.scaler = get_scaler_from_data_list(dataset.cached_data,key='band_gap')

  X = torch.tensor(X, dtype=torch.float)


In [15]:
from torch_geometric.data import DataLoader
dataloader = DataLoader(dataset,batch_size = 10)



In [18]:
dataset.scaler

StandardScalerTorch(means: inf, stds: nan)

In [31]:
dataset.prop

'tcad'

In [24]:
for i, d in enumerate(dataloader):
    print(i,d)

0 DataBatch(edge_index=[2, 410], y=[10, 1], frac_coords=[46, 3], atom_types=[46], lengths=[10, 3], angles=[10, 3], to_jimages=[410, 3], num_atoms=[10], num_bonds=[10], num_nodes=46, batch=[46], ptr=[11])
1 DataBatch(edge_index=[2, 406], y=[10, 1], frac_coords=[42, 3], atom_types=[42], lengths=[10, 3], angles=[10, 3], to_jimages=[406, 3], num_atoms=[10], num_bonds=[10], num_nodes=42, batch=[42], ptr=[11])
2 DataBatch(edge_index=[2, 478], y=[10, 1], frac_coords=[48, 3], atom_types=[48], lengths=[10, 3], angles=[10, 3], to_jimages=[478, 3], num_atoms=[10], num_bonds=[10], num_nodes=48, batch=[48], ptr=[11])
3 DataBatch(edge_index=[2, 448], y=[10, 1], frac_coords=[45, 3], atom_types=[45], lengths=[10, 3], angles=[10, 3], to_jimages=[448, 3], num_atoms=[10], num_bonds=[10], num_nodes=45, batch=[45], ptr=[11])
4 DataBatch(edge_index=[2, 488], y=[10, 1], frac_coords=[48, 3], atom_types=[48], lengths=[10, 3], angles=[10, 3], to_jimages=[488, 3], num_atoms=[10], num_bonds=[10], num_nodes=48, ba

so dataloader has the data now

In [19]:
from diffcsp.pl_modules.diffusion_w_type import CSPDiffusion, SinusoidalTimeEmbeddings

In [27]:
from diffcsp.common.utils import log_hyperparameters

In [44]:
import diffcsp.pl_modules.cspnet

In [45]:
from diffcsp.pl_modules.cspnet import CSPNet

In [47]:
cspdecoder = CSPNet

In [48]:
import hydra

In [49]:
hydra.utils.instantiate

<function hydra._internal.instantiate._instantiate2.instantiate(config: Any, *args: Any, **kwargs: Any) -> Any>

In [50]:
import math, copy

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from typing import Any, Dict

import hydra
import omegaconf
import pytorch_lightning as pl
from torch_scatter import scatter
from torch_scatter.composite import scatter_softmax
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from tqdm import tqdm

from diffcsp.common.utils import PROJECT_ROOT
from diffcsp.common.data_utils import (
    EPSILON, cart_to_frac_coords, mard, lengths_angles_to_volume, lattice_params_to_matrix_torch,
    frac_to_cart_coords, min_distance_sqr_pbc)

from diffcsp.pl_modules.diff_utils import d_log_p_wrapped_normal

MAX_ATOMIC_NUM=100


class BaseModule(pl.LightningModule):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()
        # populate self.hparams with args and kwargs automagically!
        self.save_hyperparameters()
        if hasattr(self.hparams, "model"):
            self._hparams = self.hparams.model


    def configure_optimizers(self):
        opt = hydra.utils.instantiate(
            self.hparams.optim.optimizer, params=self.parameters(), _convert_="partial"
        )
        if not self.hparams.optim.use_lr_scheduler:
            return [opt]
        scheduler = hydra.utils.instantiate(
            self.hparams.optim.lr_scheduler, optimizer=opt
        )
        return {"optimizer": opt, "lr_scheduler": scheduler, "monitor": "val_loss"}


### Model definition

class SinusoidalTimeEmbeddings(nn.Module):
    """ Attention is all you need. """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class CSPDiffusion(BaseModule):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.decoder = hydra.utils.instantiate(cspdecoder, latent_dim = self.hparams.latent_dim + self.hparams.time_dim, pred_type = True, smooth = True)
        self.beta_scheduler = hydra.utils.instantiate(self.hparams.beta_scheduler)
        self.sigma_scheduler = hydra.utils.instantiate(self.hparams.sigma_scheduler)
        self.time_dim = self.hparams.time_dim
        self.time_embedding = SinusoidalTimeEmbeddings(self.time_dim)
        self.keep_lattice = self.hparams.cost_lattice < 1e-5
        self.keep_coords = self.hparams.cost_coord < 1e-5
        self.p_uncond = self.hparams.p_uncond
        self.guide_w = self.hparams.guide_w



    def forward(self, batch):

        batch_size = batch.num_graphs
        times = self.beta_scheduler.uniform_sample_t(batch_size, self.device)
        time_emb = self.time_embedding(times)

        alphas_cumprod = self.beta_scheduler.alphas_cumprod[times]
        beta = self.beta_scheduler.betas[times]

        c0 = torch.sqrt(alphas_cumprod)
        c1 = torch.sqrt(1. - alphas_cumprod)

        sigmas = self.sigma_scheduler.sigmas[times]
        sigmas_norm = self.sigma_scheduler.sigmas_norm[times]

        lattices = lattice_params_to_matrix_torch(batch.lengths, batch.angles)
        frac_coords = batch.frac_coords

        rand_l, rand_x = torch.randn_like(lattices), torch.randn_like(frac_coords)

        input_lattice = c0[:, None, None] * lattices + c1[:, None, None] * rand_l
        sigmas_per_atom = sigmas.repeat_interleave(batch.num_atoms)[:, None]
        sigmas_norm_per_atom = sigmas_norm.repeat_interleave(batch.num_atoms)[:, None]
        input_frac_coords = (frac_coords + sigmas_per_atom * rand_x) % 1.

        gt_atom_types_onehot = F.one_hot(batch.atom_types - 1, num_classes=MAX_ATOMIC_NUM).float()

        rand_t = torch.randn_like(gt_atom_types_onehot)

        atom_type_probs = (c0.repeat_interleave(batch.num_atoms)[:, None] * gt_atom_types_onehot + c1.repeat_interleave(batch.num_atoms)[:, None] * rand_t)

        if self.keep_coords:
            input_frac_coords = frac_coords

        if self.keep_lattice:
            input_lattice = lattices


        # Need to apply property here, but before need to bernoulli sample
        property_indicator = torch.bernoulli(torch.ones(batch_size)*(1.-self.p_uncond))
        property_indicator = property_indicator.to(self.device)
        property_train = torch.squeeze(batch.y)

        pred_l, pred_x, pred_t = self.decoder(time_emb, atom_type_probs, input_frac_coords, input_lattice, batch.num_atoms, batch.batch, property_train, property_indicator)

        tar_x = d_log_p_wrapped_normal(sigmas_per_atom * rand_x, sigmas_per_atom) / torch.sqrt(sigmas_norm_per_atom)

        loss_lattice = F.mse_loss(pred_l, rand_l)
        loss_coord = F.mse_loss(pred_x, tar_x)
        loss_type = F.mse_loss(pred_t, rand_t)


        loss = (
            self.hparams.cost_lattice * loss_lattice +
            self.hparams.cost_coord * loss_coord + 
            self.hparams.cost_type * loss_type)

        return {
            'loss' : loss,
            'loss_lattice' : loss_lattice,
            'loss_coord' : loss_coord,
            'loss_type' : loss_type
        }

    @torch.no_grad()
    def sample(self, batch, band_gap, diff_ratio = 1.0, step_lr = 1e-5):


        batch_size = batch.num_graphs

        l_T, x_T = torch.randn([batch_size, 3, 3]).to(self.device), torch.rand([batch.num_nodes, 3]).to(self.device)

        t_T = torch.randn([batch.num_nodes, MAX_ATOMIC_NUM]).to(self.device)


        if self.keep_coords:
            x_T = batch.frac_coords

        if self.keep_lattice:
            l_T = lattice_params_to_matrix_torch(batch.lengths, batch.angles)
        

        traj = {self.beta_scheduler.timesteps : {
            'num_atoms' : batch.num_atoms,
            'atom_types' : t_T,
            'frac_coords' : x_T % 1.,
            'lattices' : l_T
        }}

        for t in tqdm(range(self.beta_scheduler.timesteps, 0, -1)):

            times = torch.full((batch_size, ), t, device = self.device)

            time_emb = self.time_embedding(times)
            
            alphas = self.beta_scheduler.alphas[t]
            alphas_cumprod = self.beta_scheduler.alphas_cumprod[t]

            sigmas = self.beta_scheduler.sigmas[t]
            sigma_x = self.sigma_scheduler.sigmas[t]
            sigma_norm = self.sigma_scheduler.sigmas_norm[t]

            c0 = 1.0 / torch.sqrt(alphas)
            c1 = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)

            x_t = traj[t]['frac_coords']
            l_t = traj[t]['lattices']
            t_t = traj[t]['atom_types']

            if self.keep_coords:
                x_t = x_T

            if self.keep_lattice:
                l_t = l_T

            # Corrector

            rand_l = torch.randn_like(l_T) if t > 1 else torch.zeros_like(l_T)
            rand_t = torch.randn_like(t_T) if t > 1 else torch.zeros_like(t_T)
            rand_x = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)

            step_size = step_lr * (sigma_x / self.sigma_scheduler.sigma_begin) ** 2
            std_x = torch.sqrt(2 * step_size)
            # with context
            pred_l1, pred_x1, pred_t1 = self.decoder(time_emb, t_t, x_t, l_t, batch.num_atoms, batch.batch, band_gap, torch.ones(batch_size).to(self.device))
            pred_x1 = pred_x1 * torch.sqrt(sigma_norm)
            # without context
            pred_l2, pred_x2, pred_t2 = self.decoder(time_emb, t_t, x_t, l_t, batch.num_atoms, batch.batch, band_gap, torch.zeros(batch_size).to(self.device))
            pred_x2 = pred_x2 * torch.sqrt(sigma_norm)
            
            # weighted score
            pred_x = (1+self.guide_w)*pred_x1 - self.guide_w*pred_x2
            x_t_minus_05 = x_t - step_size * pred_x + std_x * rand_x if not self.keep_coords else x_t

            l_t_minus_05 = l_t

            t_t_minus_05 = t_t


            # Predictor

            rand_l = torch.randn_like(l_T) if t > 1 else torch.zeros_like(l_T)
            rand_t = torch.randn_like(t_T) if t > 1 else torch.zeros_like(t_T)
            rand_x = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)

            adjacent_sigma_x = self.sigma_scheduler.sigmas[t-1] 
            step_size = (sigma_x ** 2 - adjacent_sigma_x ** 2)
            std_x = torch.sqrt((adjacent_sigma_x ** 2 * (sigma_x ** 2 - adjacent_sigma_x ** 2)) / (sigma_x ** 2))   

            # with context
            pred_l1, pred_x1, pred_t1 = self.decoder(time_emb, t_t, x_t, l_t, batch.num_atoms, batch.batch, band_gap, torch.ones(batch_size).to(self.device))
            pred_x1 = pred_x1 * torch.sqrt(sigma_norm)
            # without context
            pred_l2, pred_x2, pred_t2 = self.decoder(time_emb, t_t, x_t, l_t, batch.num_atoms, batch.batch, band_gap, torch.zeros(batch_size).to(self.device))
            pred_x2 = pred_x2 * torch.sqrt(sigma_norm)

            ## weighted score
            pred_x = (1+self.guide_w)*pred_x1 - self.guide_w*pred_x2
            pred_l = (1+self.guide_w)*pred_l1 - self.guide_w*pred_l2
            pred_t = (1+self.guide_w)*pred_t1 - self.guide_w*pred_t2

            x_t_minus_1 = x_t_minus_05 - step_size * pred_x + std_x * rand_x if not self.keep_coords else x_t

            l_t_minus_1 = c0 * (l_t_minus_05 - c1 * pred_l) + sigmas * rand_l if not self.keep_lattice else l_t

            t_t_minus_1 = c0 * (t_t_minus_05 - c1 * pred_t) + sigmas * rand_t

            traj[t - 1] = {
                'num_atoms' : batch.num_atoms,
                'atom_types' : t_t_minus_1,
                'frac_coords' : x_t_minus_1 % 1.,
                'lattices' : l_t_minus_1              
            }

        traj_stack = {
            'num_atoms' : batch.num_atoms,
            'atom_types' : torch.stack([traj[i]['atom_types'] for i in range(self.beta_scheduler.timesteps, -1, -1)]).argmax(dim=-1) + 1,
            'all_frac_coords' : torch.stack([traj[i]['frac_coords'] for i in range(self.beta_scheduler.timesteps, -1, -1)]),
            'all_lattices' : torch.stack([traj[i]['lattices'] for i in range(self.beta_scheduler.timesteps, -1, -1)])
        }

        return traj[0], traj_stack



    def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:

        output_dict = self(batch)

        loss_lattice = output_dict['loss_lattice']
        loss_coord = output_dict['loss_coord']
        loss_type = output_dict['loss_type']
        loss = output_dict['loss']


        self.log_dict(
            {'train_loss': loss,
            'lattice_loss': loss_lattice,
            'coord_loss': loss_coord,
            'type_loss': loss_type},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )

        if loss.isnan():
            return None

        return loss

    def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor:

        output_dict = self(batch)

        log_dict, loss = self.compute_stats(output_dict, prefix='val')

        self.log_dict(
            log_dict,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        return loss

    def test_step(self, batch: Any, batch_idx: int) -> torch.Tensor:

        output_dict = self(batch)

        log_dict, loss = self.compute_stats(output_dict, prefix='test')

        self.log_dict(
            log_dict,
        )
        return loss

    def compute_stats(self, output_dict, prefix):

        loss_lattice = output_dict['loss_lattice']
        loss_coord = output_dict['loss_coord']
        loss_type = output_dict['loss_type']
        loss = output_dict['loss']

        log_dict = {
            f'{prefix}_loss': loss,
            f'{prefix}_lattice_loss': loss_lattice,
            f'{prefix}_coord_loss': loss_coord,
            f'{prefix}_type_loss': loss_type,
        }

        return log_dict, loss

    

In [51]:
class BaseModule(pl.LightningModule):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()
        # populate self.hparams with args and kwargs automagically!
        self.save_hyperparameters()
        if hasattr(self.hparams, "model"):
            self._hparams = self.hparams.model


    def configure_optimizers(self):
        opt = hydra.utils.instantiate(
            self.hparams.optim.optimizer, params=self.parameters(), _convert_="partial"
        )
        if not self.hparams.optim.use_lr_scheduler:
            return [opt]
        scheduler = hydra.utils.instantiate(
            self.hparams.optim.lr_scheduler, optimizer=opt
        )
        return {"optimizer": opt, "lr_scheduler": scheduler, "monitor": "val_loss"}

In [52]:
base = BaseModule()

In [53]:
base.hparams



In [57]:
# Define the necessary hyperparameters manually
hparams = {
    'latent_dim': 128,  # Example value
    'time_dim': 64,  # Example value
    'cost_lattice': 1e-6,
    'cost_coord': 1e-6,
    'cost_type': 1e-6,
    'p_uncond': 0.1,
    'guide_w': 0.5,
    # Mocked classes or replace them with actual instantiated objects
    'decoder': CSPNet,  # Adjust as needed
    'beta_scheduler': YourBetaScheduler(),  # Replace with your actual scheduler
    'sigma_scheduler': YourSigmaScheduler(),  # Replace with your actual scheduler
    'optim': {
        'optimizer': torch.optim.Adam,  # Use actual optimizer class
        'use_lr_scheduler': False,  # Example: No learning rate scheduler
        'lr_scheduler': None  # Optional if `use_lr_scheduler` is False
    }
}

NameError: name 'YourBetaScheduler' is not defined

In [56]:
cspdif = CSPDiffusion()

AttributeError: 'AttributeDict' object has no attribute 'latent_dim'

In [22]:
def run(cfg: DictConfig) -> None:

    if cfg.train.deterministic:
        seed_everything(cfg.train.random_seed)

    if cfg.train.pl_trainer.fast_dev_run:
        hydra.utils.log.info(
            f"Debug mode <{cfg.train.pl_trainer.fast_dev_run=}>. "
            f"Forcing debugger friendly configuration!"
        )
        # Debuggers don't like GPUs nor multiprocessing
        cfg.train.pl_trainer.gpus = 0
        cfg.data.datamodule.num_workers.train = 0
        cfg.data.datamodule.num_workers.val = 0
        cfg.data.datamodule.num_workers.test = 0

        # Switch wandb mode to offline to prevent online logging
        cfg.logging.wandb.mode = "offline"
    print('PAST 1')
    # Hydra run directory
    hydra_dir = Path(HydraConfig.get().run.dir)

    # Instantiate datamodule
    hydra.utils.log.info(f"Instantiating <{cfg.data.datamodule._target_}>")
    datamodule: pl.LightningDataModule = hydra.utils.instantiate(
        cfg.data.datamodule, _recursive_=False
    )
    print('PAST 2')
    # Instantiate model
    hydra.utils.log.info(f"Instantiating <{cfg.model._target_}>")
    model: pl.LightningModule = hydra.utils.instantiate(
        cfg.model,
        optim=cfg.optim,
        data=cfg.data,
        logging=cfg.logging,
        _recursive_=False,
    )

    # Pass scaler from datamodule to model
    hydra.utils.log.info(f"Passing scaler from datamodule to model <{datamodule.scaler}>")
    if datamodule.scaler is not None:
        model.lattice_scaler = datamodule.lattice_scaler.copy()
        model.scaler = datamodule.scaler.copy()
    torch.save(datamodule.lattice_scaler, hydra_dir / 'lattice_scaler.pt')
    torch.save(datamodule.scaler, hydra_dir / 'prop_scaler.pt')
    # Instantiate the callbacks
    callbacks: List[Callback] = build_callbacks(cfg=cfg)
    print('PAST 3')
    # Logger instantiation/configuration
    wandb_logger = None
    if "wandb" in cfg.logging:
        hydra.utils.log.info("Instantiating <WandbLogger>")
        wandb_config = cfg.logging.wandb
        wandb_logger = WandbLogger(
            **wandb_config,
            settings=wandb.Settings(start_method="fork"),
            tags=cfg.core.tags,
        )
        hydra.utils.log.info("W&B is now watching <{cfg.logging.wandb_watch.log}>!")
        wandb_logger.watch(
            model,
            log=cfg.logging.wandb_watch.log,
            log_freq=cfg.logging.wandb_watch.log_freq,
        )

    # Store the YaML config separately into the wandb dir
    yaml_conf: str = OmegaConf.to_yaml(cfg=cfg)
    (hydra_dir / "hparams.yaml").write_text(yaml_conf)

    # Load checkpoint (if exist)
    ckpts = list(hydra_dir.glob('*.ckpt'))
    if len(ckpts) > 0:
        ckpt_epochs = np.array([int(ckpt.parts[-1].split('-')[0].split('=')[1]) for ckpt in ckpts])
        ckpt = str(ckpts[ckpt_epochs.argsort()[-1]])
        hydra.utils.log.info(f"found checkpoint: {ckpt}")
    else:
        ckpt = None
          
    hydra.utils.log.info("Instantiating the Trainer")
    print('PAST 4')
    trainer = pl.Trainer(
        default_root_dir=hydra_dir,
        logger=wandb_logger,
        callbacks=callbacks,
        deterministic=cfg.train.deterministic,
        check_val_every_n_epoch=cfg.logging.val_check_interval,
        # progress_bar_refresh_rate=cfg.logging.progress_bar_refresh_rate,
        # resume_from_checkpoint=ckpt,
        **cfg.train.pl_trainer,
    )

    log_hyperparameters(trainer=trainer, model=model, cfg=cfg)

    hydra.utils.log.info("Starting training!")
    print('PAST 5')
    trainer.fit(model=model, datamodule=datamodule)

    hydra.utils.log.info("Starting testing!")
    trainer.test(datamodule=datamodule)

    # Logger closing to release resources/avoid multi-run conflicts
    if wandb_logger is not None:
        wandb_logger.experiment.finish()

In [23]:
@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
def main(cfg: omegaconf.DictConfig):
    run(cfg)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
