In [2]:
import os

from alphafold3_pytorch import collate_inputs_to_batched_atom_input
from alphafold3_pytorch.alphafold3 import Alphafold3
from alphafold3_pytorch.inputs import (
    PDBDataset,
    molecule_to_atom_input,
    pdb_input_to_molecule_input,
)
from alphafold3_pytorch.data.weighted_pdb_sampler import WeightedPDBSampler


In [3]:
# data_test = os.path.join("data", "test")
data_test = os.path.join('/cpfs01/projects-HDD/cfff-6f3a36a0cd1e_HDD/public/protein/datasets/AF3/data/pdb_data/data_caches/200_mmcif')

"""Test a PDBDataset constructed using a WeightedPDBSampler."""
interface_mapping_path = os.path.join(data_test, "interface_cluster_mapping.csv")
chain_mapping_paths = [
    os.path.join(data_test, "ligand_chain_cluster_mapping.csv"),
    os.path.join(data_test, "nucleic_acid_chain_cluster_mapping.csv"),
    os.path.join(data_test, "peptide_chain_cluster_mapping.csv"),
    os.path.join(data_test, "protein_chain_cluster_mapping.csv"),
]

sampler = WeightedPDBSampler(
    chain_mapping_paths=chain_mapping_paths,
    interface_mapping_path=interface_mapping_path,
    batch_size=1,
)

dataset = PDBDataset(
    folder=os.path.join("/cpfs01/projects-HDD/cfff-6f3a36a0cd1e_HDD/public/protein/datasets/AF3/data/pdb_data", "merged_mmcifs"), 
    sampler=sampler, sample_type="default", crop_size=128
)
len(dataset)

[32m2024-09-04 08:15:13.784[0m | [1mINFO    [0m | [36malphafold3_pytorch.data.weighted_pdb_sampler[0m:[36m__init__[0m:[36m225[0m - [1mPrecomputing chain and interface weights. This may take several minutes to complete.[0m
[32m2024-09-04 08:15:13.786[0m | [1mINFO    [0m | [36malphafold3_pytorch.data.weighted_pdb_sampler[0m:[36m__init__[0m:[36m258[0m - [1mFinished precomputing chain and interface weights.[0m


208

In [4]:
# dataset[0]

PDBInput(mmcif_filepath='/cpfs01/projects-HDD/cfff-6f3a36a0cd1e_HDD/public/protein/datasets/AF3/data/pdb_data/merged_mmcifs/21/521p-assembly1.cif', biomol=None, chains=('A', 'B'), cropping_config=None, msa_dir=None, templates_dir=None, add_atom_ids=False, add_atompair_ids=False, directed_bonds=False, training=None, distillation=False, resolution=None, max_msas_per_chain=None, max_templates_per_chain=None, num_templates_per_chain=None, kalign_binary_path=None, extract_atom_feats_fn=<function default_extract_atom_feats_fn at 0x7f2a62caa4d0>, extract_atompair_feats_fn=<function default_extract_atompair_feats_fn at 0x7f2a62caa560>)

In [None]:
# pdb_input = dataset[0]
# print(11111)

11111


In [None]:
# dataset[0]

PDBInput(mmcif_filepath='/cpfs01/projects-HDD/cfff-6f3a36a0cd1e_HDD/public/protein/datasets/AF3/data/pdb_data/merged_mmcifs/sl/6sl5-assembly1.cif', biomol=None, chains=('E', 'F'), cropping_config=None, msa_dir=None, templates_dir=None, add_atom_ids=False, add_atompair_ids=False, directed_bonds=False, training=None, resolution=None, max_msas_per_chain=None, extract_atom_feats_fn=<function default_extract_atom_feats_fn at 0x7f0acc654820>, extract_atompair_feats_fn=<function default_extract_atompair_feats_fn at 0x7f0acc6548b0>)

In [12]:
pdb_input = dataset[5]
mol_input = pdb_input_to_molecule_input(pdb_input=pdb_input)

In [5]:
from alphafold3_pytorch.utils.utils import default, exists, first
error_cnt=0
for i in range(len(dataset)):
    filepath = dataset[i].mmcif_filepath
    file_id = os.path.splitext(os.path.basename(filepath))[0] if exists(filepath) else None
    if file_id !='2mtz-assembly1':
        continue
    print(f"Processing:{file_id}")
    mol_input = pdb_input_to_molecule_input(pdb_input=dataset[i])
    atom_input = molecule_to_atom_input(mol_input)
    batched_atom_input = collate_inputs_to_batched_atom_input([atom_input], atoms_per_window=27)


Processing:6mw0-assembly1


: 

In [14]:
error_cnt=0
for i in range(len(dataset)):
    data = dataset[i]
    filepath = data.mmcif_filepath
    file_id = os.path.splitext(os.path.basename(filepath))[0] if exists(filepath) else None
    try:
        mol_input = pdb_input_to_molecule_input(pdb_input=data)
        print(f"pass data:{i} | {file_id}")
    except Exception as e:
        print(f"Error in {i}:{file_id}")
        print(f'Exception: {e}')
        error_cnt+=1
        if error_cnt ==2:
            break


pass data:0 | 214d-assembly1
pass data:1 | 5vvr-assembly1
pass data:2 | 1n4r-assembly1
pass data:3 | 108d-assembly1
pass data:4 | 4joe-assembly1
Error in 5:5my9-assembly1
Exception: index 493 is out of bounds for dimension 0 with size 493
pass data:6 | 308d-assembly1
pass data:7 | 207d-assembly1
pass data:8 | 5nj8-assembly1
pass data:9 | 209d-assembly1
pass data:10 | 1mt8-assembly1
pass data:11 | 315d-assembly1
pass data:12 | 7nhl-assembly1
pass data:13 | 207d-assembly1
pass data:14 | 220l-assembly1
Error in 15:7nhm-assembly1
Exception: The size of tensor a (6465) must match the size of tensor b (6478) at non-singleton dimension 0


In [None]:
mol_input = pdb_input_to_molecule_input(pdb_input=dataset[0])

In [None]:
atom_input = molecule_to_atom_input(mol_input)
batched_atom_input = collate_inputs_to_batched_atom_input([atom_input], atoms_per_window=27)

In [None]:
for key,value in batched_atom_input.dict().items():
    try:
        print(key,value.shape)
    except:
        print(key,value)
print(batched_atom_input.dict()['filepath'])

atom_inputs torch.Size([1, 25240, 3])
molecule_ids torch.Size([1, 3176])
molecule_atom_lens torch.Size([1, 3176])
atompair_inputs torch.Size([1, 935, 27, 54, 5])
additional_molecule_feats torch.Size([1, 3176, 5])
is_molecule_types torch.Size([1, 3176, 5])
is_molecule_mod torch.Size([1, 3176, 4])
additional_msa_feats torch.Size([1, 1, 3176, 2])
additional_token_feats torch.Size([1, 3176, 33])
templates None
msa None
token_bonds torch.Size([1, 3176, 3176])
atom_ids None
atom_parent_ids torch.Size([1, 25240])
atompair_ids None
template_mask None
msa_mask None
atom_pos torch.Size([1, 25240, 3])
missing_atom_mask torch.Size([1, 25240])
molecule_atom_indices torch.Size([1, 3176])
distogram_atom_indices torch.Size([1, 3176])
atom_indices_for_frame torch.Size([1, 3176, 3])
distance_labels None
resolved_labels torch.Size([1, 25240])
resolution torch.Size([1])
chains torch.Size([1, 2])
filepath ('/cpfs01/projects-HDD/cfff-6f3a36a0cd1e_HDD/public/protein/datasets/AF3/data/pdb_data/merged_mmcifs/j

In [None]:
alphafold3 = Alphafold3(
        dim_atom_inputs=3,
        dim_atompair_inputs=5,
        atoms_per_window=27,
        dim_template_feats=44,
        num_dist_bins=38,
        confidence_head_kwargs=dict(pairformer_depth=1),
        template_embedder_kwargs=dict(pairformer_stack_depth=1),
        msa_module_kwargs=dict(depth=1),
        pairformer_stack=dict(depth=2),
        diffusion_module_kwargs=dict(
            atom_encoder_depth=1,
            token_transformer_depth=1,
            atom_decoder_depth=1,
        ),
    ).cuda()
input_data = {k: v.cuda() if v is not None else v for k, v in batched_atom_input.model_forward_dict().items()}

loss = alphafold3(**input_data)
loss.backward()

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.81 GiB. GPU 0 has a total capacity of 79.35 GiB of which 3.47 GiB is free. Process 7256 has 75.88 GiB memory in use. Of the allocated memory 70.26 GiB is allocated by PyTorch, and 5.13 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

tensor(8.9960, device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
# from pathlib import Path
# folder= os.path.join("/cpfs01/projects-HDD/cfff-6f3a36a0cd1e_HDD/public/protein/datasets/AF3/data/pdb_data", "test_mmcifs")
# if isinstance(folder, str):
#     folder = Path(folder)
# sampler_pdb_ids = set(sampler.mappings.get_column("pdb_id").to_list())
# files = {
#     os.path.splitext(os.path.basename(filepath.name))[0]: filepath
#     for filepath in folder.glob(os.path.join("**", "*.cif"))
#     if os.path.splitext(os.path.basename(filepath.name))[0] in sampler_pdb_ids
# }
# files

{}

173540


: 