In [1]:
import sys
import os
from torchdrug.utils import comm, pretty
from torchdrug import data, core, utils
from torch.utils import data as torch_data

from IPython import get_ipython
def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

if is_notebook():
    sys.path.append('/home/zhiqiang/PEER_Benchmark')
else:
    sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from peer import protbert, util, flip
from script.run_single import *

In [2]:
# train the model, same as PEER code
args = parse_args()

args.config = '/home/zhiqiang/PEER_Benchmark/config/single_task/ESM/gb1_ESM_fix.yaml' \
    if is_notebook() else os.path.realpath(args.config)
cfg = util.load_config(args.config)
if cfg.dataset["class"] != "Fluorescence":
    cfg.dataset["split"] = args.split

In [3]:
set_seed(args.seed)
output_dir = util.create_working_directory(cfg)
logger = util.get_root_logger()
if comm.get_rank() == 0:
    logger.warning("Config file: %s" % args.config)
    logger.warning(pprint.pformat(cfg))
    logger.warning("Output dir: %s" % output_dir)
    shutil.copyfile(args.config, os.path.basename(args.config))
os.chdir(output_dir)

16:54:59   Config file: /home/zhiqiang/PEER_Benchmark/config/single_task/ESM/gb1_ESM_fix.yaml
16:54:59   {'dataset': {'atom_feature': None,
             'bond_feature': None,
             'class': 'GB1',
             'path': '~/scratch/protein-datasets/',
             'split': 'two_vs_rest',
             'transform': {'class': 'Compose',
                           'transforms': [{'class': 'ProteinView',
                                           'view': 'residue'}]}},
 'engine': {'batch_size': 32, 'gpus': [0]},
 'eval_metric': 'spearmanr',
 'fix_encoder': True,
 'optimizer': {'class': 'Adam', 'lr': 5e-05},
 'output_dir': '~/scratch/torchprotein_output/',
 'task': {'class': 'PropertyPrediction',
          'criterion': 'mse',
          'metric': ['mae', 'rmse', 'spearmanr'],
          'model': {'class': 'ESM',
                    'model': 'ESM-1b',
                    'path': '~/scratch/protein-model-weights/esm-model-weights/',
                    'readout': 'mean'},
          'normaliz

In [4]:
solver = build_solver(cfg, logger)
# build dataset
_dataset = core.Configurable.load_config_dict(cfg.dataset)
if "test_split" in cfg:
    train_set, valid_set, test_set = _dataset.split(['train', 'valid', cfg.test_split])
else:
    train_set, valid_set, test_set = _dataset.split()
if comm.get_rank() == 0:
    logger.warning(_dataset)
    logger.warning("#train: %d, #valid: %d, #test: %d" % (len(train_set), len(valid_set), len(test_set)))

# build task model
if cfg.task["class"] in ["PropertyPrediction", "InteractionPrediction"]:
    cfg.task.task = _dataset.tasks
task = core.Configurable.load_config_dict(cfg.task)

# fix the pre-trained encoder if specified
fix_encoder = cfg.get("fix_encoder", False)
fix_encoder2 = cfg.get("fix_encoder2", False)
if fix_encoder:
    for p in task.model.parameters():
        p.requires_grad = False
if fix_encoder2:
    for p in task.model2.parameters():
        p.requires_grad = False

16:54:59   Extracting /home/zhiqiang/scratch/protein-datasets/gb1/splits.zip to /home/zhiqiang/scratch/protein-datasets/gb1


Loading /home/zhiqiang/scratch/protein-datasets/gb1/splits/two_vs_rest.csv: 100%|██████████| 8734/8734 [00:00<00:00, 156473.26it/s]
Constructing proteins from sequences: 100%|██████████| 8733/8733 [00:06<00:00, 1317.17it/s]

16:55:06   GB1(
  #sample: 8733
  #task: 1
)
16:55:06   #train: 381, #valid: 43, #test: 8309





16:55:23   Preprocess training set
16:55:25   {'batch_size': 32,
 'class': 'core.Engine',
 'gpus': [0],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'eps': 1e-08,
               'foreach': None,
               'lr': 5e-05,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'tasks.PropertyPrediction',
          'criterion': 'mse',
          'graph_construction_model': None,
          'metric': ['mae', 'rmse', 'spearmanr'],
          'mlp_batch_norm': False,
          'mlp_dropout': 0,
          'model': {'class': 'models.ESM',
                    'model': 'ESM-1b',
                    'path': '~/scratch/protein-model-weights/esm-model-weights/',
                    'readout': 'mean'},
          'normalization': False,
          'num

Loading /home/zhiqiang/scratch/protein-datasets/gb1/splits/two_vs_rest.csv: 100%|██████████| 8734/8734 [00:00<00:00, 183694.30it/s]
Constructing proteins from sequences: 100%|██████████| 8733/8733 [00:06<00:00, 1316.32it/s]

16:55:31   GB1(
  #sample: 8733
  #task: 1
)
16:55:31   #train: 381, #valid: 43, #test: 8309





In [5]:
# build solver
cfg.optimizer.params = task.parameters()
optimizer = core.Configurable.load_config_dict(cfg.optimizer)
if not "scheduler" in cfg:
    scheduler = None
else:
    cfg.scheduler.optimizer = optimizer
    scheduler = core.Configurable.load_config_dict(cfg.scheduler)

solver = core.Engine(task, train_set, valid_set, test_set, optimizer, scheduler, **cfg.engine)
if "lr_ratio" in cfg:
    cfg.optimizer.params = [
        {'params': solver.model.model.parameters(), 'lr': cfg.optimizer.lr * cfg.lr_ratio},
        {'params': solver.model.mlp.parameters(), 'lr': cfg.optimizer.lr}
    ]
    optimizer = core.Configurable.load_config_dict(cfg.optimizer)
    solver.optimizer = optimizer
if "checkpoint" in cfg:
    solver.load(cfg.checkpoint, load_optimizer=False)

16:55:48   Preprocess training set
16:55:49   {'batch_size': 32,
 'class': 'core.Engine',
 'gpus': [0],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'eps': 1e-08,
               'foreach': None,
               'lr': 5e-05,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'tasks.PropertyPrediction',
          'criterion': 'mse',
          'graph_construction_model': None,
          'metric': ['mae', 'rmse', 'spearmanr'],
          'mlp_batch_norm': False,
          'mlp_dropout': 0,
          'model': {'class': 'models.ESM',
                    'model': 'ESM-1b',
                    'path': '~/scratch/protein-model-weights/esm-model-weights/',
                    'readout': 'mean'},
          'normalization': False,
          'num

In [None]:
# solver, best_epoch = train_and_validate(cfg, solver)
"""start solver, best_epoch = train_and_validate(cfg, solver)"""
step = math.ceil(cfg.train.num_epoch / 10)
best_score = float("-inf")
best_epoch = -1

# if not cfg.train.num_epoch > 0:
#     return solver, best_epoch

for i in range(0, cfg.train.num_epoch, step):
    kwargs = cfg.train.copy()
    kwargs["num_epoch"] = min(step, cfg.train.num_epoch - i)
    solver.model.split = "train"
    # solver.train(**kwargs)
    """start solver.train(**kwargs)"""
    from torch import nn
    from itertools import islice
    batch_per_epoch = None
    num_epoch = kwargs["num_epoch"]
    sampler = torch_data.DistributedSampler(solver.train_set, solver.world_size, solver.rank)
    dataloader = data.DataLoader(solver.train_set, solver.batch_size, sampler=sampler, num_workers=solver.num_worker)
    batch_per_epoch = batch_per_epoch or len(dataloader)
    model = solver.model
    model.split = "train"
    model.train()

    for epoch in solver.meter(num_epoch):
        sampler.set_epoch(epoch)

        metrics = []
        start_id = 0
        # the last gradient update may contain less than gradient_interval batches
        gradient_interval = min(batch_per_epoch - start_id, solver.gradient_interval)

        all_loss = 0
        for batch_id, batch in enumerate(islice(dataloader, batch_per_epoch)):
            if solver.device.type == "cuda":
                batch = utils.cuda(batch, device=solver.device)

            # loss_1, metric = model(batch)
            from torch.nn import functional as F
            from torchdrug.layers import functional

            all_loss = torch.tensor(0, dtype=torch.float32, device=model.device)
            metric = {}

            # pred = model.predict(batch, all_loss, metric)
            graph = batch["graph"]
            if model.graph_construction_model:
                graph = model.graph_construction_model(graph)
            output = model.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
            pred = model.mlp(output["graph_feature"])
            if model.normalization:
                pred = pred * model.std + model.mean

            # if all([t not in batch for t in self.task]):
            #     # unlabeled data
            #     return all_loss, metric

            target = model.target(batch)
            labeled = ~torch.isnan(target)
            target[~labeled] = 0

            for criterion, weight in model.criterion.items():
                loss = F.mse_loss(pred, target, reduction="mean")

                name = tasks._get_criterion_name(criterion)
                metric[name] = loss
            loss_1, metric = loss, metric

            # if not loss.requires_grad:
            #     raise RuntimeError("Loss doesn't require grad. Did you define any loss in the task?")
            loss = loss_1 / gradient_interval
            loss.backward()
            metrics.append(metric)
            all_loss += loss
            # print(batch_id, "Two loss: ", round(loss.item(), 3), round(all_loss.item(), 3))

            solver.optimizer.step()
            # print("optimiser")
            # if batch_id - start_id + 1 == gradient_interval:
            #     solver.optimizer.step()
            #     solver.optimizer.zero_grad()
            #
            #     metric = utils.stack(metrics, dim=0)
            #     metric = utils.mean(metric, dim=0)
            #     if solver.world_size > 1:
            #         metric = comm.reduce(metric, op="mean")
            #     solver.meter.update(metric)
            #
            #     metrics = []
            #     start_id = batch_id + 1
            #     gradient_interval = min(batch_per_epoch - start_id, solver.gradient_interval)
        print("Loss: ", all_loss)

        # if solver.scheduler:
        #     # False
        #     solver.scheduler.step()
    """end solver.train(**kwargs)"""


    solver.model.split = "valid"
    # metric = solver.evaluate("valid")
    """start metric = solver.evaluate("valid")"""
    split, log = "valid", True
    if comm.get_rank() == 0:
        logger.warning(pretty.separator)
        logger.warning("Evaluate on %s" % split)
    test_set = getattr(solver, "%s_set" % split)
    sampler = torch_data.DistributedSampler(test_set, solver.world_size, solver.rank)
    dataloader = data.DataLoader(test_set, solver.batch_size, sampler=sampler, num_workers=solver.num_worker)
    model = solver.model
    model.split = split

    model.eval()
    preds = []
    targets = []
    for batch in dataloader:
        if solver.device.type == "cuda":
            batch = utils.cuda(batch, device=solver.device)

        pred, target = model.predict_and_target(batch)
        preds.append(pred)
        targets.append(target)

    pred = utils.cat(preds)
    target = utils.cat(targets)
    if solver.world_size > 1:
        pred = comm.cat(pred)
        target = comm.cat(target)
    metric = model.evaluate(pred, target)
    if log:
        solver.meter.log(metric, category="%s/epoch" % solver)
    solver.batch_size = cfg.engine.batch_size
    """end metric = solver.evaluate("valid")"""

    score = []
    for k, v in metric.items():
        if k.startswith(cfg.eval_metric):
            if "root mean squared error" in cfg.eval_metric:
                score.append(-v)
            else:
                score.append(v)
    score = sum(score) / len(score)
    if score > best_score:
        # print("Update best epoch.")
        if best_epoch > -1:
            # remove old best epoch model parameters before saving new parameters
            to_remove_path = os.path.expanduser("model_epoch_%d.pth" % best_epoch)
            # print('removing old parameters at %s' % to_remove_path)
            os.remove(to_remove_path)

        best_score = score
        best_epoch = solver.epoch

        # only save epoch model parameters when it is another best epoch
        solver.save("model_epoch_%d.pth" % solver.epoch)
# clean GPU memory
clean_gpu_memory()
solver.load("model_epoch_%d.pth" % best_epoch)

if comm.get_rank() == 0:
    logger.warning("Best epoch on valid: %d" % best_epoch)
"""end solver, best_epoch = train_and_validate(cfg, solver)"""

In [None]:
# # test(cfg, solver)
# """start test(cfg, solver)"""
# if "test_batch_size" in cfg:
#     solver.batch_size = cfg.test_batch_size
# solver.model.split = "valid"
# solver.evaluate("valid")
# solver.model.split = "test"
# solver.evaluate("test")
# """end test(cfg, solver)"""