In [1]:
from sklearn.model_selection import KFold, train_test_split
import time
import orjson
import sys
import shutil
import os.path
import os
import random
import numpy as np
from torch_geometric.loader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
from random import shuffle
from pymatgen.core import Structure
sys.path.append("../../")
from utils.save_and_load import load_from_json, save_to_json
from utils.experiment_tracking import log_mean_std_based_on_test_metrics

In [2]:
!pip install -e ./
working_dir = os.getcwd()  # save current path

Obtaining file:///home/danya/mof/shg_ml/gnn_cmp/SevenNet
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: sevenn
  Building editable for sevenn (pyproject.toml) ... [?25ldone
[?25h  Created wheel for sevenn: filename=sevenn-0.10.3-0.editable-py3-none-any.whl size=34166 sha256=0c32cf20b790efc013544821956680f5eda07e1266ceeaa5fabfc90cb654d3c5
  Stored in directory: /tmp/pip-ephem-wheel-cache-trw1urbm/wheels/cd/ea/f2/ea047805855e8039a6d78b61fce4768457fb7dd50dfca0804f
Successfully built sevenn
Installing collected packages: sevenn
  Attempting uninstall: sevenn
    Found existing installation: sevenn 0.10.3
    Uninstalling sevenn-0.10.3:
      Successfully uninstalled sevenn-0.10.3
Successfully installed sevenn-0.10.3


In [3]:
import sevenn
sevenn.__version__

'0.10.3'

In [4]:
import sevenn.util as util
from sevenn.nn.scale import SpeciesWiseRescale
import os
from sevenn.train.graph_dataset import SevenNetGraphDataset

import torch
from ase.build import bulk, molecule

from sevenn.atom_graph_data import AtomGraphData
from sevenn.train.dataload import unlabeled_atoms_to_graph
from sevenn.util import model_from_checkpoint, pretrained_name_to_path
from sevenn.train.trainer import Trainer
import torch.optim.lr_scheduler as scheduler
from copy import deepcopy
from sevenn.error_recorder import ErrorRecorder
from sevenn.train.loss import ScalarLoss
from sevenn.sevenn_logger import Logger
from aim import Run

In [5]:
# from aim.pytorch import get_model_layers 
from aim import Distribution

# Move tensor from GPU to CPU
def get_pt_tensor(t):
    return t.cpu() if hasattr(t, 'is_cuda') and t.is_cuda else t

def get_model_layers(model, dt, parent_name=None):
    layers = {}
    for name, m in model.named_children():
        layer_name = '{}__{}'.format(parent_name, name) if parent_name else name
        layer_name += '.{}'.format(type(m).__name__)

        if len(list(m.named_children())):
            layers.update(get_model_layers(m, dt, layer_name))
        else:
            layers[layer_name] = {}
            weight = None
            if hasattr(m, 'weight') and m.weight is not None:
                weight = getattr(m.weight, dt, None)
                if weight is not None:
                    layers[layer_name]['weight'] = get_pt_tensor(weight).numpy()

            bias = None
            if hasattr(m, 'bias') and m.bias is not None:
                bias = getattr(m.bias, dt, None)
                if bias is not None:
                    layers[layer_name]['bias'] = get_pt_tensor(bias).numpy()

    return layers

In [6]:
def finetune(
    from_model_path: str = '7net-0',
    dataset_name: str = 'dataset_of_max_abs_shg',
    random_seed: int = 42,
):
    run = Run(
        experiment=f'SevenNet FT from {from_model_path} on {dataset_name}',
        log_system_params=True,
    )

    dataset_shg = load_from_json(f'../../data/final_data/{dataset_name}.json')

    sevennet_0_cp_path = util.pretrained_name_to_path(from_model_path)
    model, config = util.model_from_checkpoint(sevennet_0_cp_path)
    cutoff = config['cutoff']  # 7net-0 uses 5.0 Angstrom cutoff

    dataset = [
        AtomGraphData.from_numpy_dict(
            dict(
                **unlabeled_atoms_to_graph(
                    atoms=Structure.from_dict(v['structure']).to_ase_atoms(),
                    cutoff=cutoff,
                ),
                total_energy=v['shg'],
                crystal_idx=idx,
            )
        )
        for idx, (k, v) in enumerate(dataset_shg.items())
    ]

    dataset_keys = [k for k, v in dataset_shg.items()]
    print(f'# graphs: {len(dataset)}')
    print(dataset[0])

    # enable deterministic learning
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    random.seed(random_seed)
    np.random.seed(random_seed)
    %env CUBLAS_WORKSPACE_CONFIG=:4096:8

    k_folds = 10
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=random_seed)
    data_range = np.arange(0, len(dataset))

    fold_partition = {}
    fold_partition_names = {}
    crystal_predictions_log_for_all_folds = {}
    run.set_artifacts_uri(f'file://{os.getcwd()}/artifacts/')
    for fold, (train_val_idx, test_idx) in enumerate(kfold.split(data_range)):
        print(f'FOLD {fold}')

        train_idx, val_idx = train_test_split(
            train_val_idx, train_size=8 / 9, random_state=random_seed
        )
        g = torch.manual_seed(random_seed)

        print(f'# graphs for training: {len(train_idx)}')
        print(f'# graphs for validation: {len(val_idx)}')

        # log fold partition
        fold_partition = dict(
            fold=fold_partition.get('fold', {})
            | {
                fold: dict(
                    train_idx=train_idx.tolist(),
                    val_idx=val_idx.tolist(),
                    test_idx=test_idx.tolist(),
                )
            }
        )
        fold_partition_names = dict(
            fold=fold_partition_names.get('fold', {})
            | {
                fold: dict(
                    train_names=[dataset_keys[x] for x in train_idx.tolist()],
                    val_names=[dataset_keys[x] for x in val_idx.tolist()],
                    test_names=[dataset_keys[x] for x in test_idx.tolist()],
                )
            }
        )

        # log dataset once for all folds
        dataset_key_target = {
            dataset_keys[idx]: v['total_energy'].data.cpu().item()
            for idx, v in enumerate(dataset)
        }
        dataset_key_index = {
            dataset_keys[idx]: v['total_energy'].data.cpu().item()
            for idx, v in enumerate(dataset)
        }
        dataset_index_key_target = {
            idx: dict(
                target=v['total_energy'].data.cpu().item(), key=dataset_keys[idx]
            )
            for idx, v in enumerate(dataset)
        }
        # run['dataset_index_key_target'] = dataset_index_key_target

        hparams = dict(
            rescale_atomic_energy=True, all_requires_grad=True, batch_size=1
        )
        train_loader = DataLoader(
            dataset,
            batch_size=hparams['batch_size'],
            # shuffle=True,
            sampler=SubsetRandomSampler(train_idx, g),
        )
        valid_loader = DataLoader(
            dataset,
            sampler=SubsetRandomSampler(val_idx, g),
        )
        test_loader = DataLoader(
            dataset,
            sampler=SubsetRandomSampler(test_idx, g),
        )

        sevennet_0_cp_path = util.pretrained_name_to_path(from_model_path)
        model, config = util.model_from_checkpoint(sevennet_0_cp_path)

        run['hparams'] = hparams
        # set all params to trainable
        if hparams['all_requires_grad']:
            for param in model.parameters():
                param.requires_grad = True
            for module in model.modules():
                for param in module.parameters():
                    param.requires_grad = True
            for name, module in model.named_children():
                for param in module.parameters():
                    param.requires_grad = True
                    for sub_name, sub_module in module.named_children():
                        for p in sub_module.parameters():
                            p.requires_grad = True

        # To enhance training speed, we will overwrite shift scale module to trainable
        # By making energy shift trainable, error quickly converges.
        shift_scale_module = model._modules['rescale_atomic_energy']
        shift = shift_scale_module.shift.tolist()
        scale = shift_scale_module.scale.tolist()
        model._modules['rescale_atomic_energy'] = SpeciesWiseRescale(
            shift=shift,
            scale=scale,
            train_shift_scale=hparams['rescale_atomic_energy'],
        )

        for name, module in model.named_children():
            print(
                f'Layer: {name}, Type: {type(module).__name__}, Requires Grad: {any(p.requires_grad for p in module.parameters())}'
            )
            for sub_name, sub_module in module.named_children():
                print(
                    f'  Sub-layer: {sub_name}, Type: {type(sub_module).__name__}, Requires Grad: {any(p.requires_grad for p in sub_module.parameters())}'
                )

        run['model_parameters_count'] = sum(p.numel() for p in model.parameters())
        run['model_trainable_parameters_count'] = sum(
            p.numel() for p in model.parameters() if p.requires_grad
        )
        print(model)
        cutoff = config['cutoff']  # 7net-0 uses 5.0 Angstrom cutoff

        # print(train_loader[0])

        total_epoch = 40
        config.update(
            {
                'optimizer': 'adam',
                'optim_param': {'lr': 0.01},
                'scheduler': 'linearlr',
                'scheduler_param': {
                    'start_factor': 1.0,
                    'total_iters': total_epoch,
                    'end_factor': 0.1,
                },
                'is_ddp': False,  # 7net-0 was traied with ddp=True.
                # 'loss': 'L1',
                # 'loss_param': dict(),
            }
        )
        # print(config)
        run['config'] = {
            k: v.tolist()
            if isinstance(v, np.ndarray)
            else str(v)
            if isinstance(v, torch.device)
            else v
            for k, v in config.items()
        }
        trainer = Trainer.from_config(model, config)
        trainer.run = run
        trainer.fold = fold
        loss = ScalarLoss()
        # loss.assign_criteria(nn.L1Loss())
        loss.assign_criteria(nn.HuberLoss())
        trainer.loss_functions = [(loss, 1.0)]

        # We will use it as it is, with loss weight: 1.0, 1.0, and 0.01 for energy, force, and stress, respectively.
        print(trainer.loss_functions)
        print(trainer.optimizer)
        print(trainer.scheduler)
        train_recorder = ErrorRecorder.from_config(config)
        valid_recorder = deepcopy(train_recorder)
        choosen_metrics = []
        for metric in train_recorder.metrics:
            print(metric)
            if 'energy' in str(metric.name).lower():
                print('energy in ', str(metric.name).lower())
                choosen_metrics.append(metric)
        print('Choosen metrics:')
        for metric in choosen_metrics:
            print(metric)
        train_recorder.metrics = choosen_metrics
        valid_recorder = deepcopy(train_recorder)
        test_recorder = deepcopy(train_recorder)

        valid_best = float('inf')

        logger = Logger()
        logger.screen = True
        with logger:
            t = time.time()
            for epoch in range(1, total_epoch + 1):
                logger.timer_start('epoch')
                logger.writeline(
                    f'Epoch {epoch}/{total_epoch}  Learning rate: {trainer.get_lr():.6f}'
                )
                trainer.run_one_epoch(
                    train_loader,
                    is_train=True,
                    error_recorder=train_recorder,
                    epoch=epoch,
                    subset='train',
                )
                trainer.run_one_epoch(
                    valid_loader,
                    is_train=False,
                    error_recorder=valid_recorder,
                    epoch=epoch,
                    subset='val',
                )
                trainer.scheduler_step()
                train_err = train_recorder.epoch_forward()
                # return averaged error over one epoch, then reset.
                valid_err = valid_recorder.epoch_forward()
                logger.bar()
                logger.write_full_table([train_err, valid_err], ['Train', 'Valid'])
                logger.timer_end('epoch', message=f'Epoch {epoch} elapsed')

                # grad_hist = get_model_layers(model, 'grad')

                # for name, params in grad_hist.items():
                #     if 'weight' in params:
                #         run.track(
                #             Distribution(params['weight']),
                #             name=name,
                #             context={
                #                 'type': 'gradients',
                #                 'params': 'weights',
                #                 'fold': fold,
                #             },
                #         )
                #     if 'bias' in params:
                #         run.track(
                #             Distribution(params['bias']),
                #             name=name,
                #             context={
                #                 'type': 'gradients',
                #                 'params': 'biases',
                #                 'fold': fold,
                #             },
                #         )
            run.track(time.time() - t, name='train_time', context={'fold': fold})
            t = time.time()
            trainer.run_one_epoch(
                test_loader,
                is_train=False,
                error_recorder=test_recorder,
                subset='test',
            )
            run.track(time.time() - t, name='test_time', context={'fold': fold})
            test_err = test_recorder.epoch_forward()
            logger.bar()
            logger.write_full_table([test_err], ['Test'])

        crystal_predictions_log_for_all_folds = dict(
            fold=crystal_predictions_log_for_all_folds.get('fold', {})
            | {fold: trainer.crystal_predictions_log.copy()}
        )
    trainer.write_checkpoint(
        os.path.join(working_dir, 'checkpoint_fine_tuned.pth'),
        config=config,
        epoch=total_epoch,
    )

    log_mean_std_based_on_test_metrics(run)
    jsons_to_log = dict(
        fold_partition=fold_partition,
        fold_partition_names=fold_partition_names,
        dataset_key_target=dataset_key_target,
        dataset_key_index=dataset_key_index,
        dataset_index_key_target=dataset_index_key_target,
        crystal_predictions_log_for_all_folds=crystal_predictions_log_for_all_folds,
    )
    json_name = 'train_info.json'
    save_to_json(jsons_to_log, json_name)
    run.log_artifact(json_name)
    # os.makedirs("trained", exist_ok=True)
    # os.makedirs(f"trained/{ds_name}", exist_ok=True)
    run.close()

In [7]:
# finetune()

In [8]:
# finetune('7net-l3i5')

In [9]:
finetune(dataset_name='base_dataset_of_eff_shg')

# graphs: 974
AtomGraphData(
  x=[24],
  edge_index=[2, 756],
  pos=[24, 3],
  node_attr=[24],
  atomic_numbers=[24],
  edge_vec=[756, 3],
  cell_lattice_vectors=[3, 3],
  pbc_shift=[756, 3],
  cell_volume=404.682861328125,
  num_atoms=24,
  data_info={},
  total_energy=0.33507025241851807,
  crystal_idx=0
)
env: CUBLAS_WORKSPACE_CONFIG=:4096:8
FOLD 0
# graphs for training: 778
# graphs for validation: 98
Layer: edge_embedding, Type: EdgeEmbedding, Requires Grad: True
  Sub-layer: basis_function, Type: BesselBasis, Requires Grad: True
  Sub-layer: cutoff_function, Type: XPLORCutoff, Requires Grad: False
  Sub-layer: spherical, Type: SphericalEncoding, Requires Grad: False
Layer: onehot_idx_to_onehot, Type: OnehotEmbedding, Requires Grad: False
Layer: onehot_to_feature_x, Type: IrrepsLinear, Requires Grad: True
  Sub-layer: linear, Type: Linear, Requires Grad: True
Layer: 0_self_connection_intro, Type: SelfConnectionLinearIntro, Requires Grad: True
  Sub-layer: linear, Type: Linear, Req

In [10]:
# finetune('7net-l3i5', dataset_name='base_dataset_of_eff_shg')