In [1]:
import sys
import os

import torch
from torchdrug.utils import comm, pretty
from torchdrug import data, core, utils
from torch.utils import data as torch_data

from IPython import get_ipython
def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

if is_notebook():
    sys.path.append('/home/zhiqiang/PEER_Benchmark')
else:
    sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from peer import protbert, util, flip
from script.run_single import *

In [2]:
# train the model, same as PEER code
args = parse_args()

args.config = '/home/zhiqiang/PEER_Benchmark/config/single_task/ESM/gb1_ESM_fix.yaml' \
    if is_notebook() else os.path.realpath(args.config)
cfg = util.load_config(args.config)
if cfg.dataset["class"] != "Fluorescence":
    cfg.dataset["split"] = args.split

In [3]:
set_seed(args.seed)
output_dir = util.create_working_directory(cfg)
logger = util.get_root_logger()
if comm.get_rank() == 0:
    logger.warning("Config file: %s" % args.config)
    logger.warning(pprint.pformat(cfg))
    logger.warning("Output dir: %s" % output_dir)
    shutil.copyfile(args.config, os.path.basename(args.config))
os.chdir(output_dir)

12:13:05   Config file: /home/zhiqiang/PEER_Benchmark/config/single_task/ESM/gb1_ESM_fix.yaml
12:13:05   {'dataset': {'atom_feature': None,
             'bond_feature': None,
             'class': 'GB1',
             'path': '~/scratch/protein-datasets/',
             'split': 'two_vs_rest',
             'transform': {'class': 'Compose',
                           'transforms': [{'class': 'ProteinView',
                                           'view': 'residue'}]}},
 'engine': {'batch_size': 32, 'gpus': [1]},
 'eval_metric': 'spearmanr',
 'fix_encoder': True,
 'optimizer': {'class': 'Adam', 'lr': 5e-05},
 'output_dir': '~/scratch/torchprotein_output/',
 'task': {'class': 'PropertyPrediction',
          'criterion': 'mse',
          'metric': ['mae', 'rmse', 'spearmanr'],
          'model': {'class': 'ESM',
                    'model': 'ESM-1b',
                    'path': '~/scratch/protein-model-weights/esm-model-weights/',
                    'readout': 'mean'},
          'normaliz

In [4]:
# solver = build_solver(cfg, logger)
# def build_solver(cfg, logger):
# build dataset
_dataset = core.Configurable.load_config_dict(cfg.dataset)

12:13:05   Extracting /home/zhiqiang/scratch/protein-datasets/gb1/splits.zip to /home/zhiqiang/scratch/protein-datasets/gb1


Loading /home/zhiqiang/scratch/protein-datasets/gb1/splits/two_vs_rest.csv: 100%|██████████| 8734/8734 [00:00<00:00, 154091.96it/s]
Constructing proteins from sequences: 100%|██████████| 8733/8733 [00:06<00:00, 1261.84it/s]


In [5]:
if "test_split" in cfg:
    train_set, valid_set, test_set = _dataset.split(['train', 'valid', cfg.test_split])
else:
    train_set, valid_set, test_set = _dataset.split()
if comm.get_rank() == 0:
    logger.warning(_dataset)
    logger.warning("#train: %d, #valid: %d, #test: %d" % (len(train_set), len(valid_set), len(test_set)))

12:13:12   GB1(
  #sample: 8733
  #task: 1
)
12:13:12   #train: 381, #valid: 43, #test: 8309


In [6]:
# build task model
if cfg.task["class"] in ["PropertyPrediction", "InteractionPrediction"]:
    cfg.task.task = _dataset.tasks
task = core.Configurable.load_config_dict(cfg.task)

In [7]:
task.model

EvolutionaryScaleModeling(
  (model): ProteinBertModel(
    (embed_tokens): Embedding(33, 1280, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1280, out_features=5120, bias=True)
        (fc2): Linear(in_features=5120, out_features=1280, bias=True)
        (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (v_proj)

In [8]:
type(task.model)

torchdrug.models.esm.EvolutionaryScaleModeling

In [9]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [10]:
# fix the pre-trained encoder if specified
fix_encoder = cfg.get("fix_encoder", False)
fix_encoder2 = cfg.get("fix_encoder2", False)
if fix_encoder:
    for p in task.model.parameters():
        p.requires_grad = False
if fix_encoder2:
    for p in task.model2.parameters():
        p.requires_grad = False

# build solver
cfg.optimizer.params = task.parameters()
optimizer = core.Configurable.load_config_dict(cfg.optimizer)
if not "scheduler" in cfg:
    scheduler = None
else:
    cfg.scheduler.optimizer = optimizer
    scheduler = core.Configurable.load_config_dict(cfg.scheduler)

solver = core.Engine(task, train_set, valid_set, test_set, optimizer, scheduler, **cfg.engine)
if "lr_ratio" in cfg:
    cfg.optimizer.params = [
        {'params': solver.model.model.parameters(), 'lr': cfg.optimizer.lr * cfg.lr_ratio},
        {'params': solver.model.mlp.parameters(), 'lr': cfg.optimizer.lr}
    ]
    optimizer = core.Configurable.load_config_dict(cfg.optimizer)
    solver.optimizer = optimizer
if "checkpoint" in cfg:
    solver.load(cfg.checkpoint, load_optimizer=False)

# return solver

12:13:31   Preprocess training set
12:13:32   {'batch_size': 32,
 'class': 'core.Engine',
 'gpus': [1],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'eps': 1e-08,
               'foreach': None,
               'lr': 5e-05,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'tasks.PropertyPrediction',
          'criterion': 'mse',
          'graph_construction_model': None,
          'metric': ['mae', 'rmse', 'spearmanr'],
          'mlp_batch_norm': False,
          'mlp_dropout': 0,
          'model': {'class': 'models.ESM',
                    'model': 'ESM-1b',
                    'path': '~/scratch/protein-model-weights/esm-model-weights/',
                    'readout': 'mean'},
          'normalization': False,
          'num

In [11]:
solver.model

PropertyPrediction(
  (model): EvolutionaryScaleModeling(
    (model): ProteinBertModel(
      (embed_tokens): Embedding(33, 1280, padding_idx=1)
      (layers): ModuleList(
        (0): TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bias=True)
          (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
        (1): TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear

In [12]:
# solver, best_epoch = train_and_validate(cfg, solver)
# def train_and_validate(cfg, solver):
step = math.ceil(cfg.train.num_epoch / 10)
best_score = float("-inf")
best_epoch = -1

# if not cfg.train.num_epoch > 0:
#     return solver, best_epoch

# for i in range(0, cfg.train.num_epoch, step):
i = 0
kwargs = cfg.train.copy()
kwargs["num_epoch"] = min(step, cfg.train.num_epoch - i)
solver.model.split = "train"

In [13]:
# solver.train(**kwargs)
from torch import nn
from itertools import islice
batch_per_epoch = None
num_epoch = kwargs["num_epoch"]

sampler = torch_data.DistributedSampler(solver.train_set, solver.world_size, solver.rank)
dataloader = data.DataLoader(solver.train_set, solver.batch_size, sampler=sampler, num_workers=solver.num_worker)
batch_per_epoch = batch_per_epoch or len(dataloader)
model = solver.model
if solver.world_size > 1:
    if solver.device.type == "cuda":
        model = nn.parallel.DistributedDataParallel(model, device_ids=[solver.device],
                                                    find_unused_parameters=True)
    else:
        model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)

In [14]:
model.train()

# for epoch in solver.meter(num_epoch):
epoch = 0
sampler.set_epoch(epoch)

metrics = []
start_id = 0
# the last gradient update may contain less than gradient_interval batches
gradient_interval = min(batch_per_epoch - start_id, solver.gradient_interval)

In [15]:
# for batch_id, batch in enumerate(islice(dataloader, batch_per_epoch)):
batch_id = 0
batch = list(islice(dataloader, batch_per_epoch))[batch_id]
if solver.device.type == "cuda":
    batch = utils.cuda(batch, device=solver.device)

# loss, metric = model(batch)
from torch.nn import functional as F
from torchdrug.layers import functional
all_loss = torch.tensor(0, dtype=torch.float32, device=model.device)
metric = {}

# pred = model.predict(batch, all_loss, metric)
graph = batch["graph"]
if model.graph_construction_model:
    graph = model.graph_construction_model(graph)
output = model.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
pred = model.mlp(output["graph_feature"])
if model.normalization:
    print(model.normalization)
    pred = pred * model.std + model.mean

In [16]:
output["graph_feature"][20]

tensor([ 0.0095,  0.1165,  0.0243,  ..., -0.0237, -0.0295,  0.1617],
       device='cuda:1')

In [31]:
output["residue_feature"][:265]

tensor([[-0.2205,  0.0943,  0.2996,  ..., -0.1174, -0.1577, -0.0845],
        [-0.3352,  0.2010, -0.1284,  ..., -0.1173, -0.0746,  0.3168],
        [-0.2199,  0.0658,  0.0383,  ..., -0.2070,  0.0414,  0.0626],
        ...,
        [-0.0138,  0.4636, -0.0310,  ..., -0.1492, -0.0804, -0.2891],
        [ 0.0220,  0.2933, -0.0931,  ...,  0.0600, -0.2177, -0.1240],
        [ 0.2823,  0.3091, -0.0823,  ...,  0.0716, -0.1686,  0.1461]],
       device='cuda:1')

In [35]:
graph[0]

Protein(num_atom=0, num_bond=0, num_residue=265, device='cuda:1')

In [50]:
graph[0].residue_feature[264]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       device='cuda:1')

In [52]:
graph[0].residue_feature[263]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       device='cuda:1')

In [54]:
output["residue_feature"][264]

tensor([ 0.2823,  0.3091, -0.0823,  ...,  0.0716, -0.1686,  0.1461],
       device='cuda:1')

In [55]:
output["residue_feature"][263]

tensor([ 0.0220,  0.2933, -0.0931,  ...,  0.0600, -0.2177, -0.1240],
       device='cuda:1')

In [63]:
output["residue_feature"][264+265]

tensor([ 0.2807,  0.3122, -0.0870,  ...,  0.0736, -0.1740,  0.1466],
       device='cuda:1')

In [64]:
output["residue_feature"][263+265]

tensor([ 0.0187,  0.2989, -0.0911,  ...,  0.0542, -0.2230, -0.1222],
       device='cuda:1')

In [66]:
graph[1].residue_feature[264]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       device='cuda:1')

In [67]:
graph[1].residue_feature[263]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       device='cuda:1')

In [61]:
solver.model.model.model

ProteinBertModel(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj):

In [60]:
solver.model.model.model.embed_tokens

Embedding(33, 1280, padding_idx=1)

In [17]:
pred

tensor([[0.0093],
        [0.0094],
        [0.0075],
        [0.0094],
        [0.0091],
        [0.0072],
        [0.0080],
        [0.0085],
        [0.0083],
        [0.0096],
        [0.0078],
        [0.0080],
        [0.0080],
        [0.0099],
        [0.0103],
        [0.0093],
        [0.0108],
        [0.0100],
        [0.0075],
        [0.0082],
        [0.0086],
        [0.0096],
        [0.0090],
        [0.0090],
        [0.0069],
        [0.0085],
        [0.0092],
        [0.0092],
        [0.0082],
        [0.0065],
        [0.0057],
        [0.0101]], device='cuda:1', grad_fn=<AddmmBackward0>)

In [18]:
print(model.mlp)

MultiLayerPerceptron(
  (layers): ModuleList(
    (0): Linear(in_features=1280, out_features=1280, bias=True)
    (1): Linear(in_features=1280, out_features=1, bias=True)
  )
)


In [19]:
model

PropertyPrediction(
  (model): EvolutionaryScaleModeling(
    (model): ProteinBertModel(
      (embed_tokens): Embedding(33, 1280, padding_idx=1)
      (layers): ModuleList(
        (0): TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bias=True)
          (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
        (1): TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear

In [20]:
model.model.output_dim

1280

In [21]:
[model.model.output_dim] * (model.num_mlp_layer - 1) + [sum(model.num_class)]

[1280, 1]

In [22]:
model.mlp_batch_norm

False

In [23]:
model.mlp_dropout

0

In [24]:
model.train()
output = model.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
pred = model.mlp(output["graph_feature"])

In [25]:
pred

tensor([[0.0093],
        [0.0094],
        [0.0075],
        [0.0094],
        [0.0091],
        [0.0072],
        [0.0080],
        [0.0085],
        [0.0083],
        [0.0096],
        [0.0078],
        [0.0080],
        [0.0080],
        [0.0099],
        [0.0103],
        [0.0093],
        [0.0108],
        [0.0100],
        [0.0075],
        [0.0082],
        [0.0086],
        [0.0096],
        [0.0090],
        [0.0090],
        [0.0069],
        [0.0085],
        [0.0092],
        [0.0092],
        [0.0082],
        [0.0065],
        [0.0057],
        [0.0101]], device='cuda:1', grad_fn=<AddmmBackward0>)

In [26]:
model.eval()
output = model.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
pred = model.mlp(output["graph_feature"])

In [27]:
pred

tensor([[0.0093],
        [0.0094],
        [0.0075],
        [0.0094],
        [0.0091],
        [0.0072],
        [0.0080],
        [0.0085],
        [0.0083],
        [0.0096],
        [0.0078],
        [0.0080],
        [0.0080],
        [0.0099],
        [0.0103],
        [0.0093],
        [0.0108],
        [0.0100],
        [0.0075],
        [0.0082],
        [0.0086],
        [0.0096],
        [0.0090],
        [0.0090],
        [0.0069],
        [0.0085],
        [0.0092],
        [0.0092],
        [0.0082],
        [0.0065],
        [0.0057],
        [0.0101]], device='cuda:1', grad_fn=<AddmmBackward0>)