In [1]:
!git clone https://github.com/wengong-jin/RefineGNN.git

Cloning into 'RefineGNN'...
remote: Enumerating objects: 119, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 119 (delta 11), reused 28 (delta 7), pack-reused 82 (from 1)[K
Receiving objects: 100% (119/119), 619.29 MiB | 16.10 MiB/s, done.
Resolving deltas: 100% (28/28), done.
Updating files: 100% (46/46), done.


Respository RefineGNN is structured like the following:
$$
RefineGNN/
├─ ckpts/
├─ data/
│  ├─ sabdab_2022_01/
│  │  ├─ test_data.jsonl
│  │  ├─ train_data.jsonl
│  │  ├─ val_data.jsonl
│  ├─ sabdab_2022_02/
├─ structgen/
│  ├─ __init__.py
│  ├─ data.py
│  ├─ hierarchical.py
├─ ab_train.py
$$

# Data processing

In [2]:
import os
import sys
sys.path.append('/content/RefineGNN')
os.chdir("/content/RefineGNN")
# !gunzip "data/sabdab/hcdr3_cluster/test_data.jsonl.gz"
# !gunzip "data/sabdab/hcdr3_cluster/train_data.jsonl.gz"
# !gunzip "data/sabdab/hcdr3_cluster/val_data.jsonl.gz"
!gunzip "data/sabdab_2022_01/test_data.jsonl.gz"
!gunzip "data/sabdab_2022_01/train_data.jsonl.gz"
!gunzip "data/sabdab_2022_01/val_data.jsonl.gz"

This shows how the data is formatted. For each entry in the train/val/test set, each pdb has antibody and antigen information including the masked region of the antibody that corresponds to the CDR loop that is predicted.

In [None]:
import json

file_path = "/content/RefineGNN/data/sabdab_2022_01/val_data.jsonl"

pdb_list = []
antibody_seq_list = []
antibody_cdr_list = []
antibody_coords_list = []
antibody_atypes_list = []
antigen_seq_list = []
antigen_coords_list = []
antigen_atypes_list = []

with open(file_path, 'r') as file:
    for line in file:
        record = json.loads(line)
        pdb_list.append(record.get("pdb"))
        antibody_seq_list.append(record.get("antibody_seq"))
        antibody_cdr_list.append(record.get("antibody_cdr"))
        antibody_coords_list.append(record.get("antibody_coords"))
        antibody_atypes_list.append(record.get("antibody_atypes"))
        antigen_seq_list.append(record.get("antigen_seq"))
        antigen_coords_list.append(record.get("antigen_coords"))
        antigen_atypes_list.append(record.get("antigen_atypes"))

data_format_summary = {
    "pdb": "PDB identifier for the structure",
    "antibody_seq": "Antibody amino acid sequence",
    "antibody_cdr": "Mapping of residues to complementarity-determining regions (CDRs)",
    "antibody_coords": "3D coordinates for each residue in the antibody",
    "antibody_atypes": "Atomic types for the antibody",
    "antigen_seq": "Antigen amino acid sequence",
    "antigen_coords": "3D coordinates for each residue in the antigen",
    "antigen_atypes": "Atomic types for the antigen"
}


data_format_summary, {
    "pdb_sample": pdb_list[:1],
    "antibody_seq_sample": antibody_seq_list[:1],
    "antibody_cdr_sample": antibody_cdr_list[:1],
    "antibody_coords_sample": antibody_coords_list[:1],
    "antibody_atypes_sample": antibody_atypes_list[:1],
    "antigen_seq_sample": antigen_seq_list[:1],
    "antigen_coords_sample": antigen_coords_list[:1],
    "antigen_atypes_sample": antigen_atypes_list[:1]
}


({'pdb': 'PDB identifier for the structure',
  'antibody_seq': 'Antibody amino acid sequence',
  'antibody_cdr': 'Mapping of residues to complementarity-determining regions (CDRs)',
  'antibody_coords': '3D coordinates for each residue in the antibody',
  'antibody_atypes': 'Atomic types for the antibody',
  'antigen_seq': 'Antigen amino acid sequence',
  'antigen_coords': '3D coordinates for each residue in the antigen',
  'antigen_atypes': 'Atomic types for the antigen'},
 {'pdb_sample': ['7f7e'],
  'antibody_seq_sample': ['EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAIVGSGGSTYYADSVKGRFIISRDNSKNTLYLQMNSLRAEDTAVYYCAKSLIYGHYDILTGAYYFDYWGQGTLVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEP'],
  'antibody_cdr_sample': ['000000000000000000000000011111111000000000000000002222222200000000000000000000000000000000000000333333333333333333330000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

For the dataset experiments, we truncate the full dataset to include 5%, 10%, 20%, 40%, 75%, and 100% of the original dataset by changing the ``fraction'' variable.

In [None]:
input_folder = "/content/RefineGNN/data/sabdab_2022_01"
output_folder = "/content/RefineGNN/data/sabdab_2022_02_05"
os.makedirs(output_folder, exist_ok=True)
files_to_process = ["train_data.jsonl", "val_data.jsonl", "test_data.jsonl"]

def process_file(input_file, output_file, fraction=0.05):
    """
    Process a JSONL file, reducing its size to a specified fraction.
    """
    with open(input_file, 'r') as infile:
        lines = infile.readlines()

    subset_size = max(1, int(len(lines) * fraction))
    subset = lines[:subset_size]

    with open(output_file, 'w') as outfile:
        outfile.writelines(subset)

for filename in files_to_process:
    input_file = os.path.join(input_folder, filename)
    output_file = os.path.join(output_folder, filename)

    if os.path.exists(input_file):
        process_file(input_file, output_file)
        print(f"Processed {filename} into {output_file}")
    else:
        print(f"File not found: {input_file}")


Processed train_data.jsonl into /content/RefineGNN/data/sabdab_2022_02_05/train_data.jsonl
Processed val_data.jsonl into /content/RefineGNN/data/sabdab_2022_02_05/val_data.jsonl
Processed test_data.jsonl into /content/RefineGNN/data/sabdab_2022_02_05/test_data.jsonl


This finally integrates the data processing so that it fits with the RefineGNN implementation of data.py. We redefine $\texttt{AntibodyDataset}$, $\texttt{StructureLoader}$, and $\texttt{abc}$ functions so that it is compatible with our new data.

In [None]:
from structgen.data_2 import AntibodyDataset2, StructureLoader2, completize

# Initialize dataset and loader
dataset = AntibodyDataset2(jsonl_file="/content/RefineGNN/data/sabdab_2022_01/train_data.jsonl", cdr_type="3", max_len=130)
loader = StructureLoader2(dataset.data, batch_tokens=100)

for batch in loader:
    print(len(batch))
    X_antibody, S_antibody, mask_antibody, X_antigen, S_antigen, mask_antigen = completize_debug(batch)
    print("Antibody coords shape:", X_antibody.shape)
    print("Antigen coords shape:", X_antigen.shape)
    print("Antibody sequence shape:", S_antibody.shape)
    print("Antigen sequence shape:", S_antigen.shape)
    break



1
Antibody coords shape: torch.Size([1, 130, 14, 3])
Antigen coords shape: torch.Size([1, 5, 14, 3])
Antibody sequence shape: torch.Size([1, 130])
Antigen sequence shape: torch.Size([1, 5])


# Running baseline train

In [None]:
# !python ab_train.py --cdr_type 3 --train_path data/sabdab/hcdr3_cluster/train_data.jsonl --val_path data/sabdab/hcdr3_cluster/val_data.jsonl --test_path data/sabdab/hcdr3_cluster/test_data.jsonl
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_01/train_data.jsonl --val_path data/sabdab_2022_01/val_data.jsonl --test_path data/sabdab_2022_01/test_data.jsonl

Namespace(train_path='data/sabdab_2022_01/train_data.jsonl', val_path='data/sabdab_2022_01/val_data.jsonl', test_path='data/sabdab_2022_01/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:4355, Validation:338, Test:351
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
  7% 49/683 [00:42<08:19,  1.27it/s][50] Train PPL = 18.761
 14% 99/683 [01:22<07:32,  1.29it/s][100] Train PPL = 16.021
 22% 149/683 [02:03<09:57,  1.12s/it][1

# Loss function test

Changing MSE loss to L1 loss

In [7]:
# Changed all MSE loss to L1 loss
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_01/train_data.jsonl --val_path data/sabdab_2022_01/val_data.jsonl --test_path data/sabdab_2022_01/test_data.jsonl

Namespace(train_path='data/sabdab_2022_01/train_data.jsonl', val_path='data/sabdab_2022_01/val_data.jsonl', test_path='data/sabdab_2022_01/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:4355, Validation:338, Test:351
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
  7% 49/683 [00:44<08:26,  1.25it/s][50] Train PPL = 18.210
 14% 99/683 [01:24<07:23,  1.32it/s][100] Train PPL = 15.953
 22% 149/683 [02:05<09:57,  1.12s/it][1

# Inverse probability map tests

In [None]:
# Need torch geometric for diffusion model
!pip install torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "import torch; print(torch.__version__)")+cpu.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-$(python3 -c "import torch; print(torch.__version__)")+cpu.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-$(python3 -c "import torch; print(torch.__version__)")+cpu.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-$(python3 -c "import torch; print(torch.__version__)")+cpu.html
!pip install torch-geometric

Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu121+cpu.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp310-cp310-linux_x86_64.whl size=3671568 sha256=658cc6a5c5d994c64b4cb42d17fc6fc0cdeb01b4bfa18e4fb624e264e373d97b
  Stored in directory: /root/.cache/pip/wheels/92/f1/2b/3b46d54b134259f58c8363568569053248040859b1a145b3ce
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu121+cpu.html
Collecting torch-sparse
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
[2

In [None]:
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_02_05/train_data.jsonl --val_path data/sabdab_2022_02_05/val_data.jsonl --test_path data/sabdab_2022_02_05/test_data.jsonl

Namespace(train_path='data/sabdab_2022_02_05/train_data.jsonl', val_path='data/sabdab_2022_02_05/val_data.jsonl', test_path='data/sabdab_2022_02_05/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:217, Validation:16, Test:17
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
predicted_ca shape: torch.Size([5, 32, 3])
antigen_coords shape: torch.Size([5, 5, 14, 3])
Normalized inv_prob_map min: 0.016974281519651413, max: 0.03168

Example for how inverse probability map works with simulated data

In [None]:
import torch

# Dummy inputs to mimic the expected data
B, N, G = 2, 5, 3  # Batch size, Number of residues, Antigen residue count
predicted_ca = torch.rand((B, N, 3))
true_ca = torch.rand((B, N, 3))
antigen_coords = torch.rand((B, G, 3))

def generate_inverse_probability_map(antibody_coords, antigen_coords=None):
    print(f"Input antibody_coords shape: {antibody_coords.shape}")
    if antigen_coords is None:
        B, A, _ = antibody_coords.size()
        G = 10
        placeholder_dist = torch.ones((B, A, G), device=antibody_coords.device)
        inv_prob_map = torch.softmax(-placeholder_dist, dim=-1)
        print(f"Generated placeholder inv_prob_map shape: {inv_prob_map.shape}")
    else:
        print(f"Input antigen_coords shape: {antigen_coords.shape}")
        diff = antibody_coords[:, :, None, :] - antigen_coords[:, None, :, :]
        dist = torch.sqrt(torch.sum(diff ** 2, dim=-1) + 1e-8)
        inv_prob_map = torch.softmax(-dist, dim=-1)
        print(f"Generated inv_prob_map shape: {inv_prob_map.shape}")

    return inv_prob_map

def apply_antigen_constraint(predicted_ca, true_ca, antigen_coords=None):
    print(f"Predicted CA shape: {predicted_ca.shape}")
    print(f"True CA shape: {true_ca.shape}")

    if antigen_coords is None:
        B, N, _ = predicted_ca.size()
        G = 10
        antigen_coords = torch.zeros((B, G, 3), device=predicted_ca.device)
        print(f"Generated placeholder antigen_coords shape: {antigen_coords.shape}")

    inv_prob_map = generate_inverse_probability_map(predicted_ca, antigen_coords)
    pred_diff = predicted_ca[:, :, None, :] - antigen_coords[:, None, :, :]
    pred_dist = torch.sqrt(torch.sum(pred_diff**2, dim=-1) + 1e-8)

    print(f"Predicted pairwise distance shape: {pred_dist.shape}")
    weighted_dist = pred_dist * inv_prob_map
    loss = torch.mean(weighted_dist)
    print(f"Loss value: {loss.item()}")
    return loss

loss = apply_antigen_constraint(predicted_ca, true_ca, antigen_coords)
print(f"Final computed loss: {loss}")

Predicted CA shape: torch.Size([2, 5, 3])
True CA shape: torch.Size([2, 5, 3])
Input antibody_coords shape: torch.Size([2, 5, 3])
Input antigen_coords shape: torch.Size([2, 3, 3])
Generated inv_prob_map shape: torch.Size([2, 5, 3])
Predicted pairwise distance shape: torch.Size([2, 5, 3])
Loss value: 0.20863066613674164
Final computed loss: 0.20863066613674164


In [None]:
# First test with inverse_probability map
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_01/train_data.jsonl --val_path data/sabdab_2022_01/val_data.jsonl --test_path data/sabdab_2022_01/test_data.jsonl

Namespace(train_path='data/sabdab_2022_01/train_data.jsonl', val_path='data/sabdab_2022_01/val_data.jsonl', test_path='data/sabdab_2022_01/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:4355, Validation:338, Test:351
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
  7% 49/683 [00:46<08:31,  1.24it/s][50] Train PPL = 17.560
 14% 99/683 [01:26<07:21,  1.32it/s][100] Train PPL = 16.003
 22% 149/683 [02:08<10:17,  1.16s/it][1

# Testing different sizes of dataset

Dataset experiment with 0.05, 0.1, 0.2, 0.4, 0.75, 1

In [None]:
# Dataset experiment - increasing dataset sizes
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_02_05/train_data.jsonl --val_path data/sabdab_2022_02_05/val_data.jsonl --test_path data/sabdab_2022_02_05/test_data.jsonl

Namespace(train_path='data/sabdab_2022_02_05/train_data.jsonl', val_path='data/sabdab_2022_02_05/val_data.jsonl', test_path='data/sabdab_2022_02_05/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:217, Validation:16, Test:17
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
100% 36/36 [00:39<00:00,  1.10s/it]
100% 4/4 [00:03<00:00,  1.03it/s]
Epoch 0, Val PPL = 16.428, Val RMSD = 8.142
100% 36/36 [00:39<00:00,  1.10s/it]
100%

In [None]:
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_02_10/train_data.jsonl --val_path data/sabdab_2022_02_10/val_data.jsonl --test_path data/sabdab_2022_02_10/test_data.jsonl

Namespace(train_path='data/sabdab_2022_02_10/train_data.jsonl', val_path='data/sabdab_2022_02_10/val_data.jsonl', test_path='data/sabdab_2022_02_10/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:435, Validation:33, Test:35
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
 70% 49/70 [00:42<00:18,  1.17it/s][50] Train PPL = 19.163
100% 70/70 [01:06<00:00,  1.06it/s]
100% 7/7 [00:08<00:00,  1.17s/it]
Epoch 0, Val PPL = 15.933

In [None]:
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_02_20/train_data.jsonl --val_path data/sabdab_2022_02_20/val_data.jsonl --test_path data/sabdab_2022_02_20/test_data.jsonl

Namespace(train_path='data/sabdab_2022_02_20/train_data.jsonl', val_path='data/sabdab_2022_02_20/val_data.jsonl', test_path='data/sabdab_2022_02_20/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:871, Validation:67, Test:70
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
 36% 49/136 [00:41<01:05,  1.32it/s][50] Train PPL = 17.907
 73% 99/136 [01:30<00:34,  1.07it/s][100] Train PPL = 16.147
100% 136/136 [02:04<00:00,  1.10i

In [None]:
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_02_40/train_data.jsonl --val_path data/sabdab_2022_02_40/val_data.jsonl --test_path data/sabdab_2022_02_40/test_data.jsonl

Namespace(train_path='data/sabdab_2022_02_40/train_data.jsonl', val_path='data/sabdab_2022_02_40/val_data.jsonl', test_path='data/sabdab_2022_02_40/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:1742, Validation:135, Test:140
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
 18% 49/274 [00:46<03:18,  1.14it/s][50] Train PPL = 18.794
 36% 99/274 [01:26<02:18,  1.27it/s][100] Train PPL = 16.032
 54% 149/274 [02:08<02:29,  1.

In [None]:
!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python ab_train_2.py --cdr_type 3 --train_path data/sabdab_2022_02_75/train_data.jsonl --val_path data/sabdab_2022_02_75/val_data.jsonl --test_path data/sabdab_2022_02_75/test_data.jsonl

Namespace(train_path='data/sabdab_2022_02_75/train_data.jsonl', val_path='data/sabdab_2022_02_75/val_data.jsonl', test_path='data/sabdab_2022_02_75/test_data.jsonl', save_dir='ckpts/tmp', load_model=None, cdr_type='3', hidden_size=256, batch_tokens=100, k_neighbors=9, block_size=8, update_freq=1, depth=4, vocab_size=21, num_rbf=16, dropout=0.1, lr=0.001, clip_norm=5.0, epochs=10, seed=7, anneal_rate=0.9, print_iter=50)
Training:3266, Validation:253, Test:263
  scaler = GradScaler()
  with autocast():
  with torch.cuda.amp.autocast():
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
  return fn(*args, **kwargs)
 10% 49/513 [00:40<06:22,  1.21it/s][50] Train PPL = 19.705
 19% 99/513 [01:23<07:43,  1.12s/it][100] Train PPL = 16.186
 29% 149/513 [02:03<04:38,  1.

# Code reference

Since Colab deletes the saved variables and files after the session times out, the new hierarchical.py, data.py, and ab_train.py are below.

In [None]:
# data_2.py

import torch
from torch.utils.data import Dataset
import numpy as np
import json
import random

alphabet = '#ACDEFGHIKLMNPQRSTVWY'  # Amino acid alphabet
DUMMY = {
    'pdb': None,
    'antibody_seq': '#' * 10,
    'antigen_seq': '#' * 10,
    'antibody_coords': np.zeros((10, 3)) + np.nan,
    'antigen_coords': np.zeros((10, 3)) + np.nan,
    'antibody_cdr': '#' * 10,
    'antibody_atypes': [0] * 10,
    'antigen_atypes': [0] * 10,
}

class AntibodyDataset2:
    def __init__(self, jsonl_file, cdr_type='3', max_len=130):
        self.data = []
        with open(jsonl_file) as f:
            lines = f.readlines()
            for line in lines:
                entry = json.loads(line)

                # Skip entries without antibody CDRs
                if entry['antibody_cdr'] is None or cdr_type not in entry['antibody_cdr']:
                    continue

                # Truncate antibody information based on CDR location and max_len
                last_cdr = entry['antibody_cdr'].rindex(cdr_type)
                if last_cdr >= max_len - 1:
                    entry['antibody_seq'] = entry['antibody_seq'][last_cdr - max_len + 10 : last_cdr + 10]
                    entry['antibody_cdr'] = entry['antibody_cdr'][last_cdr - max_len + 10 : last_cdr + 10]
                    entry['antibody_coords'] = entry['antibody_coords'][last_cdr - max_len + 10 : last_cdr + 10]
                    entry['antibody_atypes'] = entry['antibody_atypes'][last_cdr - max_len + 10 : last_cdr + 10]
                else:
                    entry['antibody_seq'] = entry['antibody_seq'][:max_len]
                    entry['antibody_cdr'] = entry['antibody_cdr'][:max_len]
                    entry['antibody_coords'] = entry['antibody_coords'][:max_len]
                    entry['antibody_atypes'] = entry['antibody_atypes'][:max_len]

                # Truncate antigen information
                entry['antigen_seq'] = entry['antigen_seq'][:max_len]
                entry['antigen_coords'] = entry['antigen_coords'][:max_len]
                entry['antigen_atypes'] = entry['antigen_atypes'][:max_len]

                # Append valid entry
                if len(entry['antibody_seq']) > 0 and len(entry['antigen_seq']) > 0:
                    self.data.append(entry)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class StructureLoader2:
    def __init__(self, dataset, batch_tokens, interval_sort=0):
        self.dataset = dataset
        self.size = len(dataset)
        self.lengths = [len(dataset[i]['antibody_seq']) for i in range(self.size)]
        self.batch_tokens = batch_tokens

        if interval_sort > 0:
            cdr_type = str(interval_sort)
            self.lengths = [dataset[i]['antibody_cdr'].count(cdr_type) for i in range(self.size)]
            self.intervals = [
                (dataset[i]['antibody_cdr'].index(cdr_type), dataset[i]['antibody_cdr'].rindex(cdr_type))
                for i in range(self.size)
            ]
            sorted_ix = sorted(range(self.size), key=self.intervals.__getitem__)
        else:
            sorted_ix = np.argsort(self.lengths)

        # Cluster into batches of similar sizes
        clusters, batch = [], []
        for ix in sorted_ix:
            size = self.lengths[ix]
            if size * (len(batch) + 1) <= self.batch_tokens:
                batch.append(ix)
            else:
                clusters.append(batch)
                batch = [ix]
        if len(batch) > 0:
            clusters.append(batch)
        self.clusters = clusters

    def __len__(self):
        return len(self.clusters)

    def __iter__(self):
        np.random.shuffle(self.clusters)
        for b_idx in self.clusters:
            batch = [self.dataset[i] for i in b_idx]
            yield batch

"""
Data format explained in the beginning of the notebook
"""
def completize_data(batch):
    B = len(batch)
    L_antibody = max(len(b['antibody_seq']) for b in batch)  # Max antibody length
    L_antigen = max(len(b['antigen_seq']) for b in batch)    # Max antigen length

    X_antibody = np.zeros([B, L_antibody, 14, 3])
    S_antibody = np.zeros([B, L_antibody], dtype=np.int32)
    mask_antibody = np.zeros([B, L_antibody], dtype=np.float32)

    X_antigen = np.zeros([B, L_antigen, 14, 3])
    S_antigen = np.zeros([B, L_antigen], dtype=np.int32)

    for i, b in enumerate(batch):
        antibody_coords = np.array(b['antibody_coords'])
        X_antibody[i, :antibody_coords.shape[0], :, :] = antibody_coords
        S_antibody[i, :len(b['antibody_seq'])] = [alphabet.index(a) for a in b['antibody_seq']]
        mask_antibody[i, :len(b['antibody_seq'])] = 1.0

        antigen_coords = np.array(b['antigen_coords'])
        X_antigen[i, :antigen_coords.shape[0], :, :] = antigen_coords
        S_antigen[i, :len(b['antigen_seq'])] = [alphabet.index(a) for a in b['antigen_seq']]

    mask_antibody *= np.isfinite(np.sum(X_antibody, axis=(2, 3))).astype(np.float32)
    X_antibody[np.isnan(X_antibody)] = 0.0
    X_antigen[np.isnan(X_antigen)] = 0.0

    X_antibody = torch.from_numpy(X_antibody).float().cuda()
    S_antibody = torch.from_numpy(S_antibody).long().cuda()
    mask_antibody = torch.from_numpy(mask_antibody).float().cuda()

    X_antigen = torch.from_numpy(X_antigen).float().cuda()
    S_antigen = torch.from_numpy(S_antigen).long().cuda()

    return X_antibody, S_antibody, mask_antibody, X_antigen, S_antigen

In [None]:
# hierarchical_2.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from structgen.encoder import MPNEncoder
from structgen.data import alphabet
from structgen.utils import *
from structgen.protein_features import ProteinFeatures
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GATConv


class HierarchicalEncoder2(nn.Module):

    def __init__(self, args, node_in, edge_in):
        super(HierarchicalEncoder2, self).__init__()
        self.node_in, self.edge_in = node_in, edge_in
        self.W_v = nn.Sequential(
                nn.Linear(self.node_in, args.hidden_size, bias=True),
                Normalize(args.hidden_size)
        )
        self.W_e = nn.Sequential(
                nn.Linear(self.edge_in, args.hidden_size, bias=True),
                Normalize(args.hidden_size)
        )
        self.layers = nn.ModuleList([
                MPNNLayer(args.hidden_size, args.hidden_size * 3, dropout=args.dropout)
                for _ in range(args.depth)
        ])
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def forward(self, V, E, hS, E_idx, mask):
        hS = hS.to(V.dtype)
        mask = mask.to(V.dtype)

        h_v = self.W_v(V)  # [B, N, H]
        h_e = self.W_e(E)  # [B, N, K, H]
        nei_s = gather_nodes(hS, E_idx)  # [B, N, K, H]

        # [B, N, 1] -> [B, N, K, 1] -> [B, N, K]
        vmask = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
        h = h_v
        for layer in self.layers:
            nei_v = gather_nodes(h, E_idx)  # [B, N, K, H]
            nei_h = torch.cat([nei_v, nei_s, h_e], dim=-1)
            # h = layer(h, nei_h, mask_attend=vmask)  # [B, N, H]
            h = checkpoint(layer, h, nei_h, vmask)

            h = h * mask.unsqueeze(-1)  # [B, N, H]
        return h


class HierarchicalDecoder2(nn.Module):

    def __init__(self, args):
        super(HierarchicalDecoder2, self).__init__()
        self.cdr_type = args.cdr_type
        self.k_neighbors = args.k_neighbors
        self.block_size = args.block_size
        self.update_freq = args.update_freq
        self.hidden_size = args.hidden_size
        self.pos_embedding = PosEmbedding(16)

        self.features = ProteinFeatures(
                top_k=args.k_neighbors, num_rbf=args.num_rbf,
                features_type='full',
                direction='bidirectional'
        )
        self.node_in, self.edge_in = self.features.feature_dimensions['full']
        self.O_d0 = nn.Linear(args.hidden_size, 12)
        self.O_d = nn.Linear(args.hidden_size, 12)
        self.O_s = nn.Linear(args.hidden_size, args.vocab_size)
        self.W_s = nn.Embedding(args.vocab_size, args.hidden_size)

        self.struct_mpn = HierarchicalEncoder2(args, self.node_in, self.edge_in)
        self.seq_mpn = HierarchicalEncoder2(args, self.node_in, self.edge_in)
        self.init_mpn = HierarchicalEncoder2(args, 16, 32)
        self.rnn = nn.GRU(
                args.hidden_size, args.hidden_size, batch_first=True,
                num_layers=1, bidirectional=True
        )
        self.W_stc = nn.Sequential(
                nn.Linear(args.hidden_size * 2, args.hidden_size),
                nn.ReLU(),
        )
        self.W_seq = nn.Sequential(
                nn.Linear(args.hidden_size * 2, args.hidden_size),
                nn.ReLU(),
        )

        self.ce_loss = nn.CrossEntropyLoss(reduction='none')
        self.huber_loss = nn.SmoothL1Loss(reduction='none')
        # Toggle this to change to L1 loss
        self.mse_loss = nn.MSELoss(reduction='none')
        # self.mse_loss = nn.L1Loss(reduction='none')

        self.diffusion_model = DiffusionModel(hidden_dim=args.hidden_size)

        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def init_struct(self, B, N, K):
        # initial V
        pos = torch.arange(N).cuda()
        V = self.pos_embedding(pos.view(1, N, 1))  # [1, N, 1, 16]
        V = V.squeeze(2).expand(B, -1, -1)  # [B, N, 6]
        # initial E_idx
        pos = pos.unsqueeze(0) - pos.unsqueeze(1)     # [N, N]
        D_idx, E_idx = pos.abs().topk(k=K, dim=-1, largest=False)    # [N, K]
        E_idx = E_idx.unsqueeze(0).expand(B, -1, -1)  # [B, N, K]
        D_idx = D_idx.unsqueeze(0).expand(B, -1, -1)  # [B, N, K]
        # initial E
        E_rbf = self.features._rbf(3 * D_idx)
        E_pos = self.features.embeddings(E_idx)
        E = torch.cat((E_pos, E_rbf), dim=-1)
        return V, E, E_idx

    def init_coords(self, S, mask):
        B, N = S.size(0), S.size(1)
        K = min(self.k_neighbors, N)
        V, E, E_idx = self.init_struct(B, N, K)

        V = V.float()
        E = E.float()
        S = S.float()
        mask = mask.float()

        h = self.init_mpn(V, E, S, E_idx, mask)
        return self.predict_dist(self.O_d0(h))

    # Q: [B, N, H], K, V: [B, M, H]
    def attention(self, Q, context, cmask, W):
        att = torch.bmm(Q, context.transpose(1, 2))  # [B, N, M]
        att = att - 1e6 * (1 - cmask.unsqueeze(1))
        att = F.softmax(att, dim=-1)
        out = torch.bmm(att, context)  # [B, N, M] * [B, M, H]
        out = torch.cat([Q, out], dim=-1)
        return W(out)

    def predict_dist(self, X):
        X = X.view(X.size(0), X.size(1), 4, 3)
        X_ca = X[:, :, 1, :]
        dX = X_ca[:, None, :, :] - X_ca[:, :, None, :]
        D = torch.sum(dX ** 2, dim=-1)
        V = self.features._dihedrals(X)
        AD = self.features._AD_features(X[:,:,1,:])
        return X.detach().clone(), D, V, AD

    def mask_mean(self, X, mask, i):
        # [B, N, 4, 3] -> [B, 1, 4, 3] / [B, 1, 1, 1]
        X = X[:, i:i+self.block_size]
        if X.dim() == 4:
            mask = mask[:, i:i+self.block_size].unsqueeze(-1).unsqueeze(-1)
        else:
            mask = mask[:, i:i+self.block_size].unsqueeze(-1)
        return torch.sum(X * mask, dim=1, keepdims=True) / (mask.sum(dim=1, keepdims=True) + 1e-8)

    def make_X_blocks(self, X, l, r, mask):
        N = X.size(1)
        lblocks = [self.mask_mean(X, mask, i) for i in range(0, l, self.block_size)]
        rblocks = [self.mask_mean(X, mask, i) for i in range(r + 1, N, self.block_size)]
        bX = torch.cat(lblocks + [X[:, l:r+1]] + rblocks, dim=1)
        return bX.detach()

    def make_S_blocks(self, LS, S, RS, l, r, mask):
        N = S.size(1)
        hS = self.W_s(S.long())
        LS = [self.mask_mean(hS, mask, i) for i in range(0, l, self.block_size)]
        RS = [self.mask_mean(hS, mask, i) for i in range(r + 1, N, self.block_size)]
        bS = torch.cat(LS + [hS[:, l:r+1]] + RS, dim=1)
        lmask = [mask[:, i:i+self.block_size].amax(dim=1, keepdims=True) for i in range(0, l, self.block_size)]
        rmask = [mask[:, i:i+self.block_size].amax(dim=1, keepdims=True) for i in range(r + 1, N, self.block_size)]
        bmask = torch.cat(lmask + [mask[:, l:r+1]] + rmask, dim=1)
        return bS, bmask, len(LS), len(RS)

    def get_completion_mask(self, B, N, cdr_range):
        cmask = torch.zeros(B, N).cuda()
        for i, (l,r) in enumerate(cdr_range):
            cmask[i, l:r+1] = 1
        return cmask

    def remove_cdr_coords(self, X, cdr_range):
        X = X.clone()
        for i, (l,r) in enumerate(cdr_range):
            X[i, l:r+1, :, :] = 0
        return X.clone()

    def forward(self, true_X, true_S, true_cdr, mask, antigen_coords=None, antigen_seq=None):
        B, N = mask.size(0), mask.size(1)
        K = min(self.k_neighbors, N)

        # Ensure dtype consistency within mixed precision
        with torch.cuda.amp.autocast():
            cdr_range = [(cdr.index(self.cdr_type), cdr.rindex(self.cdr_type)) for cdr in true_cdr]
            T_min = min([l for l, r in cdr_range])
            T_max = max([r for l, r in cdr_range])
            cmask = self.get_completion_mask(B, N, cdr_range)
            smask = mask.clone()

            # Encode framework
            S = true_S.clone() * (1 - cmask.long())
            hS, _ = self.rnn(self.W_s(S.long()))
            LS, RS = hS[:, :, :self.hidden_size], hS[:, :, self.hidden_size:]
            hS, mask, offset, suffix = self.make_S_blocks(LS, S, RS, T_min, T_max, mask)
            cmask = torch.cat([cmask.new_zeros(B, offset), cmask[:, T_min:T_max+1], cmask.new_zeros(B, suffix)], dim=1)

            # Ground truth
            true_X = self.make_X_blocks(true_X, T_min, T_max, smask)
            true_V = self.features._dihedrals(true_X)
            true_AD = self.features._AD_features(true_X[:, :, 1, :])
            true_D, mask_2D = pairwise_distance(true_X, mask)
            true_D = true_D ** 2

            # Initialize
            sloss = 0.0
            X, D, V, AD = self.init_coords(hS, mask)
            X = X.detach().clone()
            dloss = self.huber_loss(D, true_D)
            vloss = self.mse_loss(V, true_V)
            aloss = self.mse_loss(AD, true_AD)

            if antigen_coords is not None:
                # For just the naive approach:
                # antigen_constraint_loss = self.apply_antigen_constraint(predicted_ca=X[:, :, 1, :], true_ca=true_X[:, :, 1, :], antigen_coords=antigen_coords)

                # For diffusion model
                antigen_constraint_loss = self.apply_antigen_constraint_diffusion(predicted_ca=X[:, :, 1, :], true_ca=true_X[:, :, 1, :], antigen_coords=antigen_coords)
            else:
                antigen_constraint_loss = 0.0

            for t in range(T_min, T_max + 1):
                # Prepare input
                V, E, E_idx = self.features(X, mask)
                hS = self.make_S_blocks(LS, S, RS, T_min, T_max, smask)[0]

                # Predict residue t
                h = self.seq_mpn(V, E, hS, E_idx, mask)
                h = self.attention(h, LS, smask, self.W_seq)
                logits = self.O_s(h[:, offset + t - T_min])
                logits = logits.float()
                snll = self.ce_loss(logits, true_S[:, t].long())
                sloss = sloss + torch.sum(snll * cmask[:, offset + t - T_min])

                # Teacher forcing on S
                S = S.clone()
                S[:, t] = true_S[:, t]

                # Iterative refinement
                if t % self.update_freq == 0:
                    h = self.struct_mpn(V, E, hS, E_idx, mask)
                    h = self.attention(h, LS, smask, self.W_stc)
                    X, D, V, AD = self.predict_dist(self.O_d(h))
                    X = X.detach().clone()
                    dloss = dloss + self.huber_loss(D, true_D)
                    vloss = vloss + self.mse_loss(V, true_V)
                    aloss = aloss + self.mse_loss(AD, true_AD)

            dloss = torch.sum(dloss * mask_2D) / mask_2D.sum()
            vloss = torch.sum(vloss * mask.unsqueeze(-1)) / mask.sum()
            aloss = torch.sum(aloss * mask.unsqueeze(-1)) / mask.sum()
            sloss = sloss.sum() / cmask.sum()
            loss = sloss + dloss + vloss + aloss + 10*antigen_constraint_loss
            return loss, sloss

    def log_prob(self, true_S, true_cdr, mask, antigen_coords=None, antigen_seq=None):
        B, N = mask.size(0), mask.size(1)
        K = min(self.k_neighbors, N)

        cdr_range = [(cdr.index(self.cdr_type), cdr.rindex(self.cdr_type)) for cdr in true_cdr]
        T_min = min([l for l, r in cdr_range])
        T_max = max([r for l, r in cdr_range])
        cmask = self.get_completion_mask(B, N, cdr_range)
        smask = mask.clone()

        # Initialize
        S = true_S.clone() * (1 - cmask.long())
        hS, _ = self.rnn(self.W_s(S))
        LS, RS = hS[:, :, :self.hidden_size], hS[:, :, self.hidden_size:]
        hS, mask, offset, suffix = self.make_S_blocks(LS, S, RS, T_min, T_max, mask)
        cmask = torch.cat([cmask.new_zeros(B, offset), cmask[:, T_min:T_max+1], cmask.new_zeros(B, suffix)], dim=1)

        # Placeholder for antigen-based adjustments (if required)
        if antigen_coords is not None and antigen_seq is not None:
            # Optionally process antigen_coords and antigen_seq here
            pass

        sloss = 0.0
        X = self.init_coords(hS, mask)[0]
        X = X.detach().clone()

        for t in range(T_min, T_max + 1):
            # Prepare input
            V, E, E_idx = self.features(X, mask)
            hS = self.make_S_blocks(LS, S, RS, T_min, T_max, smask)[0]

            # Predict residue t
            h = self.seq_mpn(V, E, hS, E_idx, mask)
            h = self.attention(h, LS, smask, self.W_seq)
            logits = self.O_s(h[:, offset + t - T_min])
            logits = logits.float()
            snll = self.ce_loss(logits, true_S[:, t].long())
            sloss = sloss + snll * cmask[:, offset + t - T_min]

            # Teacher forcing on S
            S = S.clone()
            S[:, t] = true_S[:, t]

            # Iterative refinement
            if t % self.update_freq == 0:
                h = self.struct_mpn(V, E, hS, E_idx, mask)
                h = self.attention(h, LS, smask, self.W_stc)
                X = self.predict_dist(self.O_d(h))[0]
                X = X.detach().clone()

        ppl = sloss / cmask.sum(dim=-1)
        sloss = sloss.sum() / cmask.sum()
        return ReturnType(nll=sloss, ppl=ppl, X=X, X_cdr=X[:, offset:offset+T_max-T_min+1])


    def generate(self, true_S, true_cdr, mask, return_ppl=False):
        B, N = mask.size(0), mask.size(1)
        K = min(self.k_neighbors, N)

        cdr_range = [(cdr.index(self.cdr_type), cdr.rindex(self.cdr_type)) for cdr in true_cdr]
        T_min = min([l for l,r in cdr_range])
        T_max = max([r for l,r in cdr_range])
        cmask = self.get_completion_mask(B, N, cdr_range)
        smask = mask.clone()

        # initialize
        S = true_S.clone() * (1 - cmask.long())
        hS, _ = self.rnn(self.W_s(S))
        LS, RS = hS[:, :, :self.hidden_size], hS[:, :, self.hidden_size:]
        hS, mask, offset, suffix = self.make_S_blocks(LS, S, RS, T_min, T_max, mask)
        cmask = torch.cat([cmask.new_zeros(B, offset), cmask[:, T_min:T_max+1], cmask.new_zeros(B, suffix)], dim=1)

        X = self.init_coords(hS, mask)[0]
        X = X.detach().clone()
        sloss = 0

        for t in range(T_min, T_max + 1):
            # Prepare input
            V, E, E_idx = self.features(X, mask)
            hS = self.make_S_blocks(LS, S, RS, T_min, T_max, smask)[0]

            # Predict residue t
            h = self.seq_mpn(V, E, hS, E_idx, mask)
            h = self.attention(h, LS, smask, self.W_seq)
            logits = self.O_s(h[:, offset + t - T_min])
            prob = F.softmax(logits, dim=-1)  # [B, 20]
            S[:, t] = torch.multinomial(prob, num_samples=1).squeeze(-1)  # [B, 1]
            sloss = sloss + self.ce_loss(logits, S[:, t]) * cmask[:, offset + t - T_min]

            # Iterative refinement
            h = self.struct_mpn(V, E, hS, E_idx, mask)
            h = self.attention(h, LS, smask, self.W_stc)
            X = self.predict_dist(self.O_d(h))[0]
            X = X.detach().clone()

        S = S.tolist()
        S = [''.join([alphabet[S[i][j]] for j in range(cdr_range[i][0], cdr_range[i][1] + 1)]) for i in range(B)]
        ppl = torch.exp(sloss / cmask.sum(dim=-1))
        return (S, ppl, X[:, offset:offset+T_max-T_min+1]) if return_ppl else S

    def generate_inverse_probability_map(self, antibody_coords, antigen_coords=None):
        """
        New function: Generate an inverse probability map based on distances.
        """
        if antigen_coords is not None:
            # Reduce antigen_coords along the third dimension (e.g., average over atoms)
            antigen_coords_reduced = antigen_coords.mean(dim=2)  # Shape becomes [B, 5, 3]
            # Compute pairwise distances
            dX = antibody_coords.unsqueeze(2) - antigen_coords_reduced.unsqueeze(1)  # Shape: [B, 32, 5, 3]
            distances = torch.norm(dX, dim=-1)  # Shape: [B, 32, 5]
            # Apply a Gaussian kernel to derive an inverse probability map
            sigma = 5.0
            inv_prob_map = torch.exp(-distances**2 / (2 * sigma**2))
            inv_prob_map = 1 - inv_prob_map
            return inv_prob_map
        else:
            return None

    def apply_antigen_constraint(self, predicted_ca, true_ca=None, antigen_coords=None):
        """
        New function: Apply antigen-based constraints as a loss term.
        """
        # print(f"predicted_ca shape: {predicted_ca.shape}")
        # if true_ca is not None:
        #     # print(f"true_ca shape: {true_ca.shape}")
        #     print("")
        # if antigen_coords is not None:
        #     # print(f"antigen_coords shape: {antigen_coords.shape}")
        #     print("")

        # Generate inverse probability map
        inv_prob_map = self.generate_inverse_probability_map(predicted_ca, antigen_coords)

        if inv_prob_map is not None:
            # Compute pairwise distances between predicted and true antibody C-alpha coordinates
            dX = predicted_ca - true_ca
            distances = torch.norm(dX, dim=-1)  # [B, N]

            # Apply the inverse probability map as a weight then aggregate distances
            weighted_distances = distances * inv_prob_map.mean(dim=-1)  # [B, N]
            antigen_constraint_loss = weighted_distances.mean()
            # print(f"antigen_constraint_loss: {antigen_constraint_loss}")

            return antigen_constraint_loss
        else:
            # print("inv_prob_map is None")
            return 0.0

    def generate_inverse_probability_map_diffusion(self, antibody_coords, antigen_coords=None):
        """
        New function: Generate an inverse probability map using a diffusion model.
        """
        if antigen_coords is not None:
            # Pool antigen atomic coordinates (e.g., mean along atom dimension)
            pooled_antigen_coords = antigen_coords.mean(dim=2)  # [B, N_antigen, 3]
            combined_coords = torch.cat([antibody_coords, pooled_antigen_coords], dim=1)  # [B, N_combined, 3]

            # Generate initial noisy map (Gaussian noise)
            B, N_combined, _ = combined_coords.size()
            noisy_map = torch.randn(B, N_combined, 3).to(combined_coords.device).view(-1, 3)  # [total_nodes, 3]

            edge_index, batch = self.construct_graph_edges(combined_coords, B)
            timesteps = torch.randint(0, 1000, (B,), device=combined_coords.device)  # [batch_size]
            inv_prob_map = self.diffusion_model(noisy_map, timesteps, edge_index, batch)

            return inv_prob_map.view(B, N_combined, -1)  # Reshape back to [B, N_combined, features]
        else:
            return None

    def apply_antigen_constraint_diffusion(self, predicted_ca, true_ca=None, antigen_coords=None, epoch=0, max_epochs=100):
        if antigen_coords is not None:
            # Generate inverse probability map
            inv_prob_map = self.generate_inverse_probability_map_diffusion(predicted_ca, antigen_coords)
            if inv_prob_map is not None:
                # Normalize probabilities
                inv_prob_map = F.softmax(inv_prob_map, dim=1)
                dX = predicted_ca - true_ca
                distances = torch.norm(dX, dim=-1)  # [B, N]
                # Slice probabilities for antibody nodes
                inv_prob_map_antibody = inv_prob_map[:, :predicted_ca.size(1)]
                weighted_distances = distances * inv_prob_map_antibody.mean(dim=-1)

                # With regularization:
                # regularization_weight = max(0.1, 0.01 * (1 - epoch / max_epochs))
                # antigen_constraint_loss = weighted_distances.mean() + regularization_weight * torch.mean((1 - inv_prob_map_antibody) ** 2)

                # Without regularization:
                antigen_constraint_loss = weighted_distances.mean()
                # print(f" antigen_constraint_loss: {antigen_constraint_loss}")

                return antigen_constraint_loss
            else:
                return 0.0
        else:
            return 0.0

    def construct_graph_edges(self, combined_coords, batch_size):
        """
        New function:
            Construct edges (graph will be a fully connected graph)
            between Ab/Ag and batch indices for the graph
            representation.

        Args:
            combined_coords: Combined antibody and antigen coordinates [B, N_combined, 3].
            batch_size: Number of graphs in the batch.

        Returns:
            edge_index: Edge list [2, total_edges].
            batch: Batch indices [total_nodes].
        """
        B, N, _ = combined_coords.size()
        total_nodes = B * N

        edge_index = []
        batch = []
        for i in range(B):
            nodes = torch.arange(i * N, (i + 1) * N, device=combined_coords.device)
            edges = torch.combinations(nodes, r=2).t()
            edge_index.append(edges)
            batch.extend([i] * N)

        edge_index = torch.cat(edge_index, dim=1)
        batch = torch.tensor(batch, device=combined_coords.device)
        return edge_index, batch

"""
New Diffusion model that learns the inverse probability matrix given antigen data
"""
class DiffusionModel(nn.Module):
    def __init__(self, hidden_dim):
        super(DiffusionModel, self).__init__()
        self.embedding_t = nn.Embedding(1000, hidden_dim)
        self.input_projection = nn.Linear(3, hidden_dim)
        self.gnn1 = GATConv(hidden_dim, hidden_dim)
        self.gnn2 = GATConv(hidden_dim, hidden_dim)
        self.output_projection = nn.Linear(hidden_dim, 3)

    def forward(self, X_t, t, edge_index, batch):
        """
        Process 3D graph data with GNN layers.

        Args:
            X_t: Noisy node features [total_nodes, 3].
            t: Time embeddings [batch_size].
            edge_index: Graph edges [2, total_edges].
            batch: Batch assignments [total_nodes].

        Returns:
            Denoised node features [total_nodes, 3].
        """
        t_emb = self.embedding_t(t)  # [batch_size, hidden_dim]
        t_emb_nodes = t_emb[batch]  # [total_nodes, hidden_dim]
        # Project input features and add time embeddings
        x = self.input_projection(X_t)  # [total_nodes, hidden_dim]
        x = x + t_emb_nodes

        # GNN layers
        x = F.relu(self.gnn1(x, edge_index))  # [total_nodes, hidden_dim]
        x = self.gnn2(x, edge_index)  # [total_nodes, hidden_dim]

        # Project to output space (3D coordinates)
        x = self.output_projection(x)  # [total_nodes, 3]
        return x


In [None]:
# ab_train_2.py

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader

import json
import csv
import math, random, sys
import numpy as np
import argparse
import os
from torch.cuda.amp import GradScaler, autocast

from structgen import *
from tqdm import tqdm


def evaluate(model, loader, args):
    model.eval()
    val_nll = val_tot = 0.0
    val_rmsd = []
    with torch.no_grad():
        for hbatch in tqdm(loader):
            X_antibody, S_antibody, mask_antibody, X_antigen, S_antigen = completize_data(hbatch)
            antibody_cdr = [b['antibody_cdr'] for b in hbatch]

            for i in range(len(hbatch)):
                L = mask_antibody[i:i+1].sum().long().item()
                if L > 0:
                    out = model.log_prob(
                        S_antibody[i:i+1, :L].long(),
                        [antibody_cdr[i]],
                        mask_antibody[i:i+1, :L],
                        antigen_coords=X_antigen[i:i+1, :],
                        antigen_seq=S_antigen[i:i+1, :]
                    )
                    nll, X_pred = out.nll, out.X_cdr
                    val_nll += nll.item() * antibody_cdr[i].count(args.cdr_type)
                    val_tot += antibody_cdr[i].count(args.cdr_type)
                    l, r = antibody_cdr[i].index(args.cdr_type), antibody_cdr[i].rindex(args.cdr_type)
                    rmsd = compute_rmsd(
                        X_pred[:, :, 1, :],  # predicted alpha carbons
                        X_antibody[i:i+1, l:r+1, 1, :],  # ground truth alpha carbons
                        mask_antibody[i:i+1, l:r+1]
                    )
                    val_rmsd.append(rmsd.item())

    val_ppl = math.exp(val_nll / val_tot) if val_tot > 0 else float('inf')
    avg_rmsd = sum(val_rmsd) / len(val_rmsd) if val_rmsd else float('inf')
    return val_ppl, avg_rmsd

parser = argparse.ArgumentParser()
parser.add_argument('--train_path', default='data/sabdab_2022_01/train_data.jsonl')
parser.add_argument('--val_path', default='data/sabdab_2022_01/val_data.jsonl')
parser.add_argument('--test_path', default='data/sabdab_2022_01/test_data.jsonl')
parser.add_argument('--save_dir', default='ckpts/tmp')
parser.add_argument('--load_model', default=None)

parser.add_argument('--cdr_type', default='3')

parser.add_argument('--hidden_size', type=int, default=256)
parser.add_argument('--batch_tokens', type=int, default=100)
parser.add_argument('--k_neighbors', type=int, default=9)
parser.add_argument('--block_size', type=int, default=8)
parser.add_argument('--update_freq', type=int, default=1)
parser.add_argument('--depth', type=int, default=4)
parser.add_argument('--vocab_size', type=int, default=21)
parser.add_argument('--num_rbf', type=int, default=16)
parser.add_argument('--dropout', type=float, default=0.1)

parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--clip_norm', type=float, default=5.0)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--anneal_rate', type=float, default=0.9)
parser.add_argument('--print_iter', type=int, default=50)

args = parser.parse_args()
print(args)

os.makedirs(args.save_dir, exist_ok=True)

torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

loaders = []
for path in [args.train_path, args.val_path, args.test_path]:
    data = AntibodyDataset2(path, cdr_type=args.cdr_type)
    loader = StructureLoader2(data.data, batch_tokens=args.batch_tokens, interval_sort=int(args.cdr_type))
    loaders.append(loader)

loader_train, loader_val, loader_test = loaders

model = HierarchicalDecoder2(args).cuda()

optimizer = torch.optim.Adam(model.parameters())
if args.load_model:
    model_ckpt, opt_ckpt, model_args = torch.load(args.load_model)
    model = HierarchicalDecoder2(model_args).cuda()
    optimizer = torch.optim.Adam(model.parameters())
    model.load_state_dict(model_ckpt)
    optimizer.load_state_dict(opt_ckpt)

print('Training:{}, Validation:{}, Test:{}'.format(
    len(loader_train.dataset), len(loader_val.dataset), len(loader_test.dataset))
)

best_ppl, best_epoch = 100, -1

scaler = GradScaler()

for e in range(args.epochs):
    model.train()
    meter = 0

    for i, hbatch in enumerate(tqdm(loader_train)):
        optimizer.zero_grad()
        X_antibody, S_antibody, mask_antibody, X_antigen, S_antigen = completize_data(hbatch)
        antibody_cdr = [b['antibody_cdr'] for b in hbatch]

        # loss, snll = model(X_antibody, S_antibody, antibody_cdr, mask_antibody, antigen_coords=X_antigen, antigen_seq=S_antigen)

        # loss.backward()
        # optimizer.step()

        with autocast():
            loss, snll = model(
                X_antibody.float(), S_antibody.long(), antibody_cdr, mask_antibody,
                antigen_coords=X_antigen.float(), antigen_seq=S_antigen.float()
            )
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()


        meter += snll.exp().item()
        if (i + 1) % args.print_iter == 0:
            meter /= args.print_iter
            print(f'[{i + 1}] Train PPL = {meter:.3f}')
            meter = 0

    val_ppl, val_rmsd = evaluate(model, loader_val, args)
    ckpt = (model.state_dict(), optimizer.state_dict(), args)
    torch.save(ckpt, os.path.join(args.save_dir, f"model.ckpt.{e}"))
    print(f'Epoch {e}, Val PPL = {val_ppl:.3f}, Val RMSD = {val_rmsd:.3f}')

    if val_ppl < best_ppl:
        best_ppl = val_ppl
        best_epoch = e

if best_epoch >= 0:
    best_ckpt = os.path.join(args.save_dir, f"model.ckpt.{best_epoch}")
    model.load_state_dict(torch.load(best_ckpt)[0])

test_ppl, test_rmsd = evaluate(model, loader_test, args)
print(f'Test PPL = {test_ppl:.3f}, Test RMSD = {test_rmsd:.3f}')

In [None]:
# decoder.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from structgen.encoder import MPNEncoder
from structgen.data import alphabet
from structgen.utils import *
from structgen.protein_features import ProteinFeatures


class Decoder(nn.Module):

    def __init__(self, args, return_coords=True):
        super(Decoder, self).__init__()
        self.k_neighbors = args.k_neighbors
        self.depth = args.depth
        self.hidden_size = args.hidden_size
        self.augment_eps = args.augment_eps
        self.context = args.context
        self.return_coords = return_coords

        self.pos_embedding = PosEmbedding(16)
        self.features = ProteinFeatures(
                top_k=args.k_neighbors, num_rbf=args.num_rbf,
                features_type='dist',
                direction='forward'
        )
        self.node_in, self.edge_in = self.features.feature_dimensions['dist']
        self.O_nei = nn.Sequential(
                nn.Linear(args.hidden_size * 2 + 16, args.hidden_size),
                nn.ReLU(),
                nn.Linear(args.hidden_size, 1),
        )
        self.O_dist = nn.Sequential(
                nn.Linear(args.hidden_size * 2 + 16, args.hidden_size),
                nn.ReLU(),
                nn.Linear(args.hidden_size, 1),
        )
        self.O_s = nn.Linear(args.hidden_size, args.vocab_size)
        self.O_v = nn.Linear(args.hidden_size, self.node_in)
        self.O_e = nn.Linear(args.hidden_size, self.edge_in - self.features.num_positional_embeddings)

        self.struct_mpn = MPNEncoder(args, self.node_in, self.edge_in)
        self.seq_mpn = MPNEncoder(args, self.node_in, self.edge_in)

        if args.context:
            self.W_stc = nn.Sequential(
                    nn.Linear(args.hidden_size * 2, args.hidden_size),
                    nn.ReLU(),
            )
            self.W_seq = nn.Sequential(
                    nn.Linear(args.hidden_size * 2, args.hidden_size),
                    nn.ReLU(),
            )
            self.crnn = nn.GRU(
                    len(alphabet), args.hidden_size,
                    batch_first=True, num_layers=1,
                    dropout=args.dropout
            )

        self.ce_loss = nn.CrossEntropyLoss(reduction='none')
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
        # Change this to toggle L1 or MSE Loss
        self.mse_loss = nn.MSELoss(reduction='none')
        # self.mse_loss = nn.L1Loss(reduction='none')

        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    # Q: [B, N, H], K, V: [B, M, H]
    def attention(self, Q, context, W):
        context, cmask = context  # cmask: [B, M]
        att = torch.bmm(Q, context.transpose(1, 2))  # [B, N, M]
        att = att - 1e6 * (1 - cmask.unsqueeze(1))
        att = F.softmax(att, dim=-1)
        out = torch.bmm(att, context)  # [B, N, M] * [B, M, H]
        out = torch.cat([Q, out], dim=-1)
        return W(out)

    def encode_context(self, context):
        cS, cmask, crange = context
        cS = F.one_hot(cS, num_classes=len(alphabet)).float()
        cH, _ = self.crnn(cS)
        return (cH, cmask)

    def forward(self, X, S, L, mask, context=None, debug=False):
        # X: [B, N, 4, 3], S: [B, N], mask: [B, N]
        true_V, _, _ = self.features(X, mask)
        N, K = S.size(1), self.k_neighbors

        # data augmentation
        V, E, E_idx = self.features(
                X + self.augment_eps * torch.randn_like(X),
                mask
        )

        # run struct MPN
        h = self.struct_mpn(V, E, S, E_idx, mask)
        if self.context:
            context = self.encode_context(context)
            h = self.attention(h, context, self.W_stc)

        # predict node feature with h_{v-1}
        vout = self.O_v(h[:, :-1])
        vloss = self.mse_loss(vout, true_V[:, 1:]).mean(dim=-1)
        vloss = torch.sum(vloss * mask[:, 1:]) / mask[:, 1:].sum()

        # predict neighbors with h_{v-1}, h_u, E_pos
        E_next, nlabel, dlabel, nmask = get_nei_label(X, mask, K)  # [B, N-1, N]
        h_cur = h[:, :-1].unsqueeze(2).expand(-1,-1,N,-1)  # [B, N-1, N, H]
        h_pre = gather_nodes(h, E_next)  # [B, N-1, N, H]
        pos = torch.arange(1, N).cuda().view(1, -1, 1) - E_next  # [B, N-1, N]
        E_pos = self.pos_embedding(pos)  # [B, N-1, N, H]
        h_nei = torch.cat([h_cur, h_pre, E_pos], dim=-1)
        nout = self.O_nei(h_nei).squeeze(-1)  # [B, N-1, N]
        nloss = self.bce_loss(nout, nlabel.float())
        nloss = torch.sum(nloss * nmask) / nmask.sum()

        # predict neighbors distance
        dout = self.O_dist(h_nei).squeeze(-1)  # [B, N-1, N]
        dout = dout[:, :, :K]  # [B, N-1, K]
        dmask = nmask[:, :, :K]  # [B, N-1, K]
        dlabel = dlabel.clamp(max=20)
        dlabel = (dlabel[:, :, :K] - 10) / 10  # D in [0, 20]
        dloss = self.mse_loss(dout, dlabel)
        dloss = torch.sum(dloss * dmask) / dmask.sum()

        # sequence prediction
        h = self.seq_mpn(V, E, S, E_idx, mask)
        if self.context:
            h = self.attention(h, context, self.W_seq)

        sout = self.O_s(h)
        sloss = self.ce_loss(sout.view(-1, sout.size(-1)), S.view(-1))
        sloss = torch.sum(sloss * mask.view(-1)) / mask.sum()

        loss = sloss + nloss + vloss + dloss
        dout = dout * 10 + 10
        return (sout, vout, nout, dout) if debug else loss

    def expand_one_residue(self, h, V, E, E_idx, t):
        # predict node feature for t+1
        B, K = len(h), self.k_neighbors
        V[:, t+1] = self.O_v(h[:, t])

        # predict neighbors for t+1
        h_cur = h[:, t:t+1].expand(-1, t+1, -1)  # [B, t+1, H]
        h_pre = h[:, :t+1]  # [B, t+1, H]
        pos = t + 1 - torch.arange(t + 1).view(1, -1, 1).expand(B, -1, -1)  # [B, t+1, 1]
        E_pos = self.pos_embedding(pos.cuda()).squeeze(2)  # [B, t+1, H]
        h_nei = torch.cat([h_cur, h_pre, E_pos], dim=-1)
        nout = self.O_nei(h_nei).squeeze(-1)  # [B, t+1]

        if K <= t + 1:
            _, E_idx[:, t+1] = nout.topk(dim=-1, k=K, largest=True)
            nei_topk = E_idx[:, t+1]  # [B, K]
        else:
            E_idx[:, t+1, :t+1] *= 0
            E_idx[:, t+1, :t+1] += torch.arange(t, -1, -1).view(1,-1).cuda()
            nei_topk = E_idx[:, t+1, :t+1]  # [B, t+1]

        # predict neighbors distance
        # Positional encoding is relative!
        dout = self.O_dist(h_nei).squeeze(-1)  # [B, t+1]
        dout = dout * 10 + 10
        dout = gather_2d(dout, nei_topk)  # [B, t+1]
        rbf_vecs = self.features._rbf(dout.unsqueeze(1))  # [B, 1, t+1, H]
        pos_vecs = self.pos_embedding(nei_topk.unsqueeze(1) - t - 1)  # [B, 1, t+1] => [B, 1, t+1, H]
        E[:, t+1, :t+1] = torch.cat([pos_vecs, rbf_vecs], dim=-1).squeeze(1)  # [B, t+1, H]
        return nout, dout

    def log_prob(self, S, mask, context=None, debug=None):
        B, N = S.size(0), S.size(1)
        K = self.k_neighbors

        V = torch.zeros(B, N+1, self.node_in).cuda()
        V[:, :, :self.node_in // 2] = 1.  # cos(0) = 1
        E = torch.zeros(B, N+1, K, self.edge_in).cuda()
        E_idx = torch.zeros(B, N+1, K).long().cuda() + N - 1
        h_stc = [torch.zeros(B, N, self.hidden_size, requires_grad=True).cuda() for _ in range(self.depth + 1)]
        h_seq = [torch.zeros(B, N, self.hidden_size, requires_grad=True).cuda() for _ in range(self.depth + 1)]

        D = torch.zeros(B, N+1, K).cuda()
        log_prob = []
        if self.context:
            context = self.encode_context(context)

        for t in range(N):
            # run MPN
            h_seq = self.seq_mpn.inc_forward(V, E, S, E_idx, mask, h_seq, t)
            h_stc = self.struct_mpn.inc_forward(V, E, S, E_idx, mask, h_stc, t)

            h = h_seq[-1][:, t:t+1]
            if self.context:
                h = self.attention(h, context, self.W_seq)

            # predict residue for t
            logits = self.O_s(h.squeeze(1))
            lprob = F.log_softmax(logits, dim=-1)
            nll = F.nll_loss(lprob, S[:, t], reduction='none')
            log_prob.append(nll)

            # predict position for t + 1
            h = self.attention(h_stc[-1], context, self.W_stc) if self.context else h_stc[-1]
            V, E, E_idx = V.clone(), E.clone(), E_idx.clone()  # avoid inplace autograd error
            nout, dout = self.expand_one_residue(h, V, E, E_idx, t)
            V, E, E_idx = V.clone(), E.clone(), E_idx.clone()  # avoid inplace autograd error
            D[:, t+1, :dout.size(-1)] = dout

            if debug and t < N - 1:
                self.debug_decode(debug, logits, V, E, E_idx, mask, nout, dout, t)

        log_prob = torch.stack(log_prob, dim=1)  # [B, N]
        ppl = torch.sum(log_prob * mask, dim=-1) / mask.sum(dim=-1)
        log_prob = torch.sum(log_prob * mask) / mask.sum()
        if self.return_coords:
            X = fit_coords(D[:, :-1, :].detach(), E_idx[:, :-1, :].detach(), mask)
            X = X.unsqueeze(2).expand(-1,-1,4,-1)
            return ReturnType(nll=log_prob, ppl=ppl, X_cdr=X)
        else:
            return ReturnType(nll=log_prob, ppl=ppl, X_cdr=None)

    def generate(self, B, N, context=None, return_ppl=False):
        K = self.k_neighbors
        S = torch.zeros(B, N).long().cuda()
        mask = torch.ones(B, N).cuda()

        V = torch.zeros(B, N+1, self.node_in).cuda()
        V[:, :, :self.node_in // 2] = 1.  # cos(0) = 1
        E = torch.zeros(B, N+1, K, self.edge_in).cuda()
        E_idx = torch.zeros(B, N+1, K).long().cuda() + N - 1
        h_stc = [torch.zeros(B, N, self.hidden_size).cuda() for _ in range(self.depth + 1)]
        h_seq = [torch.zeros(B, N, self.hidden_size).cuda() for _ in range(self.depth + 1)]

        if self.context:
            context = self.encode_context(context)

        sloss = 0.
        for t in range(N):
            # run MPN
            h_seq = self.seq_mpn.inc_forward(V, E, S, E_idx, mask, h_seq, t)
            h_stc = self.struct_mpn.inc_forward(V, E, S, E_idx, mask, h_stc, t)

            h = h_seq[-1][:, t:t+1]
            if self.context:
                h = self.attention(h, context, self.W_seq)

            # predict residue for t
            logits = self.O_s(h.squeeze(1))
            prob = F.softmax(logits, dim=-1)  # [B, 20]
            S[:, t] = torch.multinomial(prob, num_samples=1).squeeze(-1)  # [B, 1]
            sloss = sloss + self.ce_loss(logits, S[:, t])

            # predict position for t + 1
            h = self.attention(h_stc[-1], context, self.W_stc) if self.context else h_stc[-1]
            nout, dout = self.expand_one_residue(h, V, E, E_idx, t)

        S = S.tolist()
        S = [''.join([alphabet[S[i][j]] for j in range(N)]) for i in range(B)]
        ppl = torch.exp(sloss / N)
        return (S, ppl) if return_ppl else S

    def debug_decode(self, debug_info, logits, V, E, E_idx, mask, nout, dout, t):
        X, L, true_logits, true_vout, true_nout, true_dout = debug_info[:7]
        true_V, true_E, true_E_idx = self.features(X, mask)

        print(t)
        ll = min(t + 1, self.k_neighbors)
        print('-------S-------')
        print(logits - true_logits[:, t])
        print('-------N-------')
        print(E_idx[:, t+1])
        print(true_E_idx[:, t+1])
        print(nout[:, :ll].sum() - true_nout[:, t, :ll].sum())
        print('-------V-------')
        print(V[:, t+1] - true_vout[:, t])
        print('-------E-------')
        print(dout[:, :ll].sum() - true_dout[:, t, :ll].sum())
        #print(E[:, t+1] - true_E[:, t+1])
        print('---------------')

        V[:, t+1] = true_V[:, t+1]
        E[:, t+1] = true_E[:, t+1]
        E_idx[:, t+1] = true_E_idx[:, t+1]
        input("Press Enter to continue...")
