In [1]:
import os
import csv
import time
import json
import torch
import random
import argparse
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from model import DDG_RDE_Network, Codebook
from trainer import CrossValidation, recursive_to
from utils import set_seed, check_dir, eval_skempi_three_modes, save_code, load_config
from dataset import SkempiDataset
from torch.utils.data import random_split
from ipdb import set_trace

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
dataset = SkempiDataset(
            csv_path="/home/hongtan/pretrain_single/data/dataset/S2648/s2648.csv",
            pdb_dir="/home/hongtan/pretrain_single/data/dataset/s2648_s669/pdb",
            cache_dir="/home/hongtan/pretrain_single/data/dataset/S2648/cache",
            patch_size=128
        )

  return torch.load(io.BytesIO(b))


In [4]:
s543_dataset = SkempiDataset(
            csv_path="/home/hongtan/pretrain_single/data/dataset/S543/s543.csv",
            pdb_dir="/home/hongtan/pretrain_single/data/dataset/S543/s543_pdb",
            cache_dir="/home/hongtan/pretrain_single/data/dataset/S543/cache",
            patch_size=128
        )

Structures: 100%|██████████| 55/55 [01:32<00:00,  1.69s/it]


In [5]:
param = json.loads(open("/home/hongtan/pretrain_single/config/param_configs.json", 'r').read())
args = argparse.Namespace(**param)

In [6]:
model = DDG_RDE_Network(args)

In [7]:
import math

import torch
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate

DEFAULT_PAD_VALUES = {
    'aa': 21, 
    'chain_nb': -1, 
    'chain_id': ' ', 
}
class PaddingCollate(object):
    def __init__(self, length_ref_key='aa', pad_values=DEFAULT_PAD_VALUES):
        super().__init__()
        self.length_ref_key = length_ref_key
        self.pad_values = pad_values

    @staticmethod
    def _pad_last(x, n, value=0):
        if isinstance(x, torch.Tensor):
            assert x.size(0) <= n
            if x.size(0) == n:
                return x
            pad_size = [n - x.size(0)] + list(x.shape[1:])
            pad = torch.full(pad_size, fill_value=value).to(x)
            return torch.cat([x, pad], dim=0)
        elif isinstance(x, list):
            pad = [value] * (n - len(x))
            return x + pad
        else:
            return x

    @staticmethod
    def _get_common_keys(list_of_dict):
        keys = set(list_of_dict[0].keys())
        for d in list_of_dict[1:]:
            keys = keys.intersection(d.keys())
        return keys

    def _get_pad_value(self, key):
        return self.pad_values.get(key, 0)

    def __call__(self, data_list):
        # 过滤掉 None 的数据
        data_list = [data for data in data_list if data is not None]

        # 如果所有数据都被过滤掉，返回 None
        if not data_list:
            print("Warning: All data in batch are None.")
            return None

        try:
            max_length = max([data[self.length_ref_key].size(0) for data in data_list])
        except KeyError:
            print(f"KeyError: Key '{self.length_ref_key}' not found in one of the data items.")
            return None
        except ValueError:
            print("ValueError: One of the data items is empty.")
            return None

        max_length = math.ceil(max_length / 8) * 8

        exclude_keys = {'esm_embeddings_wt','esm_embeddings_mut'}
        data_list_padded = []
        for data in data_list:
            if data is None:
                continue  # 再次检查并跳过 None 数据

            data_padded = {}
            for k, v in data.items():
                if k in exclude_keys:
                    data_padded[k] = v  # 保持原样，不填充
                else:
                    data_padded[k] = self._pad_last(v, max_length, value=self._get_pad_value(k))
            data_list_padded.append(data_padded)

        # 如果经过填充的数据仍然是空的，返回 None
        if not data_list_padded:
            print("Warning: No valid data to collate after padding.")
            return None

        # 过滤掉任何仍包含 None 的数据
        data_list_padded = [d for d in data_list_padded if all(v is not None for v in d.values())]

        if not data_list_padded:
            print("Warning: All data in batch are invalid after filtering.")
            return None

        batch = default_collate(data_list_padded)
        return batch

In [8]:
dataset_size = len(dataset)
train_size = int(0.9 * dataset_size)  # 80% 用于训练集
val_size = int(0.1 * dataset_size)    # 10% 用于验证集
test_size = dataset_size - train_size - val_size  # 剩下的用于测试集

# 使用 random_split 来划分数据集
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
test_dataset = s543_dataset
# 创建 PyGDataLoader 对象
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,collate_fn=PaddingCollate(),num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True,collate_fn=PaddingCollate(),num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True,collate_fn=PaddingCollate(),num_workers=4)

In [9]:
vae_model = Codebook(args).to(device)
vae_model.load_state_dict(torch.load("/home/hongtan/pretrain_single/weight/checkpoint/vvvvae_model.ckpt", map_location=device))

  vae_model.load_state_dict(torch.load("/home/hongtan/pretrain_single/weight/checkpoint/vvvvae_model.ckpt", map_location=device))


<All keys matched successfully>

In [10]:
import torch
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm.auto import tqdm

# RMSE: Mean Squared Error (MSE), then square root
def compute_rmse(true_vals, pred_vals):
    mse = mean_squared_error(true_vals, pred_vals)
    return np.sqrt(mse)

# PCC: Pearson correlation coefficient
def compute_pcc(true_vals, pred_vals):
    return pearsonr(true_vals, pred_vals)[0]

# R²: Coefficient of determination
def compute_r2(true_vals, pred_vals):
    return r2_score(true_vals, pred_vals)

In [11]:
import torch
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm.auto import tqdm

# RMSE: Root Mean Squared Error
def compute_rmse(true_vals, pred_vals):
    mse = mean_squared_error(true_vals, pred_vals)
    return np.sqrt(mse)

# PCC: Pearson correlation coefficient
def compute_pcc(true_vals, pred_vals):
    return pearsonr(true_vals, pred_vals)[0]

# R²: Coefficient of determination (R-squared)
def compute_r2(true_vals, pred_vals):
    return r2_score(true_vals, pred_vals)

import torch
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm.auto import tqdm

# RMSE: Root Mean Squared Error
def compute_rmse(true_vals, pred_vals):
    mse = mean_squared_error(true_vals, pred_vals)
    return np.sqrt(mse)

# PCC: Pearson correlation coefficient
def compute_pcc(true_vals, pred_vals):
    return pearsonr(true_vals, pred_vals)[0]

# R²: Coefficient of determination (R-squared)
def compute_r2(true_vals, pred_vals):
    return r2_score(true_vals, pred_vals)

# Helper function to move data to device
def recursive_to(batch, device):
    if isinstance(batch, torch.Tensor):
        return batch.to(device)
    elif isinstance(batch, dict):
        return {k: recursive_to(v, device) for k, v in batch.items()}
    elif isinstance(batch, list):
        return [recursive_to(v, device) for v in batch]
    else:
        return batch

def train_model(model, train_loader, val_loader, test_loader,optimizer, epochs=50, device='cuda:0'):
    model.to(device)
    best_val_loss = float('inf')
    best_model_path = '/home/hongtan/pretrain_single/weight/best_model.pth'
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        # Training loop with progress bar
        with tqdm(total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}', dynamic_ncols=True) as pbar:
            for batch in train_loader:
                batch = recursive_to(batch, device)
                optimizer.zero_grad()

                loss, _ = model(batch,vae_model)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                pbar.set_postfix({'Train Loss': loss.item()})
                pbar.update(1)

        avg_train_loss = running_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}')
        
        # Validation loop
        val_rmse, val_pcc, val_r2, avg_val_loss = validate_model(model, val_loader, device)
        print(f'Epoch {epoch+1}/{epochs}, Val Loss: {avg_val_loss:.4f}, Val RMSE: {val_rmse:.4f}, Val PCC: {val_pcc:.4f}, Val R²: {val_r2:.4f}')
        
        test_rmse, test_pcc, test_r2, test_val_loss = test_model(model, test_loader, device)
        print(f'Epoch {epoch+1}/{epochs}, test RMSE: {test_rmse:.4f}, test PCC: {test_pcc:.4f}, test R²: {test_r2:.4f}')
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved with Val Loss: {avg_val_loss:.4f}")

def validate_model(model, val_loader, device):
    model.eval()
    
    all_ddg_true = []
    all_ddg_pred = []
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validating', dynamic_ncols=True):
            batch = recursive_to(batch, device)
            
            # Forward pass
            loss, output_dict = model(batch,vae_model)
            val_loss += loss.item()

            # Collect true and predicted ddG values
            ddg_true = output_dict['ddG_true'].cpu().numpy()
            ddg_pred = output_dict['ddG_pred'].cpu().numpy()
            
            all_ddg_true.append(ddg_true)
            all_ddg_pred.append(ddg_pred)

    # Flatten lists of ddG values
    all_ddg_true = np.concatenate(all_ddg_true)
    all_ddg_pred = np.concatenate(all_ddg_pred)
    avg_val_loss = val_loss / len(val_loader)
    
    # Compute RMSE, PCC, and R²
    rmse = compute_rmse(all_ddg_true, all_ddg_pred)
    pcc = compute_pcc(all_ddg_true, all_ddg_pred)
    r2 = compute_r2(all_ddg_true, all_ddg_pred)
    
    return rmse, pcc, r2, avg_val_loss

def test_model(model, test_loader, device='cuda:0'):
    model.eval()
    all_ddg_true = []
    all_ddg_pred = []
    test_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Testing', dynamic_ncols=True):
            batch = recursive_to(batch, device)
            
            # Forward pass
            loss, output_dict = model(batch,vae_model)
            test_loss += loss.item()

            # Collect true and predicted ddG values
            ddg_true = output_dict['ddG_true'].cpu().numpy()
            ddg_pred = output_dict['ddG_pred'].cpu().numpy()
            
            all_ddg_true.append(ddg_true)
            all_ddg_pred.append(ddg_pred)

    # Flatten lists of ddG values
    all_ddg_true = np.concatenate(all_ddg_true)
    all_ddg_pred = np.concatenate(all_ddg_pred)
    avg_test_loss = test_loss / len(test_loader)

    # Compute RMSE, PCC, and R²
    rmse = compute_rmse(all_ddg_true, all_ddg_pred)
    pcc = compute_pcc(all_ddg_true, all_ddg_pred)
    r2 = compute_r2(all_ddg_true, all_ddg_pred)
    
    print(f'Test Loss: {avg_test_loss:.4f}, Test RMSE: {rmse:.4f}, Test PCC: {pcc:.4f}, Test R²: {r2:.4f}')
    return rmse, pcc, r2, avg_test_loss


In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.8, patience=10, min_lr=1e-6)


In [14]:
train_model(model, train_loader, val_loader,test_loader,optimizer = optimizer, epochs=100, device= torch.device('cuda:0'))


Epoch 1/100:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch 1/100:  59%|█████▊    | 44/75 [01:01<00:47,  1.54s/it, Train Loss=1.4] 

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 1/100: 100%|██████████| 75/75 [01:46<00:00,  1.42s/it, Train Loss=2.3]  


Epoch 1/100, Train Loss: 1.9622


Validating: 100%|██████████| 9/9 [00:08<00:00,  1.08it/s]


Epoch 1/100, Val Loss: 1.6048, Val RMSE: 1.2724, Val PCC: 0.4925, Val R²: 0.2002


Testing: 100%|██████████| 21/21 [00:12<00:00,  1.68it/s]


Test Loss: 2.4868, Test RMSE: 1.5759, Test PCC: 0.3511, Test R²: 0.0735
Epoch 1/100, test RMSE: 1.5759, test PCC: 0.3511, test R²: 0.0735
Best model saved with Val Loss: 1.6048


Epoch 2/100:  71%|███████   | 53/75 [00:49<00:25,  1.16s/it, Train Loss=1.13] 

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 2/100: 100%|██████████| 75/75 [01:08<00:00,  1.09it/s, Train Loss=2.45] 


Epoch 2/100, Train Loss: 1.6255


Validating: 100%|██████████| 9/9 [00:06<00:00,  1.40it/s]


Epoch 2/100, Val Loss: 1.5496, Val RMSE: 1.2656, Val PCC: 0.5165, Val R²: 0.2086


Testing: 100%|██████████| 21/21 [00:24<00:00,  1.17s/it]


Test Loss: 2.4596, Test RMSE: 1.5685, Test PCC: 0.3677, Test R²: 0.0822
Epoch 2/100, test RMSE: 1.5685, test PCC: 0.3677, test R²: 0.0822
Best model saved with Val Loss: 1.5496


Epoch 3/100:  77%|███████▋  | 58/75 [01:30<00:30,  1.77s/it, Train Loss=0.718]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 3/100: 100%|██████████| 75/75 [01:46<00:00,  1.42s/it, Train Loss=1.3]  


Epoch 3/100, Train Loss: 1.5087


Validating: 100%|██████████| 9/9 [00:13<00:00,  1.47s/it]


Epoch 3/100, Val Loss: 1.5986, Val RMSE: 1.2650, Val PCC: 0.5167, Val R²: 0.2094


Testing: 100%|██████████| 21/21 [00:15<00:00,  1.40it/s]


Test Loss: 2.4451, Test RMSE: 1.5640, Test PCC: 0.3680, Test R²: 0.0875
Epoch 3/100, test RMSE: 1.5640, test PCC: 0.3680, test R²: 0.0875


Epoch 4/100:   0%|          | 0/75 [00:00<?, ?it/s]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 4/100: 100%|██████████| 75/75 [01:11<00:00,  1.05it/s, Train Loss=2.45] 


Epoch 4/100, Train Loss: 1.3822


Validating: 100%|██████████| 9/9 [00:09<00:00,  1.00s/it]


Epoch 4/100, Val Loss: 1.6141, Val RMSE: 1.2747, Val PCC: 0.4991, Val R²: 0.1972


Testing: 100%|██████████| 21/21 [00:14<00:00,  1.46it/s]


Test Loss: 2.4320, Test RMSE: 1.5535, Test PCC: 0.3722, Test R²: 0.0997
Epoch 4/100, test RMSE: 1.5535, test PCC: 0.3722, test R²: 0.0997


Epoch 5/100:  28%|██▊       | 21/75 [00:23<00:51,  1.05it/s, Train Loss=1.2]  

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 5/100: 100%|██████████| 75/75 [01:06<00:00,  1.12it/s, Train Loss=0.854]


Epoch 5/100, Train Loss: 1.3376


Validating: 100%|██████████| 9/9 [00:07<00:00,  1.15it/s]


Epoch 5/100, Val Loss: 1.7063, Val RMSE: 1.2708, Val PCC: 0.5145, Val R²: 0.2022


Testing: 100%|██████████| 21/21 [00:14<00:00,  1.50it/s]


Test Loss: 2.3416, Test RMSE: 1.5295, Test PCC: 0.4137, Test R²: 0.1273
Epoch 5/100, test RMSE: 1.5295, test PCC: 0.4137, test R²: 0.1273


Epoch 6/100:   1%|▏         | 1/75 [00:01<02:26,  1.98s/it, Train Loss=2.15]

Skipping mutation K->S at ('A', 80): not found in seq_map


Epoch 6/100:   1%|▏         | 1/75 [00:03<02:26,  1.98s/it, Train Loss=1.85]

Skipping entry 1699 after transform.

Epoch 6/100:   3%|▎         | 2/75 [00:03<01:52,  1.54s/it, Train Loss=1.85]




Epoch 6/100: 100%|██████████| 75/75 [01:11<00:00,  1.04it/s, Train Loss=1.36] 


Epoch 6/100, Train Loss: 1.2234


Validating: 100%|██████████| 9/9 [00:08<00:00,  1.12it/s]


Epoch 6/100, Val Loss: 1.5157, Val RMSE: 1.2344, Val PCC: 0.5452, Val R²: 0.2472


Testing: 100%|██████████| 21/21 [00:19<00:00,  1.06it/s]


Test Loss: 2.4407, Test RMSE: 1.5595, Test PCC: 0.3888, Test R²: 0.0927
Epoch 6/100, test RMSE: 1.5595, test PCC: 0.3888, test R²: 0.0927
Best model saved with Val Loss: 1.5157


Epoch 7/100:  36%|███▌      | 27/75 [00:21<00:37,  1.29it/s, Train Loss=0.959]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 7/100: 100%|██████████| 75/75 [01:15<00:00,  1.00s/it, Train Loss=3.85] 


Epoch 7/100, Train Loss: 1.1598


Validating: 100%|██████████| 9/9 [00:15<00:00,  1.71s/it]


Epoch 7/100, Val Loss: 1.4385, Val RMSE: 1.2313, Val PCC: 0.5711, Val R²: 0.2510


Testing: 100%|██████████| 21/21 [00:12<00:00,  1.65it/s]


Test Loss: 2.4577, Test RMSE: 1.5679, Test PCC: 0.3708, Test R²: 0.0828
Epoch 7/100, test RMSE: 1.5679, test PCC: 0.3708, test R²: 0.0828
Best model saved with Val Loss: 1.4385


Epoch 8/100:  41%|████▏     | 31/75 [00:47<01:16,  1.74s/it, Train Loss=0.798]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 8/100: 100%|██████████| 75/75 [01:36<00:00,  1.29s/it, Train Loss=1.28] 


Epoch 8/100, Train Loss: 1.0022


Validating: 100%|██████████| 9/9 [00:17<00:00,  1.93s/it]


Epoch 8/100, Val Loss: 1.5499, Val RMSE: 1.2043, Val PCC: 0.5669, Val R²: 0.2835


Testing: 100%|██████████| 21/21 [00:13<00:00,  1.50it/s]


Test Loss: 2.3998, Test RMSE: 1.5517, Test PCC: 0.3948, Test R²: 0.1017
Epoch 8/100, test RMSE: 1.5517, test PCC: 0.3948, test R²: 0.1017


Epoch 9/100:  87%|████████▋ | 65/75 [01:22<00:20,  2.04s/it, Train Loss=0.859]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 9/100: 100%|██████████| 75/75 [01:43<00:00,  1.38s/it, Train Loss=0.468]


Epoch 9/100, Train Loss: 0.9206


Validating: 100%|██████████| 9/9 [00:15<00:00,  1.73s/it]


Epoch 9/100, Val Loss: 1.5011, Val RMSE: 1.2163, Val PCC: 0.5778, Val R²: 0.2691


Testing: 100%|██████████| 21/21 [00:12<00:00,  1.63it/s]


Test Loss: 2.4368, Test RMSE: 1.5591, Test PCC: 0.3836, Test R²: 0.0932
Epoch 9/100, test RMSE: 1.5591, test PCC: 0.3836, test R²: 0.0932


Epoch 10/100:  45%|████▌     | 34/75 [01:07<00:37,  1.11it/s, Train Loss=0.628]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 10/100: 100%|██████████| 75/75 [01:34<00:00,  1.27s/it, Train Loss=1.88] 


Epoch 10/100, Train Loss: 0.7741


Validating: 100%|██████████| 9/9 [00:08<00:00,  1.12it/s]


Epoch 10/100, Val Loss: 1.3101, Val RMSE: 1.1835, Val PCC: 0.5897, Val R²: 0.3081


Testing: 100%|██████████| 21/21 [00:15<00:00,  1.38it/s]


Test Loss: 2.4646, Test RMSE: 1.5671, Test PCC: 0.3710, Test R²: 0.0838
Epoch 10/100, test RMSE: 1.5671, test PCC: 0.3710, test R²: 0.0838
Best model saved with Val Loss: 1.3101


Epoch 11/100:  21%|██▏       | 16/75 [00:15<00:45,  1.29it/s, Train Loss=0.861]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.

Epoch 11/100:  23%|██▎       | 17/75 [00:16<00:42,  1.37it/s, Train Loss=0.5]  




Epoch 11/100: 100%|██████████| 75/75 [01:08<00:00,  1.09it/s, Train Loss=0.282]


Epoch 11/100, Train Loss: 0.6736


Validating: 100%|██████████| 9/9 [00:06<00:00,  1.41it/s]


Epoch 11/100, Val Loss: 1.4755, Val RMSE: 1.2187, Val PCC: 0.5818, Val R²: 0.2663


Testing: 100%|██████████| 21/21 [00:12<00:00,  1.66it/s]


Test Loss: 2.5201, Test RMSE: 1.5890, Test PCC: 0.3580, Test R²: 0.0580
Epoch 11/100, test RMSE: 1.5890, test PCC: 0.3580, test R²: 0.0580


Epoch 12/100:  36%|███▌      | 27/75 [00:40<01:25,  1.77s/it, Train Loss=0.406]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 12/100: 100%|██████████| 75/75 [01:49<00:00,  1.46s/it, Train Loss=0.312]


Epoch 12/100, Train Loss: 0.5295


Validating: 100%|██████████| 9/9 [00:06<00:00,  1.50it/s]


Epoch 12/100, Val Loss: 1.4214, Val RMSE: 1.1899, Val PCC: 0.5951, Val R²: 0.3005


Testing: 100%|██████████| 21/21 [00:13<00:00,  1.52it/s]


Test Loss: 2.4141, Test RMSE: 1.5508, Test PCC: 0.3961, Test R²: 0.1028
Epoch 12/100, test RMSE: 1.5508, test PCC: 0.3961, test R²: 0.1028


Epoch 13/100:  72%|███████▏  | 54/75 [00:47<00:15,  1.35it/s, Train Loss=0.547]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 13/100: 100%|██████████| 75/75 [01:04<00:00,  1.16it/s, Train Loss=0.829]


Epoch 13/100, Train Loss: 0.4239


Validating: 100%|██████████| 9/9 [00:07<00:00,  1.27it/s]


Epoch 13/100, Val Loss: 1.5650, Val RMSE: 1.1636, Val PCC: 0.6264, Val R²: 0.3311


Testing: 100%|██████████| 21/21 [00:13<00:00,  1.51it/s]


Test Loss: 2.5596, Test RMSE: 1.5977, Test PCC: 0.3752, Test R²: 0.0477
Epoch 13/100, test RMSE: 1.5977, test PCC: 0.3752, test R²: 0.0477


Epoch 14/100:  73%|███████▎  | 55/75 [01:05<00:16,  1.22it/s, Train Loss=0.482]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 14/100: 100%|██████████| 75/75 [01:26<00:00,  1.15s/it, Train Loss=0.226]


Epoch 14/100, Train Loss: 0.3343


Validating: 100%|██████████| 9/9 [00:15<00:00,  1.77s/it]


Epoch 14/100, Val Loss: 1.6660, Val RMSE: 1.2001, Val PCC: 0.5887, Val R²: 0.2884


Testing: 100%|██████████| 21/21 [00:20<00:00,  1.02it/s]


Test Loss: 2.5471, Test RMSE: 1.5937, Test PCC: 0.3642, Test R²: 0.0524
Epoch 14/100, test RMSE: 1.5937, test PCC: 0.3642, test R²: 0.0524


Epoch 15/100:   3%|▎         | 2/75 [00:05<02:55,  2.41s/it, Train Loss=0.308]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 15/100: 100%|██████████| 75/75 [01:14<00:00,  1.01it/s, Train Loss=0.34] 


Epoch 15/100, Train Loss: 0.2784


Validating: 100%|██████████| 9/9 [00:06<00:00,  1.29it/s]


Epoch 15/100, Val Loss: 1.3507, Val RMSE: 1.1762, Val PCC: 0.5960, Val R²: 0.3165


Testing: 100%|██████████| 21/21 [00:27<00:00,  1.33s/it]


Test Loss: 2.5572, Test RMSE: 1.5977, Test PCC: 0.3633, Test R²: 0.0477
Epoch 15/100, test RMSE: 1.5977, test PCC: 0.3633, test R²: 0.0477


Epoch 16/100:  31%|███       | 23/75 [00:41<01:16,  1.47s/it, Train Loss=0.246]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 16/100: 100%|██████████| 75/75 [01:41<00:00,  1.35s/it, Train Loss=0.0939]


Epoch 16/100, Train Loss: 0.2236


Validating: 100%|██████████| 9/9 [00:19<00:00,  2.20s/it]


Epoch 16/100, Val Loss: 1.3826, Val RMSE: 1.1885, Val PCC: 0.5981, Val R²: 0.3021


Testing: 100%|██████████| 21/21 [00:19<00:00,  1.10it/s]


Test Loss: 2.5678, Test RMSE: 1.6037, Test PCC: 0.3669, Test R²: 0.0405
Epoch 16/100, test RMSE: 1.6037, test PCC: 0.3669, test R²: 0.0405


Epoch 17/100:  43%|████▎     | 32/75 [00:32<00:49,  1.15s/it, Train Loss=0.237]

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 17/100: 100%|██████████| 75/75 [01:39<00:00,  1.32s/it, Train Loss=0.0884]


Epoch 17/100, Train Loss: 0.1937


Validating: 100%|██████████| 9/9 [00:08<00:00,  1.12it/s]


Epoch 17/100, Val Loss: 1.3077, Val RMSE: 1.1721, Val PCC: 0.6015, Val R²: 0.3213


Testing: 100%|██████████| 21/21 [00:17<00:00,  1.21it/s]


Test Loss: 2.4617, Test RMSE: 1.5674, Test PCC: 0.3768, Test R²: 0.0835
Epoch 17/100, test RMSE: 1.5674, test PCC: 0.3768, test R²: 0.0835
Best model saved with Val Loss: 1.3077


Epoch 18/100:  57%|█████▋    | 43/75 [00:53<00:51,  1.62s/it, Train Loss=0.21]  

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 18/100: 100%|██████████| 75/75 [01:44<00:00,  1.39s/it, Train Loss=0.159] 


Epoch 18/100, Train Loss: 0.1523


Validating: 100%|██████████| 9/9 [00:09<00:00,  1.08s/it]


Epoch 18/100, Val Loss: 1.3300, Val RMSE: 1.1353, Val PCC: 0.6276, Val R²: 0.3632


Testing: 100%|██████████| 21/21 [00:16<00:00,  1.26it/s]


Test Loss: 2.5058, Test RMSE: 1.5843, Test PCC: 0.3637, Test R²: 0.0635
Epoch 18/100, test RMSE: 1.5843, test PCC: 0.3637, test R²: 0.0635


Epoch 19/100:  67%|██████▋   | 50/75 [00:47<00:24,  1.02it/s, Train Loss=0.116] 

Skipping mutation K->S at ('A', 80): not found in seq_map
Skipping entry 1699 after transform.


Epoch 19/100: 100%|██████████| 75/75 [01:09<00:00,  1.08it/s, Train Loss=0.119] 


Epoch 19/100, Train Loss: 0.1256


Validating: 100%|██████████| 9/9 [00:09<00:00,  1.10s/it]


Epoch 19/100, Val Loss: 1.2663, Val RMSE: 1.1327, Val PCC: 0.6248, Val R²: 0.3662


Testing: 100%|██████████| 21/21 [00:14<00:00,  1.45it/s]


Test Loss: 2.4972, Test RMSE: 1.5829, Test PCC: 0.3544, Test R²: 0.0653
Epoch 19/100, test RMSE: 1.5829, test PCC: 0.3544, test R²: 0.0653
Best model saved with Val Loss: 1.2663


Epoch 20/100:   4%|▍         | 3/75 [00:06<02:07,  1.77s/it, Train Loss=0.0791]

Skipping mutation K->S at ('A', 80): not found in seq_mapSkipping entry 1699 after transform.



Epoch 20/100: 100%|██████████| 75/75 [01:07<00:00,  1.12it/s, Train Loss=0.0626]


Epoch 20/100, Train Loss: 0.1070


Validating: 100%|██████████| 9/9 [00:04<00:00,  1.91it/s]


Epoch 20/100, Val Loss: 1.3108, Val RMSE: 1.1602, Val PCC: 0.6166, Val R²: 0.3349


Testing: 100%|██████████| 21/21 [00:17<00:00,  1.22it/s]


Test Loss: 2.5124, Test RMSE: 1.5883, Test PCC: 0.3720, Test R²: 0.0589
Epoch 20/100, test RMSE: 1.5883, test PCC: 0.3720, test R²: 0.0589


Epoch 21/100:  67%|██████▋   | 50/75 [00:42<00:23,  1.05it/s, Train Loss=0.0697]

In [14]:
best_model_path = "/home/hongtan/pretrain_single/wt/best_model.pth"
model.to(device)
model.load_state_dict(torch.load(best_model_path))

# 在测试集上测试
test_model(model, test_loader, device='cuda:0')

  model.load_state_dict(torch.load(best_model_path))
Testing:   0%|          | 0/17 [00:00<?, ?it/s]

Skipping mutation I->L at ('A', 112): not found in seq_map
Skipping entry 144 after transform.


Testing: 100%|██████████| 17/17 [00:17<00:00,  1.02s/it]


Test Loss: 0.1540, Test RMSE: 0.3919, Test PCC: 0.9599, Test R²: 0.9198


(0.3918693, 0.9598762047998406, 0.9198229908943176, 0.15403935225570903)

Epoch 1/1000, Train Loss: 3.7755478318733506
Validating: 100%|██████████| 9/9 [00:15<00:00,  1.73s/it]
Validation Loss: 3.440200130144755
Epoch 1/1000, Val RMSE: 1.8626, Val PCC: 0.6669, Val R²: 0.4344
Epoch 2/1000, Train Loss: 2.769809875307204
Validating: 100%|██████████| 9/9 [00:18<00:00,  2.03s/it]
Validation Loss: 3.4820684327019586
Epoch 2/1000, Val RMSE: 1.8031, Val PCC: 0.6885, Val R²: 0.4700
Epoch 3/1000, Train Loss: 2.436043185523794
Validating: 100%|██████████| 9/9 [00:14<00:00,  1.66s/it]
Validation Loss: 3.3180822796291776
Epoch 3/1000, Val RMSE: 1.8264, Val PCC: 0.6820, Val R²: 0.4562
Epoch 4/1000, Train Loss: 2.396531761447086
Validating: 100%|██████████| 9/9 [00:19<00:00,  2.22s/it]
Validation Loss: 3.402861820326911
Epoch 4/1000, Val RMSE: 1.8588, Val PCC: 0.7065, Val R²: 0.4368
Epoch 5/1000, Train Loss: 2.2285565168042725
Validating: 100%|██████████| 9/9 [00:17<00:00,  1.97s/it]
Validation Loss: 3.3189518451690674
Epoch 5/1000, Val RMSE: 1.8318, Val PCC: 0.6858, Val R²: 0.4530
Epoch 6/1000, Train Loss: 2.437140565884264
Validating: 100%|██████████| 9/9 [00:17<00:00,  1.94s/it]
Validation Loss: 3.866072667969598
Epoch 6/1000, Val RMSE: 1.9175, Val PCC: 0.6353, Val R²: 0.4006
Epoch 7/1000, Train Loss: 2.5104055570650705
Validating: 100%|██████████| 9/9 [00:16<00:00,  1.81s/it]
Validation Loss: 4.321215364668104
Epoch 7/1000, Val RMSE: 2.0865, Val PCC: 0.5843, Val R²: 0.2903
Epoch 8/1000, Train Loss: 4.1250943533981905
Validating: 100%|██████████| 9/9 [00:16<00:00,  1.83s/it]
Validation Loss: 4.389447000291613
Epoch 8/1000, Val RMSE: 2.0988, Val PCC: 0.5502, Val R²: 0.2819
Epoch 9/1000, Train Loss: 3.9547084675559514
Validating: 100%|██████████| 9/9 [00:17<00:00,  1.98s/it]
Validation Loss: 4.163823180728489
Epoch 9/1000, Val RMSE: 2.0451, Val PCC: 0.5684, Val R²: 0.3181