In [2]:
from pathlib import Path
import numpy as np
import pandas as pd
import h5py

import torch
from torch.utils.data import Dataset

In [3]:
from dataset import ProcessedLigandPocketDataset
from pathlib import Path
import torch_geometric.transforms as T

datadir = '../data/docking_results/processed_crossdock_noH_full_temp'

data_transform = None

train_dataset = ProcessedLigandPocketDataset(Path(datadir, 'train.npz'), transform=data_transform)
test_dataset = ProcessedLigandPocketDataset(Path(datadir, 'test.npz'), transform=data_transform)
val_dataset = ProcessedLigandPocketDataset(Path(datadir, 'val.npz'), transform=data_transform)



In [4]:
train_dataset[0].keys()

dict_keys(['names', 'prompt_labels', 'ref_lig_coords', 'ref_lig_one_hot', 'ref_lig_bonds', 'ref_lig_mask', 'num_ref_lig_atoms', 'pocket_coords', 'pocket_one_hot', 'pocket_mask', 'num_pocket_nodes', 'opt_lig_coords', 'opt_lig_one_hot', 'opt_lig_bond', 'opt_lig_mask', 'num_opt_lig_atoms'])

In [5]:
len(train_dataset)

4639

In [7]:
train_dataset[84]['prompt_labels'].size()


torch.Size([11, 3])

In [6]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=8, num_workers=24, collate_fn=train_dataset.collate_fn, shuffle=False, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=8, num_workers=24, collate_fn=val_dataset.collate_fn, shuffle=False,pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=8, num_workers=24, collate_fn=test_dataset.collate_fn, shuffle=False,pin_memory=True)

In [7]:
for batch in train_loader:
    print(batch.keys())
    break

dict_keys(['names', 'prompt_labels', 'ref_lig_coords', 'ref_lig_one_hot', 'ref_lig_bonds', 'ref_lig_mask', 'num_ref_lig_atoms', 'pocket_coords', 'pocket_one_hot', 'pocket_mask', 'num_pocket_nodes', 'opt_lig_coords', 'opt_lig_one_hot', 'opt_lig_bond', 'opt_lig_mask', 'num_opt_lig_atoms'])


In [8]:
for i, batch in enumerate(train_loader):
    if i == 1:  # 0-based index, so this is the second batch
        print(batch['names'])
        break

['2hoz_A_rec_2hoz_pmp_lig_tt_min/2hoz_A_rec_2hoz_pmp_lig_tt_min_0_pocket10.pdb_2hoz_A_rec_2hoz_pmp_lig_tt_min/2hoz_A_rec_2hoz_pmp_lig_tt_min_0.sdf', '2hoz_A_rec_2hoz_pmp_lig_tt_min/2hoz_A_rec_2hoz_pmp_lig_tt_min_0_pocket10.pdb_2hoz_A_rec_2hoz_pmp_lig_tt_min/2hoz_A_rec_2hoz_pmp_lig_tt_min_0.sdf', '4y7g_A_rec_4y7f_48x_lig_tt_min/4y7g_A_rec_4y7f_48x_lig_tt_min_0_pocket10.pdb_4y7g_A_rec_4y7f_48x_lig_tt_min/4y7g_A_rec_4y7f_48x_lig_tt_min_0.sdf', '4y7g_A_rec_4y7f_48x_lig_tt_min/4y7g_A_rec_4y7f_48x_lig_tt_min_0_pocket10.pdb_4y7g_A_rec_4y7f_48x_lig_tt_min/4y7g_A_rec_4y7f_48x_lig_tt_min_0.sdf', '5ugp_A_rec_5vrw_ttp_lig_tt_docked/5ugp_A_rec_5vrw_ttp_lig_tt_docked_4_pocket10.pdb_5ugp_A_rec_5vrw_ttp_lig_tt_docked/5ugp_A_rec_5vrw_ttp_lig_tt_docked_4.sdf', '5ugp_A_rec_5vrw_ttp_lig_tt_docked/5ugp_A_rec_5vrw_ttp_lig_tt_docked_4_pocket10.pdb_5ugp_A_rec_5vrw_ttp_lig_tt_docked/5ugp_A_rec_5vrw_ttp_lig_tt_docked_4.sdf', '5ugp_A_rec_5vrw_ttp_lig_tt_docked/5ugp_A_rec_5vrw_ttp_lig_tt_docked_4_pocket10.pdb_5ug

In [9]:
from constants import dataset_params, FLOAT_TYPE, INT_TYPE
def get_ligand_and_pocket(data,virtual_nodes):
    ref_ligand = {
        'x': data['ref_lig_coords'].to('cuda', FLOAT_TYPE),
        'one_hot': data['ref_lig_one_hot'].to('cuda', FLOAT_TYPE),
        'size': data['num_ref_lig_atoms'].to('cuda', INT_TYPE),
        'mask': data['ref_lig_mask'].to('cuda', INT_TYPE),
    }
    if virtual_nodes:
        ref_ligand['num_virtual_atoms'] = data['num_virtual_atoms'].to('cuda', INT_TYPE)
    
    opt_ligand = {
        'x': data['opt_lig_coords'].to('cuda', FLOAT_TYPE),
        'one_hot': data['opt_lig_one_hot'].to('cuda', FLOAT_TYPE),
        'size': data['num_opt_lig_atoms'].to('cuda', INT_TYPE),
        'mask': data['opt_lig_mask'].to('cuda', INT_TYPE),
    }
    if virtual_nodes:
        opt_ligand['num_virtual_atoms'] = data['num_virtual_atoms'].to('cuda', INT_TYPE)

    pocket = {
        'x': data['pocket_coords'].to('cuda', FLOAT_TYPE),
        'one_hot': data['pocket_one_hot'].to('cuda', FLOAT_TYPE),
        'size': data['num_pocket_nodes'].to('cuda', INT_TYPE),
        'mask': data['pocket_mask'].to('cuda', INT_TYPE)
    }

    atom_num_1 = ref_ligand['one_hot'].shape[0]
    atom_num_2 = pocket['one_hot'].shape[0]
    additional_tensor_1 = torch.tensor([[1, 0]]).repeat(atom_num_1, 1).to('cuda')
    additional_tensor_2 = torch.tensor([[0, 1]]).repeat(atom_num_2, 1).to('cuda')
    ref_ligand['one_hot'] = torch.cat((ref_ligand['one_hot'], additional_tensor_1), dim=1)
    pocket['one_hot'] = torch.cat([pocket['one_hot'],additional_tensor_2],dim =1)

    return ref_ligand, pocket, opt_ligand

In [10]:
ref_ligand, pocket, opt_ligand = get_ligand_and_pocket(batch, virtual_nodes=False)

In [11]:
pocket['size']

tensor([361, 361, 209, 209, 338, 338, 338, 338], device='cuda:0')

In [16]:
opt_ligand['one_hot'].size()

torch.Size([176, 11])

In [13]:
pocket['mask'] = torch.cat([ref_ligand['mask'],pocket['mask']],dim =0)
pocket['x'] = torch.cat([ref_ligand['x'],pocket['x']],dim =0)
pocket['one_hot'] = torch.cat([ref_ligand['one_hot'],pocket['one_hot']],dim =0)
pocket['size'] = ref_ligand['size'] + pocket['size']

In [14]:
xh_lig = torch.cat([opt_ligand['x'], opt_ligand['one_hot']], dim=1)

In [15]:
xh_lig.size()

torch.Size([176, 14])

In [15]:
opt_ligand['size']

tensor([15, 17, 13, 13, 30, 28, 30, 30], device='cuda:0')

In [16]:
def sigma(gamma, target_tensor):
        """Computes sigma given gamma."""
        return inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)),
                                        target_tensor)
    
def inflate_batch_array(array, target):
    """
    Inflates the batch array (array) with only a single axis
    (i.e. shape = (batch_size,), or possibly more empty axes
    (i.e. shape (batch_size, 1, ..., 1)) to match the target shape.
    """
    target_shape = (array.size(0),) + (1,) * (len(target.size()) - 1)
    return array.view(target_shape)
class PositiveLinear(torch.nn.Module):
    """Linear layer with weights forced to be positive."""

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 weight_init_offset: int = -2):
        super(PositiveLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.nn.Parameter(
            torch.empty((out_features, in_features)))
        if bias:
            self.bias = torch.nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)
        self.weight_init_offset = weight_init_offset
        self.reset_parameters()

    def reset_parameters(self) -> None:
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

        with torch.no_grad():
            self.weight.add_(self.weight_init_offset)

        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            torch.nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        positive_weight = F.softplus(self.weight)
        return F.linear(input, positive_weight, self.bias)

class GammaNetwork(torch.nn.Module):
    """The gamma network models a monotonic increasing function.
    Construction as in the VDM paper."""
    def __init__(self):
        super().__init__()

        self.l1 = PositiveLinear(1, 1)
        self.l2 = PositiveLinear(1, 1024)
        self.l3 = PositiveLinear(1024, 1)

        self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.]))
        self.gamma_1 = torch.nn.Parameter(torch.tensor([10.]))
        self.show_schedule()

    def show_schedule(self, num_steps=50):
        t = torch.linspace(0, 1, num_steps).view(num_steps, 1)
        gamma = self.forward(t)
        print('Gamma schedule:')
        print(gamma.detach().cpu().numpy().reshape(num_steps))

    def gamma_tilde(self, t):
        l1_t = self.l1(t)
        return l1_t + self.l3(torch.sigmoid(self.l2(l1_t)))

    def forward(self, t):
        zeros, ones = torch.zeros_like(t), torch.ones_like(t)
        # Not super efficient.
        gamma_tilde_0 = self.gamma_tilde(zeros)
        gamma_tilde_1 = self.gamma_tilde(ones)
        gamma_tilde_t = self.gamma_tilde(t)

        # Normalize to [0, 1]
        normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / (
                gamma_tilde_1 - gamma_tilde_0)

        # Rescale to [gamma_0, gamma_1]
        gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma

        return gamma
   
import math
import torch.nn.functional as F
gamma = GammaNetwork()
gamma.to('cuda')
lowest_t = 0
time_step = 1000
t_int = torch.randint(
    lowest_t, time_step + 1, size=(ref_ligand['size'].size(0), 1),
    device=ref_ligand['x'].device).float()
s_int = t_int - 1  # previous timestep
# Masks: important to compute log p(x | z0).
t_is_zero = (t_int == 0).float()
t_is_not_zero = 1 - t_is_zero
# Normalize t to [0, 1]. Note that the negative
# step of s will never be used, since then p(x | z0) is computed.
s = s_int / time_step
t = t_int / time_step
# Compute gamma_s and gamma_t via the network.
gamma_s = inflate_batch_array(gamma(s), opt_ligand['x'])
gamma_t = inflate_batch_array(gamma(t), opt_ligand['x'])
sigma_s = sigma(gamma_s, xh_lig)
lig_mask = opt_ligand["mask"]
sigma_s[lig_mask].size()


Gamma schedule:
[-5.         -4.6940923  -4.3876495  -4.0813856  -3.7749429  -3.4688568
 -3.1627707  -2.8566847  -2.5505986  -2.2445126  -1.9382484  -1.6318059
 -1.3255415  -1.0198119  -0.71336937 -0.4072833  -0.10101891  0.20488882
  0.5109749   0.81706095  1.1235032   1.4294114   1.7358537   2.0415835
  2.3480263   2.6537557   2.9603763   3.266284    3.5729046   3.8786345
  4.184721    4.4908066   4.796715    5.1028004   5.408887    5.7149725
  6.0212374   6.327324    6.6334095   6.9396734   7.2455816   7.551489
  7.8577538   8.16384     8.470104    8.775833    9.082276    9.388006
  9.694096   10.        ]


torch.Size([176, 1])

In [17]:
xh_lig = torch.cat([opt_ligand['x'], opt_ligand['one_hot']], dim=1)
xh_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
zs_lig, zs_pocket = xh_lig, xh_pocket

shift_net_out_ligand = torch.ones(234,14)
shift_net_out_ligand = shift_net_out_ligand.to(zs_lig.device)

ligand_mask = opt_ligand["mask"]
pocket_mask = pocket["mask"]
zt_lig = zs_lig + shift_net_out_ligand[ligand_mask]*zs_lig*(1/time_step) + sigma_s[ligand_mask]*(1.0/time_step)**(1/2)*torch.randn_like(zs_lig)


In [18]:
def sigma(gamma, target_tensor):
        """Computes sigma given gamma."""
        return inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)),
                                        target_tensor)

def alpha(gamma, target_tensor):
    """Computes alpha given gamma."""
    return inflate_batch_array(torch.sqrt(torch.sigmoid(-gamma)),
                                    target_tensor)
def sample_gaussian(size, device):
        x = torch.randn(size, device=device)
        return x

In [19]:
alpha_t = alpha(gamma_t, xh_lig)
sigma_t = sigma(gamma_t, xh_lig)

# Sample zt ~ Normal(alpha_t x, sigma_t)
eps_lig = sample_gaussian(
    size=(len(lig_mask), 14),
    device=lig_mask.device)

# Sample z_t given x, h for timestep t, from q(z_t | x, h)
z_t_lig = alpha_t[lig_mask] * xh_lig + sigma_t[lig_mask] * eps_lig

In [20]:

from constants import dataset_params
import utils
from utils import AppendVirtualNodes

dataset_info = dataset_params['crossdock_full']
histogram_file = Path(datadir, 'size_distribution.npy')
histogram = np.load(histogram_file).tolist()

lig_type_encoder = dataset_info['atom_encoder']
lig_type_decoder = dataset_info['atom_decoder']
pocket_type_encoder = dataset_info['aa_encoder']
pocket_type_decoder = dataset_info['aa_decoder']

virtual_nodes = False
data_transform = None
max_num_nodes = len(histogram) - 1
if virtual_nodes:
    # symbol = 'virtual'

    symbol = 'Ne'  # visualize as Neon atoms
    lig_type_encoder[symbol] = len(lig_type_encoder)
    data_transform = utils.AppendVirtualNodes(
        max_num_nodes, lig_type_encoder, symbol)
    
    virtual_atom = lig_type_encoder[symbol]
    lig_type_decoder.append(symbol)


    # Update dataset_info dictionary. This is necessary for using the
    # visualization functions.
    dataset_info['atom_encoder'] = lig_type_encoder
    dataset_info['atom_decoder'] = lig_type_decoder

atom_nf = len(lig_type_decoder)
aa_nf = len(pocket_type_decoder)

In [21]:
max_num_nodes

70

In [22]:
lig_type_decoder

['C', 'N', 'O', 'S', 'B', 'Br', 'Cl', 'P', 'I', 'F', 'others']

In [23]:
train_dataset = ProcessedLigandPocketDataset(Path(datadir, 'train.npz'), transform=data_transform)
test_dataset = ProcessedLigandPocketDataset(Path(datadir, 'test.npz'), transform=data_transform)
val_dataset = ProcessedLigandPocketDataset(Path(datadir, 'val.npz'), transform=data_transform)


train_loader = DataLoader(train_dataset, batch_size=8, num_workers=24, collate_fn=train_dataset.collate_fn, shuffle=False, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=8, num_workers=24, collate_fn=val_dataset.collate_fn, shuffle=False,pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=8, num_workers=24, collate_fn=test_dataset.collate_fn, shuffle=False,pin_memory=True)

In [24]:
from equivariant_diffusion.dynamics import EGNNDynamics

x_dims = 3
joint_nf =4

condition_vector = torch.tensor([0, 0, 1], dtype=torch.int, device='cuda')
condition_vector = condition_vector.unsqueeze(0).expand(2, -1)


net_dynamics = EGNNDynamics(
    atom_nf = atom_nf,
    residue_nf = aa_nf,
    n_dims = x_dims,
    joint_nf = joint_nf,
    device='cuda',
    hidden_nf=2,
    act_fn=torch.nn.SiLU(),
    n_layers= 2 ,
    attention= True,
    tanh=True,
    norm_constant=1,
    inv_sublayers=1,
    sin_embedding=False,
    normalization_factor=100,
    aggregation_method= 'sum' ,
    edge_cutoff_ligand=10,
    edge_cutoff_pocket=4,
    edge_cutoff_interaction=4,
    update_pocket_coords= False,
    reflection_equivariant=True,
    edge_embedding_dim=8,
    condition_vector = True
)


In [25]:
histogram_file = Path(datadir, 'size_distribution.npy')
histogram_file


PosixPath('../data/docking_results/processed_crossdock_noH_full_temp/size_distribution.npy')

In [26]:
from equivariant_diffusion.conditional_model import ConditionalDDPM


x_dims = 3
joint_nf =4


cddpm = ConditionalDDPM(
            dynamics = net_dynamics,
            atom_nf = atom_nf,
            residue_nf = aa_nf,
            n_dims = x_dims,
            timesteps= 100,
            noise_schedule = 'polynomial_2',
            noise_precision = 5.0e-4,
            loss_type = 'l2',
            norm_values = [1, 4],
            size_histogram = histogram,
            virtual_node_idx=lig_type_encoder[symbol] if virtual_nodes else None
    )




Entropy of n_nodes: H[N] 8.910039901733398


In [27]:
dynamics = net_dynamics,
atom_nf = atom_nf,
residue_nf = aa_nf,
n_dims = x_dims,
timesteps= 100,
noise_schedule = 'polynomial_2',
noise_precision = 5.0e-4,
loss_type = 'l2',
norm_values = [1, 4],
size_histogram = histogram,
virtual_node_idx=lig_type_encoder[symbol] if virtual_nodes else None

In [28]:
def get_prompts(data):
    # 创建一个张量 [0, 0, 1]
    prompts = torch.tensor(data['prompt_labels']).to('cuda', INT_TYPE)
    
    return prompts

In [29]:
loss_type = 'l2'

In [30]:
import torch
from tqdm import tqdm
import utils
from utils import AppendVirtualNodes

# 假设你已经定义了模型、优化器和其他超参数
optimizer = torch.optim.Adam(cddpm.parameters(), lr=0.001)  # 选择合适的学习率
num_epochs = 1
device = 'cuda'

for epoch in range(num_epochs):
    cddpm.train()  # 设置模型为训练模式
    cddpm.to(device)
    total_loss = 0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}', leave=False)

    for batch in pbar:
        # 假设 batch 是一个列表，包含多个样本
        # 将每个样本移动到设备并组合成字典
        # batch={key:batch[key].cuda() for key in batch}

        
        optimizer.zero_grad()  # 清空梯度

        ref_ligand, pocket, opt_ligand = get_ligand_and_pocket(batch,virtual_nodes)
        prompt_labels = get_prompts(batch)
        pocket['mask'] = torch.cat([ref_ligand['mask'],pocket['mask']],dim =0)
        pocket['x'] = torch.cat([ref_ligand['x'],pocket['x']],dim =0)
        pocket['one_hot'] = torch.cat([ref_ligand['one_hot'],pocket['one_hot']],dim =0)
        pocket['size'] = ref_ligand['size'] + pocket['size']


        loss_terms = cddpm(ref_ligand, pocket, opt_ligand, prompt_labels, return_info=False)

        # delta_log_px, error_t_lig, error_t_pocket, SNR_weight, \
        # loss_0_x_ligand, loss_0_x_pocket, loss_0_h, neg_log_const_0, \
        # kl_prior, log_pN, t_int, xh_lig_hat, info = \
        #     cddpm(ref_ligand, pocket, opt_ligand, prompt_labels, return_info=True)
        
        delta_log_px, error_t_lig, error_t_pocket, SNR_weight, \
        loss_0_x_ligand, loss_0_x_pocket, loss_0_h, neg_log_const_0, \
        kl_prior, t_int, xh_lig_hat, info = \
            cddpm(ref_ligand, pocket, opt_ligand, prompt_labels, return_info=True)

        if loss_type == 'l2':
            actual_ligand_size = ref_ligand['size'] - ref_ligand['num_virtual_atoms'] if virtual_nodes else ref_ligand['size']

            # normalize loss_t
            denom_lig = x_dims * actual_ligand_size + \
                        cddpm.atom_nf * ref_ligand['size']
            error_t_lig = error_t_lig / denom_lig
            denom_pocket = (x_dims + cddpm.residue_nf) * pocket['size']
            error_t_pocket = error_t_pocket / denom_pocket
            loss_t = 0.5 * (error_t_lig + error_t_pocket)


            # normalize loss_0
            loss_0_x_ligand = loss_0_x_ligand / (x_dims * actual_ligand_size)
            loss_0_x_pocket = loss_0_x_pocket / (x_dims * pocket['size'])
            loss_0 = loss_0_x_ligand + loss_0_x_pocket + loss_0_h


        nll = loss_t + loss_0 + kl_prior

        # print("loss", nll)
        nll = nll.mean()

        # 反向传播
        nll.backward()

        # 更新参数
        optimizer.step()

        # 累加损失
        total_loss += nll.item()
        pbar.set_postfix(nll_loss=nll.item())  # 更新进度条的后缀信息

    print(f'Epoch {epoch}, Average NLL Loss: {total_loss / len(train_loader)}')

  prompts = torch.tensor(data['prompt_labels']).to('cuda', INT_TYPE)
                                                                          

Epoch 0, Average NLL Loss: 0.5761266875369796




In [31]:
import torch
from tqdm import tqdm
import os

# 假设你已经定义了模型、优化器和其他超参数
optimizer = torch.optim.Adam(cddpm.parameters(), lr=0.001)  # 选择合适的学习率
num_epochs = 1
device = 'cuda'
save_dir = '../checkpoints/cddpm'  # 模型保存的文件夹路径

# 创建保存目录（如果不存在）
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# 训练循环
for epoch in range(num_epochs):
    # 设置模型为训练模式
    cddpm.train()
    cddpm.to(device)
    total_loss = 0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}', leave=False)

    # 训练阶段
    for batch in pbar:
        # batch = {key: batch[key].to(device) for key in batch}

        optimizer.zero_grad()  # 清空梯度

        # 提取配体和口袋数据
        ref_ligand, pocket, opt_ligand = get_ligand_and_pocket(batch,virtual_nodes)
        prompt_labels = get_prompts(batch)

        pocket['mask'] = torch.cat([ref_ligand['mask'],pocket['mask']],dim =0)
        pocket['x'] = torch.cat([ref_ligand['x'],pocket['x']],dim =0)
        pocket['one_hot'] = torch.cat([ref_ligand['one_hot'],pocket['one_hot']],dim =0)
        pocket['size'] = ref_ligand['size'] + pocket['size']

        # 计算损失，返回 (nll, info)
        delta_log_px, error_t_lig, error_t_pocket, SNR_weight, \
        loss_0_x_ligand, loss_0_x_pocket, loss_0_h, neg_log_const_0, \
        kl_prior, t_int, xh_lig_hat, info = cddpm(ref_ligand, pocket, opt_ligand, prompt_labels, return_info=True)

        if loss_type == 'l2':
            actual_ligand_size = ref_ligand['size'] - ref_ligand['num_virtual_atoms'] if virtual_nodes else ref_ligand['size']

            # normalize loss_t
            denom_lig = x_dims * actual_ligand_size + \
                        cddpm.atom_nf * ref_ligand['size']
            error_t_lig = error_t_lig / denom_lig
            denom_pocket = (x_dims + cddpm.residue_nf) * pocket['size']
            error_t_pocket = error_t_pocket / denom_pocket
            loss_t = 0.5 * (error_t_lig + error_t_pocket)

            # normalize loss_0
            loss_0_x_ligand = loss_0_x_ligand / (x_dims * actual_ligand_size)
            loss_0_x_pocket = loss_0_x_pocket / (x_dims * pocket['size'])
            loss_0 = loss_0_x_ligand + loss_0_x_pocket + loss_0_h

        nll = loss_t + loss_0 + kl_prior

        # print("loss", nll)
        nll = nll.mean()

        # 反向传播
        nll.backward()

        # 更新参数
        optimizer.step()

        # 累加损失
        total_loss += nll.item()
        pbar.set_postfix(nll_loss=nll.item())  # 更新进度条的后缀信息

    print(f'Epoch {epoch}, Average NLL Loss: {total_loss / len(train_loader)}')

    # 验证阶段
    cddpm.eval()  # 设置模型为评估模式
    val_loss = 0
    with torch.no_grad():  # 不需要计算梯度
        for batch in val_loader:
            # batch = {key: batch[key].to(device) for key in batch}
            
            ref_ligand, pocket, opt_ligand = get_ligand_and_pocket(batch, virtual_nodes)
            prompt_labels = get_prompts(batch)

            pocket['mask'] = torch.cat([ref_ligand['mask'],pocket['mask']],dim =0)
            pocket['x'] = torch.cat([ref_ligand['x'],pocket['x']],dim =0)
            pocket['one_hot'] = torch.cat([ref_ligand['one_hot'],pocket['one_hot']],dim =0)
            pocket['size'] = ref_ligand['size'] + pocket['size']

            delta_log_px, error_t_lig, error_t_pocket, SNR_weight, \
            loss_0_x_ligand, loss_0_x_pocket, loss_0_h, neg_log_const_0, \
            kl_prior, t_int, xh_lig_hat, info = cddpm(ref_ligand, pocket, opt_ligand, prompt_labels, return_info=True)

            if loss_type == 'l2':
                actual_ligand_size = ref_ligand['size'] - ref_ligand['num_virtual_atoms'] if virtual_nodes else ref_ligand['size']

                # normalize loss_t
                denom_lig = x_dims * actual_ligand_size + \
                            cddpm.atom_nf * ref_ligand['size']
                error_t_lig = error_t_lig / denom_lig
                denom_pocket = (x_dims + cddpm.residue_nf) * pocket['size']
                error_t_pocket = error_t_pocket / denom_pocket
                loss_t = 0.5 * (error_t_lig + error_t_pocket)

                # normalize loss_0
                loss_0_x_ligand = loss_0_x_ligand / (x_dims * actual_ligand_size)
                loss_0_x_pocket = loss_0_x_pocket / (x_dims * pocket['size'])
                loss_0 = loss_0_x_ligand + loss_0_x_pocket + loss_0_h

            nll = loss_t + loss_0 + kl_prior
            nll = nll.mean()  # 将 nll 转换为标量
            val_loss += nll.item()

    print(f'Epoch {epoch}, Validation Loss: {val_loss / len(val_loader)}')

    # 每个 epoch 后保存模型
    torch.save(cddpm.state_dict(), os.path.join(save_dir, f'cddpm_epoch_{epoch}.pth'))

# 最终测试阶段
cddpm.eval()  # 设置模型为评估模式
test_loss = 0
with torch.no_grad():  # 不需要计算梯度
    for batch in test_loader:
        # batch = {key: batch[key].to(device) for key in batch}

        ref_ligand, pocket, opt_ligand = get_ligand_and_pocket(batch, virtual_nodes)
        prompt_labels = get_prompts(batch)
        
        pocket['mask'] = torch.cat([ref_ligand['mask'],pocket['mask']],dim =0)
        pocket['x'] = torch.cat([ref_ligand['x'],pocket['x']],dim =0)
        pocket['one_hot'] = torch.cat([ref_ligand['one_hot'],pocket['one_hot']],dim =0)
        pocket['size'] = ref_ligand['size'] + pocket['size']

        delta_log_px, error_t_lig, error_t_pocket, SNR_weight, \
        loss_0_x_ligand, loss_0_x_pocket, loss_0_h, neg_log_const_0, \
        kl_prior, t_int, xh_lig_hat, info = cddpm(ref_ligand, pocket, opt_ligand, prompt_labels,return_info=True)

        if loss_type == 'l2':
            actual_ligand_size = ref_ligand['size'] - ref_ligand['num_virtual_atoms'] if virtual_nodes else ref_ligand['size']

            # normalize loss_t
            denom_lig = x_dims * actual_ligand_size + \
                        cddpm.atom_nf * ref_ligand['size']
            error_t_lig = error_t_lig / denom_lig
            denom_pocket = (x_dims + cddpm.residue_nf) * pocket['size']
            error_t_pocket = error_t_pocket / denom_pocket
            loss_t = 0.5 * (error_t_lig + error_t_pocket)

            # normalize loss_0
            loss_0_x_ligand = loss_0_x_ligand / (x_dims * actual_ligand_size)
            loss_0_x_pocket = loss_0_x_pocket / (x_dims * pocket['size'])
            loss_0 = loss_0_x_ligand + loss_0_x_pocket + loss_0_h

        nll = loss_t + loss_0 + kl_prior
        nll = nll.mean()  # 将 nll 转换为标量
        test_loss += nll.item()

print(f'Test Loss: {test_loss / len(test_loader)}')

# 最终保存模型
torch.save(cddpm.state_dict(), os.path.join(save_dir, 'cddpm_final2.pth'))

Epoch 0:   0%|          | 0/580 [00:00<?, ?it/s]

  prompts = torch.tensor(data['prompt_labels']).to('cuda', INT_TYPE)
                                                                          

Epoch 0, Average NLL Loss: 0.5159810288199063




Epoch 0, Validation Loss: 1.0647622297207515
Test Loss: 1.06418603244755
