In [1]:
import argparse
import logging
import os
import random
import sys
import time
import json

import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
from flatten_dict import flatten, unflatten
import yaml
import ml_collections as mlc


In [2]:
from openfold.config import model_config
from openfold.data.data_modules import (
    OpenFoldDataModule,
    DummyDataLoader,
)
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.callbacks import (
    EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
    drmsd,
    gdt_ts,
    gdt_ha,
)
from openfold.utils.import_weights import import_jax_weights_

from scripts.zero_to_fp32 import (
    get_fp32_state_dict_from_zero_checkpoint,
    get_global_step_from_zero_checkpoint,
)

from openfold.utils.logger import PerformanceLoggingCallback
from openfold.utils.config_check import enforce_config_constraints


In [3]:
from train_openfold import OpenFoldWrapper

In [4]:
    #  --wandb \
    #  --wandb_entity openfold \
    #  --wandb_project single_sequence_yiming  \
    #  --experiment_name no_msa_no_template \

In [5]:
args = dict(
    # train
    train_data_dir="/scratch/00946/zzhang/data/openfold/ls6-tacc/pdb_mmcif/mmcif_files",
    train_alignment_dir="/scratch/00946/zzhang/data/openfold/ls6-tacc/alignment_db",
    alignment_index_path="/scratch/00946/zzhang/data/openfold/ls6-tacc/gustaf/chain_lists/duplicated_super_fix.index",
    obsolete_pdbs_file_path="/scratch/00946/zzhang/data/openfold/ls6-tacc/pdb_mmcif/obsolete.dat",
    use_small_bfd=False,
    train_filter_path=None,
    train_chain_data_cache_path="/scratch/00946/zzhang/data/openfold/ls6-tacc/gustaf/prot_data_cache.json",
    # template
    template_mmcif_dir="/scratch/00946/zzhang/data/openfold/ls6-tacc/pdb_mmcif/mmcif_files",
    max_template_date="2021-10-01",
    template_release_dates_cache_path="/scratch/00946/zzhang/data/openfold/ls6-tacc/gustaf/mmcif_cache.json",
    kalign_binary_path="/usr/bin/kalign",
    # validation
    val_data_dir="/scratch/00946/zzhang/data/openfold/ls6-tacc/gustaf/val_set/data",
    val_alignment_dir="/scratch/00946/zzhang/data/openfold/ls6-tacc/gustaf/val_set/alignments",
    # model
    # {initial_training/fintuning/fintuning_no_template/model_1.1/model_1.2/model_1.1.1/model_1.1.2/model_1.2.1/model_1.2.2/model_1.2.3}
    config_stage="initial_training",
    # {ptm/None}
    config_ptm=False,
    # {train/inference_long_seq/None}
    config_mode="train",
    # {low_prec/None}
    config_lowprec=False,
    script_modules=False,
    # ditillation
    distillation_data_dir=None,
    distillation_alignment_dir=None,
    distillation_filter_path=None,
    distillation_alignment_index_path=None,
    _distillation_structure_index_path=None,
    distillation_chain_data_cache_path=None,
    # logging
    log_lr=True,
    checkpoint_every_epoch=True,
    output_dir="train_baseline/test",
    log_performance=False,
    wandb=True,
    wandb_entity="openfold",
    wandb_project="single_sequence_yiming",
    experiment_name="no_msa_no_template",
    wandb_id=None,
    # parallel
    gpus=3,
    num_nodes=1,
    replace_sampler_ddp=True,
    deepspeed_config_path="deepspeed_config.json",
    # trainer
    seed=42,
    train_epoch_len=126000,
    accumulate_grad_batches=3,
    num_sanity_val_steps=0,
    reload_dataloaders_every_n_epochs = 1,
    resume_from_ckpt=None,
    resume_model_weights_only=False,
    # early stopping
    early_stopping=False,
    min_delta=0,
    patience=3,
)


In [6]:
# with open("configs/baseline.yaml", "w") as f:
#     yaml.dump(args, f)

In [7]:
def enforce_arg_constrains(args):
    if args["seed"] is None and (
        (args["gpus"] is not None and args["gpus"] > 1)
        or (args["num_nodes"] is not None and args["num_nodes"] > 1)
    ):
        raise ValueError("For distributed training, --seed must be specified")

In [8]:
enforce_arg_constrains(args)

In [9]:
if args["seed"] is not None:
    seed_everything(args["seed"])

Global seed set to 42


In [10]:
config = {}
for preset in ["base"] + args["config_preset"].split("-"):
    if not preset:
        continue
    with open(f"configs/{preset}.json") as f:
        config = unflatten({**flatten(config), **flatten(json.load(f))})
# enforce_config_constraints(config)

In [11]:
config["data"]["data_module"]["data_loaders"] = {
    "batch_size": 1,
    "num_workers": 8,
    "pin_memory": True,
}
config["globals"]["chunk_size"] = None

In [12]:
# training mode (train_data_dir passed)
data_module = OpenFoldDataModule(
    config=mlc.ConfigDict(config["data"]),
    template_mmcif_dir=args["template_mmcif_dir"],
    max_template_date=args["max_template_date"],
    train_data_dir=args["train_data_dir"],
    train_alignment_dir=args["train_alignment_dir"],
    train_chain_data_cache_path=args["train_chain_data_cache_path"],
    distillation_data_dir=args["distillation_data_dir"],
    distillation_alignment_dir=args["distillation_alignment_dir"],
    distillation_chain_data_cache_path=args["distillation_chain_data_cache_path"],
    val_data_dir=args["val_data_dir"],
    val_alignment_dir=args["val_alignment_dir"],
    kalign_binary_path=args["kalign_binary_path"],
    train_filter_path=args["train_filter_path"],
    distillation_filter_path=args["distillation_filter_path"],
    obsolete_pdbs_file_path=args["obsolete_pdbs_file_path"],
    template_release_dates_cache_path=args["template_release_dates_cache_path"],
    batch_seed=args["seed"],
    train_epoch_len=args["train_epoch_len"],
    _distillation_structure_index_path=args["_distillation_structure_index_path"],
    alignment_index_path=args["alignment_index_path"],
    distillation_alignment_index_path=args["distillation_alignment_index_path"],
    predict_data_dir=None,
    predict_alignment_dir=None,
    # **vars(args)
)

data_module.prepare_data()
data_module.setup()

In [13]:
val_dl = data_module.val_dataloader()

In [14]:
sample = next(iter(val_dl))

In [15]:
type(sample)

dict

In [16]:
for k, v in sample.items():
    print(k, "\t\t\t", v.shape if hasattr(v, "shape") else type(v))

aatype 			 torch.Size([1, 209, 4])
residue_index 			 torch.Size([1, 209, 4])
seq_length 			 torch.Size([1, 4])
all_atom_positions 			 torch.Size([1, 209, 37, 3, 4])
all_atom_mask 			 torch.Size([1, 209, 37, 4])
resolution 			 torch.Size([1, 4])
is_distillation 			 torch.Size([1, 4])
template_aatype 			 torch.Size([1, 4, 209, 4])
template_all_atom_mask 			 torch.Size([1, 4, 209, 37, 4])
template_all_atom_positions 			 torch.Size([1, 4, 209, 37, 3, 4])
template_sum_probs 			 torch.Size([1, 4, 1, 4])
seq_mask 			 torch.Size([1, 209, 4])
msa_mask 			 torch.Size([1, 128, 209, 4])
msa_row_mask 			 torch.Size([1, 128, 4])
template_mask 			 torch.Size([1, 4, 4])
template_pseudo_beta 			 torch.Size([1, 4, 209, 3, 4])
template_pseudo_beta_mask 			 torch.Size([1, 4, 209, 4])
template_torsion_angles_sin_cos 			 torch.Size([1, 4, 209, 7, 2, 4])
template_alt_torsion_angles_sin_cos 			 torch.Size([1, 4, 209, 7, 2, 4])
template_torsion_angles_mask 			 torch.Size([1, 4, 209, 7, 4])
atom14_atom_exists 	

In [17]:
# if os.path.isdir(args["resume_from_ckpt"]):
#         sd = get_fp32_state_dict_from_zero_checkpoint(args["resume_from_ckpt"])
#     else:
#         sd = torch.load(args["resume_from_ckpt"])
#     sd = {k[len("module.") :]: v for k, v in sd.items()}
#     model_module.load_state_dict(sd)
#     logging.info("Successfully loaded model weights...")

# # TorchScript components of the model
# if args["script_modules"]:
#     script_preset_(model_module)

def get_model_basename(model_path):
    return os.path.splitext(os.path.basename(os.path.normpath(model_path)))[0]

path = "af2/params/params_model_2_ptm.npz"
model_basename = get_model_basename(path)
model_version = "_".join(model_basename.split("_")[1:])

In [18]:
model_module = OpenFoldWrapper(mlc.ConfigDict(config)).eval()

# from openfold_original.model.model import AlphaFold

# model_module.model = AlphaFold()

import_jax_weights_(
    model_module.model, path, version=model_version
)

In [19]:
sample["template_mask"].shape

torch.Size([1, 4, 4])

In [20]:
with torch.no_grad():
    batch = sample
    # At the start of validation, load the EMA weights
    if model_module.cached_weights is None:
        # model.state_dict() contains references to model weights rather
        # than copies. Therefore, we need to clone them before calling
        # load_state_dict().
        clone_param = lambda t: t.detach().clone()
        model_module.cached_weights = tensor_tree_map(clone_param, model_module.model.state_dict())
        model_module.model.load_state_dict(model_module.ema.state_dict()["params"])

    # Run the model
    outputs = model_module.model(batch)
    batch = tensor_tree_map(lambda t: t[..., -1], batch)

    # Compute loss and other metrics
    batch["use_clamped_fape"] = 0.0
    _, loss_breakdown = model_module.loss(outputs, batch, _return_breakdown=True)
    other_metrics = model_module._compute_validation_metrics(
        batch, outputs, superimposition_metrics=True
    )

torch.Size([1, 4, 209, 209, 64]) torch.Size([1, 209, 209, 128])


AttributeError: 'StructureModule' object has no attribute 'linear_q'

In [None]:
other_metrics, loss_breakdown

In [3]:
model_module = OpenFoldWrapper(config)
if args.resume_from_ckpt:
    if os.path.isdir(args.resume_from_ckpt):
        last_global_step = get_global_step_from_zero_checkpoint(
            args.resume_from_ckpt
        )
    else:
        sd = torch.load(args.resume_from_ckpt)
        last_global_step = int(sd["global_step"])
    model_module.resume_last_lr_step(last_global_step)
    logging.info("Successfully loaded last lr step...")
if args.resume_from_ckpt and args.resume_model_weights_only:
    if os.path.isdir(args.resume_from_ckpt):
        sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
    else:
        sd = torch.load(args.resume_from_ckpt)
    sd = {k[len("module.") :]: v for k, v in sd.items()}
    model_module.load_state_dict(sd)
    logging.info("Successfully loaded model weights...")

# TorchScript components of the model
if args.script_modules:
    script_preset_(model_module)

# data_module = DummyDataLoader("new_batch.pickle")
data_module = OpenFoldDataModule(
    config=config.data, batch_seed=args.seed, **vars(args)
)

data_module.prepare_data()
data_module.setup()

callbacks = []
if args.checkpoint_every_epoch:
    mc = ModelCheckpoint(
        every_n_epochs=1,
        auto_insert_metric_name=False,
        save_top_k=-1,
    )
    callbacks.append(mc)

if args.early_stopping:
    es = EarlyStoppingVerbose(
        monitor="val/lddt_ca",
        min_delta=args.min_delta,
        patience=args.patience,
        verbose=False,
        mode="max",
        check_finite=True,
        strict=True,
    )
    callbacks.append(es)

if args.log_performance:
    global_batch_size = args.num_nodes * args.gpus
    perf = PerformanceLoggingCallback(
        log_file=os.path.join(args.output_dir, "performance_log.json"),
        global_batch_size=global_batch_size,
    )
    callbacks.append(perf)

if args.log_lr:
    lr_monitor = LearningRateMonitor(logging_interval="step")
    callbacks.append(lr_monitor)

loggers = []
if args.wandb:
    wdb_logger = WandbLogger(
        name=args.experiment_name,
        save_dir=args.output_dir,
        id=args.wandb_id,
        project=args.wandb_project,
        **{"entity": args.wandb_entity},
    )
    loggers.append(wdb_logger)

if args.deepspeed_config_path is not None:
    strategy = DeepSpeedPlugin(
        config=args.deepspeed_config_path,
    )
    if args.wandb:
        wdb_logger.experiment.save(args.deepspeed_config_path)
        wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
    strategy = DDPPlugin(find_unused_parameters=False)
else:
    strategy = None

if args.wandb:
    freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
    os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
    wdb_logger.experiment.save(f"{freeze_path}")

trainer = pl.Trainer.from_argparse_args(
    args,
    default_root_dir=args.output_dir,
    strategy=strategy,
    callbacks=callbacks,
    logger=loggers,
)

if args.resume_model_weights_only:
    ckpt_path = None
else:
    ckpt_path = args.resume_from_ckpt

trainer.fit(
    model_module,
    datamodule=data_module,
    ckpt_path=ckpt_path,
)


IndentationError: unindent does not match any outer indentation level (<tokenize>, line 4)