forked from dptech-corp/Uni-Fold
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add script for benchmarking & benchmark results (dptech-corp#10)
* add script for benchmarking * code clean * add benchmark in memory cost * remove use_lma * add option for LMA
- Loading branch information
Showing
2 changed files
with
399 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# modified from https://github.com/hpcaitech/FastFold/blob/main/benchmark/perf.py | ||
import argparse | ||
import os | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from fastfold.distributed import init_dap | ||
from fastfold.model.fastnn import Evoformer | ||
|
||
|
||
def main(): | ||
|
||
parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark') | ||
parser.add_argument("--dap-size", default=1, type=int, help='batch size') | ||
parser.add_argument('--batch-size', default=1, type=int, help='batch size') | ||
parser.add_argument('--msa-length', default=128, type=int, help='Sequence Length of MSA') | ||
parser.add_argument('--res-length', | ||
default=256, | ||
type=int, | ||
help='Sequence Length of Residues') | ||
parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute') | ||
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard') | ||
parser.add_argument('--layers', | ||
default=4, | ||
type=int, | ||
help='Evoformer Layers to Execute') | ||
parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension') | ||
parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension') | ||
parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads') | ||
parser.add_argument('--openfold', | ||
action='store_true', | ||
help='Benchmark with Evoformer Implementation from OpenFold.') | ||
parser.add_argument('--openfold-lma', | ||
action='store_true', | ||
help='set use_lma to True in openfold.') | ||
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.') | ||
|
||
args = parser.parse_args() | ||
|
||
init_dap(args.dap_size) | ||
|
||
precision = torch.bfloat16 | ||
if args.dap_size > 1: | ||
# (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch | ||
precision = torch.float16 | ||
|
||
if not torch.cuda.is_available(): | ||
raise NotImplementedError('Running on CPU is not supported') | ||
|
||
torch.manual_seed(42) | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed_all(42) | ||
|
||
if args.openfold: | ||
from openfold.model.evoformer import EvoformerBlock | ||
|
||
class OpenFoldEvoformer(nn.Module): | ||
|
||
def __init__(self, d_node, d_pair): | ||
super(OpenFoldEvoformer, self).__init__() | ||
self.d_node = d_node | ||
self.d_pair = d_pair | ||
|
||
self.c_hidden_msa_att = int(d_node / 8) | ||
self.c_hidden_pair_att = int(d_pair / 4) | ||
|
||
self.EvoformerBlock = EvoformerBlock(c_m=d_node, | ||
c_z=d_pair, | ||
c_hidden_msa_att=self.c_hidden_msa_att, | ||
c_hidden_opm=self.c_hidden_msa_att, | ||
c_hidden_mul=self.d_pair, | ||
c_hidden_pair_att=self.c_hidden_pair_att, | ||
no_heads_msa=8, | ||
no_heads_pair=4, | ||
transition_n=4, | ||
msa_dropout=0.15, | ||
pair_dropout=0.25, | ||
inf=1e9, | ||
eps=1e-10) | ||
|
||
def forward(self, node, pair, node_mask, pair_mask): | ||
node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask, use_lma=args.openfold_lma) | ||
return node, pair | ||
|
||
attn_layers = [] | ||
for idx in range(0, args.layers): | ||
if args.openfold: | ||
attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz)) | ||
else: | ||
attn_layers.append(Evoformer(d_node=args.cm, d_pair=args.cz)) | ||
attn_layers[idx].cuda() | ||
attn_layers[idx].to(dtype=precision) | ||
|
||
start_evt_fwd = [] | ||
start_evt_bwd = [] | ||
stop_evt_bwd = [] | ||
for recorded_trial in range(0, args.trials): | ||
start_evt_fwd.append(torch.cuda.Event(enable_timing=True)) | ||
start_evt_bwd.append(torch.cuda.Event(enable_timing=True)) | ||
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True)) | ||
|
||
inputs_node = torch.randn(args.batch_size, | ||
args.msa_length // args.dap_size, | ||
args.res_length, | ||
args.cm, | ||
dtype=precision, | ||
device=torch.device("cuda")).requires_grad_(True) | ||
inputs_pair = torch.randn(args.batch_size, | ||
args.res_length // args.dap_size, | ||
args.res_length, | ||
args.cz, | ||
dtype=precision, | ||
device=torch.device("cuda")).requires_grad_(True) | ||
node_mask = torch.ones((args.batch_size, args.msa_length, args.res_length), | ||
dtype=precision, | ||
device=torch.device("cuda")).requires_grad_(False) | ||
pair_mask = torch.ones((args.batch_size, args.res_length, args.res_length), | ||
dtype=precision, | ||
device=torch.device("cuda")).requires_grad_(False) | ||
|
||
|
||
total_used_mem_gb = 0 | ||
for trial in range(0, args.trials + args.warmup_trials): | ||
layer_inputs = inputs_node, inputs_pair | ||
evt_idx = trial - args.warmup_trials | ||
|
||
torch.distributed.barrier() | ||
torch.cuda.synchronize() | ||
torch.cuda.reset_peak_memory_stats() | ||
if evt_idx >= 0: | ||
start_evt_fwd[evt_idx].record() | ||
with torch.set_grad_enabled(not args.fwd): | ||
for lyr_idx in range(0, args.layers): | ||
layer_inputs = attn_layers[lyr_idx].forward( | ||
*layer_inputs, | ||
node_mask, | ||
pair_mask, | ||
) | ||
|
||
torch.cuda.synchronize() | ||
|
||
if evt_idx >= 0: | ||
start_evt_bwd[evt_idx].record() | ||
|
||
if not args.fwd: | ||
s = layer_inputs[0].mean() + layer_inputs[1].mean() | ||
s.backward() | ||
|
||
torch.cuda.synchronize() | ||
cur_cost_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 | ||
total_used_mem_gb += cur_cost_mem | ||
if evt_idx >= 0: | ||
stop_evt_bwd[evt_idx].record() | ||
|
||
|
||
torch.cuda.synchronize() | ||
elapsed_time_fwd = 0.0 | ||
elapsed_time_bwd = 0.0 | ||
for evt_idx in range(0, args.trials): | ||
elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx]) | ||
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx]) | ||
|
||
print( | ||
"Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}), Fwd Time / Layer: {:.3f} ms, Bwd Time / Layer: {:.3f} ms, Memory cost {:.3f} GB".format( | ||
args.batch_size, | ||
args.msa_length, | ||
args.res_length, | ||
args.cm, | ||
args.cz, | ||
elapsed_time_fwd / (args.trials * args.layers), | ||
elapsed_time_bwd / (args.trials * args.layers), | ||
total_used_mem_gb / (args.trials), | ||
) | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.