# Transformer

## Structure

In [1]:
import torch.nn as nn
import torch
import math, os
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore")
current_dir = os.getcwd()


class TransformerKP(nn.Module):
    def __init__(self, input_dim, hidden_dim_list=[1024, 256], encoder_dim_list=[(8, 32), (1, 64)],
                 drop_ratio=0.12, norm_fun='batch_norm',
                 act_fun='gelu', encoder_with_res=False, encoder_norm=None,
                 encoder_drop_ratio=0.0, num_heads=1, residual_coef=1.0, device=torch.device('cuda:0' if torch.cuda.is_available() else "cpu")):
        super(TransformerKP, self).__init__()

        # fc blocks
        self.fc_layers = FCBlock(input_dim, hidden_dim_list, norm_fun, act_fun, drop_ratio)

        # encoder layers
        self.encoder_layers = TransformerBlock(hidden_dim_list, encoder_dim_list, drop_ratio, norm_fun,
                 act_fun, encoder_with_res, encoder_norm, encoder_drop_ratio, num_heads, residual_coef)

        self.output_layer = nn.Linear(encoder_dim_list[-1][0] * encoder_dim_list[-1][1], 1)

    def forward(self, x):
        x = x.to(torch.float32)
        x = self.fc_layers(x)
        x = self.encoder_layers(x)
        output = self.output_layer(x)
        return output


class FCBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim_list, norm_fun, act_fun, drop_ratio):
        super(FCBlock, self).__init__()

        self.act = nn.GELU() if act_fun == 'gelu' else nn.ReLU()

        if norm_fun == 'batch_norm':
            self.norm_fc = nn.ModuleList(nn.BatchNorm1d(dim) for dim in hidden_dim_list[1:])
        elif norm_fun == 'layer_norm':
            self.norm_fc = nn.ModuleList(nn.LayerNorm(dim) for dim in hidden_dim_list[1:])
        else:
            self.norm_fc = nn.ModuleList(nn.Identity() for dim in hidden_dim_list[1:])

        # self.norm_fc = nn.ModuleList([nn.BatchNorm1d(dim) if norm_fun == 'batch_norm' else nn.LayerNorm(dim) for dim in hidden_dim_list[1:]])
        self.drop = nn.Dropout(p=drop_ratio)

        # fc blocks
        self.fc_layers = nn.ModuleList()
        self.fc_layers.append(nn.Linear(input_dim, hidden_dim_list[0]))

        self.norm_fcs = nn.ModuleList()

        for idx in range(len(hidden_dim_list) - 1):
            self.fc_layers.append(nn.Linear(hidden_dim_list[idx], hidden_dim_list[idx + 1]))
            self.norm_fcs.append(self.norm_fc[idx])

    def forward(self, x):
        x = self.drop(self.act(self.fc_layers[0](x)))

        for fc_layer, norm in zip(self.fc_layers[1:], self.norm_fcs):
            x = self.drop(norm(self.act(fc_layer(x))))

        return x


class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim_list, encoder_dim_list, drop_ratio, norm_fun,
                 act_fun, encoder_with_res, encoder_norm, encoder_drop_ratio, num_heads, residual_coef):
        super(TransformerBlock, self).__init__()

        self.act = return_act_fun(act_fun)


        self.drop = nn.Dropout(p=drop_ratio)

        self.encoder_layers = nn.ModuleList()

        encoder_dim = hidden_dim_list[-1]  # 256
        for _num, _encoder_dim in encoder_dim_list:  # [(8, 32), (1, 64)]
            if norm_fun == 'batch_norm':
                self.norm_fun = nn.BatchNorm1d(_num * _encoder_dim)
            elif norm_fun == 'layer_norm':
                self.norm_fun = nn.LayerNorm(_num * _encoder_dim)
            else:
                self.norm_fun = nn.Identity()

            linear_ = nn.Sequential(
                nn.Linear(encoder_dim, _num * _encoder_dim),  # (8, 32)
                self.act,
                self.norm_fun,
                self.drop,
            )
            encoder_dim = _num * _encoder_dim
            reshape_layer = ReshapeLayer(_num, _encoder_dim)
            attention_ = SelfAttention(_encoder_dim, encoder_with_res, encoder_norm, encoder_drop_ratio, num_heads, residual_coef)
            flatten1_ = nn.Flatten()

            self.encoder_layers.append(linear_)
            self.encoder_layers.append(reshape_layer)
            self.encoder_layers.append(attention_)
            self.encoder_layers.append(flatten1_)

    def forward(self, x):
        for layer in self.encoder_layers:
            x = layer(x)

        return x


def return_act_fun(name):
    if name == 'silu':
        return nn.SiLU()
    elif name == 'relu':
        return nn.ReLU()
    elif name == 'gelu':
        return nn.GELU()
    elif name == 'lrelu':
        return nn.LeakyReLU(negative_slope=0.01)
    else:
        raise ValueError(f"Unsupported activation function: {name}")


class ReshapeLayer(nn.Module):
    def __init__(self, num, dim):
        super(ReshapeLayer, self).__init__()
        self.num = num
        self.dim = dim

    # @autocast(True)
    def forward(self, x):
        return x.view(x.size(0), self.num, self.dim)


class SelfAttention(nn.Module):
    def __init__(self, output_dim, encoder_with_res=False, encoder_norm=None, encoder_drop_ratio=0.0, num_heads=1, residual_coef=1.0):
        super(SelfAttention, self).__init__()
        self.output_dim = output_dim
        self.encoder_with_res = encoder_with_res
        self.residual_coef = residual_coef
        self.num_heads = num_heads
        self.encoder_norm = nn.LayerNorm(output_dim) if encoder_norm else None
        self.dropout = nn.Dropout(p=encoder_drop_ratio) if encoder_drop_ratio > 0 else None

        head_dim = output_dim // num_heads
        assert head_dim * num_heads == output_dim, "output_dim must be divisible by num_heads"

        # Define the weights
        self.WQ = nn.Parameter(torch.Tensor(output_dim, output_dim))
        self.WK = nn.Parameter(torch.Tensor(output_dim, output_dim))
        self.WV = nn.Parameter(torch.Tensor(output_dim, output_dim))

        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.WQ)
        nn.init.xavier_uniform_(self.WK)
        nn.init.xavier_uniform_(self.WV)

    def forward(self, x):
        batch_size, seq_length, _ = x.size()

        # Linear projections
        Q = torch.matmul(x, self.WQ).view(batch_size, seq_length, self.num_heads, self.output_dim // self.num_heads)
        K = torch.matmul(x, self.WK).view(batch_size, seq_length, self.num_heads, self.output_dim // self.num_heads)
        V = torch.matmul(x, self.WV).view(batch_size, seq_length, self.num_heads, self.output_dim // self.num_heads)

        Q = Q.permute(0, 2, 1, 3)
        K = K.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)

        # Scaled Dot-Product Attention
        QK = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.output_dim // self.num_heads)
        QK = F.softmax(QK, dim=-1)

        output = torch.matmul(QK, V).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, self.output_dim)

        # Apply dropout if specified
        if self.dropout:
            output = self.dropout(output)

        # Apply normalization if specified
        if self.encoder_norm:
            output = self.encoder_norm(output)

        # Add residual connection if specified
        if self.encoder_with_res:
            output = output + self.residual_coef * x

        return output

  from .autonotebook import tqdm as notebook_tqdm


## Train model

In [2]:
import pandas as pd
import random
import numpy as np
import torch.optim as optim
import torch.utils.data as Data
from hyperopt import fmin, tpe, hp, Trials, space_eval  # 超参数搜索
import json
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from scipy.stats import pearsonr
from sklearn.model_selection import train_test_split, KFold
import os


random_state = 66
random.seed(random_state)
np.random.seed(random_state)
torch.manual_seed(random_state)
torch.cuda.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state)

def return_scores(y_true, y_pred):
    y_true = np.ravel(y_true)
    y_pred = np.ravel(y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    pcc = pearsonr(y_true, y_pred)[0]
    return rmse, mae, r2, pcc


def return_data_loader(x, y, batch_size, shuffle=True, seed=66):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    x = torch.FloatTensor(x)
    y = torch.FloatTensor(y)
    label_loader = Data.DataLoader(Data.TensorDataset(x, y), batch_size=batch_size, shuffle=shuffle)
    
    return label_loader

def return_x_y(df_filtered):
    y = df_filtered[label_name].values
    mask = ~np.isnan(y)

    # factors
    auxiliary_data = []
    if use_t_ph_embedding:
        ph = df_filtered['ph'].values.reshape(-1, 1)
        t = df_filtered['t'].values.reshape(-1, 1)
        auxiliary_data.append(ph)
        auxiliary_data.append(t)

    if use_mw_logp:
        mw = df_filtered['mw'].values.reshape(-1, 1)
        logp = df_filtered['logp'].values.reshape(-1, 1)
        auxiliary_data.append(mw)
        auxiliary_data.append(logp)

    protein_data = np.array(df_filtered[protein_column].tolist())
    substrate_data = np.array(df_filtered[substrate_column].tolist())
    x = np.hstack([protein_data, substrate_data] + auxiliary_data)

    return x[mask], y[mask]


def train_one_epoch(model, optimizer, train_loader):
    model.train()
    loss_function = torch.nn.MSELoss()
    accu_loss_train = torch.zeros(1).to(device)  # 累计损失
    optimizer.zero_grad()

    for step, data in enumerate(train_loader):
        data, label_value = data[0].to(device), data[1].to(device)
        pred = model(data)

        loss = loss_function(pred.float().squeeze(), label_value.float())
        loss.backward()
        accu_loss_train += loss.detach()

        # 在更新权重之前，对梯度进行裁剪，使其不超过clip_value
        torch.nn.utils.clip_grad_value_([p for p in model.parameters() if p.requires_grad], clip_value=clip_value)
        optimizer.step()
        optimizer.zero_grad()

    return accu_loss_train.item() / (step + 1), model


def evaluate_model(model, data_loader, mode='search'):
    model.eval()
    all_pred = []
    all_labels = []

    with torch.no_grad():
        loss_function = torch.nn.MSELoss()
        accu_loss = torch.zeros(1).to(device)  # 累计损失

        for step, data in enumerate(data_loader):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)

            loss = loss_function(outputs.float().squeeze(), labels.float())
            accu_loss += loss.detach()

            if mode != 'search':
                all_pred.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

    torch.cuda.empty_cache()  # 清理未使用的缓存

    if mode == 'search':
        return accu_loss.item() / len(data_loader)  # 返回平均损失

    else:
        return all_pred, all_labels

def search_model(params, train_x, train_y, val_x, val_y):
    # data loader
    train_loader = return_data_loader(train_x, train_y, batch_size=params['batch_size'], shuffle=True, seed=random_state)
    val_loader = return_data_loader(val_x, val_y, batch_size=params['batch_size'], shuffle=False, seed=random_state)

    model = TransformerKP(
                input_dim=len(train_x[0]),
                hidden_dim_list=params['hidden_dim_list'],
                encoder_dim_list=params['encoder_dim_list'],
                drop_ratio=params['drop_ratio'],
                norm_fun=params['norm_fun'],
                act_fun=params['act_fun'],
                encoder_with_res=params['encoder_with_res'],
                encoder_norm=params['encoder_norm'],
                encoder_drop_ratio=params['encoder_drop_ratio'],
                num_heads=params['num_heads'],
                residual_coef=params['residual_coef']
            ).to(device)

    # optimizer
    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(pg, lr=params['lr'], weight_decay=5E-5)  # optimizer
    lf = lambda x: ((1 + math.cos(x * math.pi / params['epochs'])) / 2) * (1 - params['lrf']) + params['lrf']  # cosine
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    best_loss = np.Inf
    best_epoch, patience_nums = 0, 0

    for epoch_idx in range(params['epochs']):
        # train
        train_loss, model = train_one_epoch(model, optimizer, train_loader)
        scheduler.step()

        # evaluate
        val_loss = evaluate_model(model, val_loader, mode='search')
        if epoch_idx % 100 == 0:
            print(f'Epoch {epoch_idx} Train loss {train_loss:.3f} Val loss {val_loss:.3f}')

        # compare
        if val_loss <= best_loss:
            best_epoch = epoch_idx
            best_loss = val_loss
            patience_nums = 0

        else:
            patience_nums += 1

        if patience_nums > patience:
            break

    # print Log
    if patience_nums > patience:
        print(f'Early stopped at epoch {best_epoch} best_val_loss {best_loss:.3f}')
    else:
        print(f'Stopped at epoch {best_epoch} best_val_loss {best_loss:.3f}')

    return best_loss


def _search_params(params):
    print(params)
    val_loss_list = []
    for fold_idx, (train_index, val_index) in enumerate(kf.split(df_train_val), start=1):
        print(f"Fold: {fold_idx}/5")
        df_train = df_train_val.iloc[train_index]
        df_val = df_train_val.iloc[val_index]

        train_x, train_y = return_x_y(df_train)
        val_x, val_y = return_x_y(df_val)

        val_loss = search_model(params, train_x, train_y, val_x, val_y)
        val_loss_list.append(val_loss)

    val_loss_mean = np.mean(val_loss_list, axis=0)
    print(f"val MSE loss mean: {val_loss_mean:.5f}\n")

    return val_loss_mean


def search_best_param(max_evals):
    space = {
        "lr": hp.uniform("lr", 1e-4, 1e-3),
        'lrf': hp.choice('lrf', [0.01]),
        "drop_ratio": hp.uniform("drop_ratio", 0, 0.5),
        'hidden_dim_list': hp.choice('hidden_dim_list', [
            (2048, 1024, 256),
            (1024, 256),
            (2048, 256)
        ]),
        'encoder_dim_list': hp.choice('encoder_dim_list', [
            [(8, 32), (1, 64)],
            [(8, 32), (2, 64), (1, 64)],
            [(4, 64), (2, 128), (1, 128)],
            [(4, 64), (2, 128), (1, 64)]
        ]),
        'norm_fun': hp.choice('norm_fun', ['batch_norm', 'layer_norm']),
        'act_fun': hp.choice('act_fun', ['gelu', 'relu']),
        'encoder_with_res': hp.choice('encoder_with_res', [False, True]),
        'encoder_norm': hp.choice('encoder_norm', ['layer_norm', None]),
        "encoder_drop_ratio": hp.uniform("encoder_drop_ratio", 0, 0.5),
        'num_heads': hp.choice('num_heads', [1, 2, 4]),
        'residual_coef': hp.uniform('residual_coef', 0, 1.0),
        'batch_size': hp.choice('batch_size', [256, 512, 1024, 2048]),
        'epochs': hp.choice('epochs', [200, 300, 400]),
    }

    trials = Trials()
    print(f'[Info] Starting parameter search with MSE_Loss...')
    best_params = fmin(fn=_search_params, space=space, algo=tpe.suggest, max_evals=max_evals, trials=trials)
    best_params = space_eval(space, best_params)

    # Save the best params to JSON
    with open(params_json_path, 'w') as json_file:
        json.dump(best_params, json_file)

    return best_params


# config
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
use_t_ph_embedding = True
use_mw_logp = True
search_max_evals = 60
patience = 30
clip_value = 0.8

protein_column,  substrate_column = 'prott5', 'molebert'
input_model = 'transformer'
label_name = 'logkcatkm'

df_input = pd.read_pickle(f'{current_dir}/../../data_process/dataset/df_all_log_transformed.pkl')
df_train_val, df_test = train_test_split(df_input, test_size=0.2, random_state=random_state)
kf = KFold(n_splits=5, shuffle=True, random_state=random_state)

params_json_path = f'{current_dir}/model_dict/{input_model}_params.json'
if os.path.exists(params_json_path):
    with open(params_json_path) as json_file:
        params = json.load(json_file)
else:
    params = search_best_param(search_max_evals)

print(f'Best params:{params}\n')

# Train
val_scores_list, test_scores_list = [], []
fold_results = []

for fold_idx, (train_index, val_index) in enumerate(kf.split(df_train_val), start=1):
    print(f"Fold: {fold_idx}/5")
    df_train = df_train_val.iloc[train_index]
    df_val = df_train_val.iloc[val_index]

    train_x, train_y = return_x_y(df_train)
    val_x, val_y = return_x_y(df_val)
    test_x, test_y = return_x_y(df_test)

    # data loader
    train_loader = return_data_loader(train_x, train_y, batch_size=params['batch_size'], shuffle=True, seed=random_state)
    val_loader = return_data_loader(val_x, val_y, batch_size=params['batch_size'], shuffle=False, seed=random_state)
    test_loader = return_data_loader(test_x, test_y, batch_size=params['batch_size'], shuffle=False, seed=random_state)

    model = TransformerKP(
                input_dim=len(train_x[0]),
                hidden_dim_list=params['hidden_dim_list'],
                encoder_dim_list=params['encoder_dim_list'],
                drop_ratio=params['drop_ratio'],
                norm_fun=params['norm_fun'],
                act_fun=params['act_fun'],
                encoder_with_res=params['encoder_with_res'],
                encoder_norm=params['encoder_norm'],
                encoder_drop_ratio=params['encoder_drop_ratio'],
                num_heads=params['num_heads'],
                residual_coef=params['residual_coef']
            ).to(device)

    # optimizer
    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(pg, lr=params['lr'], weight_decay=5E-5)  # optimizer
    lf = lambda x: ((1 + math.cos(x * math.pi / params['epochs'])) / 2) * (1 - params['lrf']) + params['lrf']  # cosine
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    best_loss = np.Inf
    best_epoch, patience_nums, best_model = 0, 0, None

    # train
    for epoch_idx in range(params['epochs']):
        train_loss, model = train_one_epoch(model, optimizer, train_loader)
        scheduler.step()

        val_loss = evaluate_model(model, val_loader, mode='search')

        # compare
        if val_loss <= best_loss:
            best_model = model
            best_epoch = epoch_idx
            best_loss = val_loss
            patience_nums = 0

        else:
            patience_nums += 1

        if patience_nums > patience:
            print(f'Early stopped at epoch {best_epoch} best_val_loss {best_loss:.3f}')
            break
        if epoch_idx % 50 == 0:
            print(f"[Epoch {epoch_idx} fold {fold_idx} {label_name}] Train loss {train_loss:.3f} Val loss {val_loss:.3f}")

    val_pred, val_labels = evaluate_model(best_model, val_loader, mode='val')
    test_pred, test_labels = evaluate_model(best_model, test_loader, mode='test')

    # scores
    val_scores = return_scores(val_labels, val_pred)
    test_scores = return_scores(test_labels, test_pred)
    val_scores_list.append(val_scores)
    test_scores_list.append(test_scores)

    # fold
    fold_results.append([
        fold_idx,
        val_scores[0], val_scores[1], val_scores[2], val_scores[3],
        test_scores[0], test_scores[1], test_scores[2], test_scores[3]
    ])

# mean
val_scores_mean = np.mean(val_scores_list, axis=0)
test_scores_mean = np.mean(test_scores_list, axis=0)

print(f"Dimension of x: {train_x.shape[1]}")
print(f"[Val_mean] rmse {val_scores_mean[0]:.4f} mae {val_scores_mean[1]:.4f} r2 {val_scores_mean[2]:.4f} pcc {val_scores_mean[3]:.4f} "
      f"[Test_mean] rmse {test_scores_mean[0]:.4f} mae {test_scores_mean[1]:.4f} r2 {test_scores_mean[2]:.4f} pcc {test_scores_mean[3]:.4f}\n")

# save cvs
df_cv_results = pd.DataFrame(fold_results, columns=[
    "Fold",
    "Val_RMSE", "Val_MAE", "Val_R2", "Val_PCC",
    "Test_RMSE", "Test_MAE", "Test_R2", "Test_PCC"])
df_cv_results.to_excel(f"{current_dir}/results/{input_model}_cv_results.xlsx", index=False)
print("Results saved")

[Info] Starting parameter search with MSE_Loss...
  0%|          | 0/60 [00:00<?, ?trial/s, best loss=?]                                                      {'act_fun': 'relu', 'batch_size': 256, 'drop_ratio': 0.17175376482442656, 'encoder_dim_list': ((8, 32), (2, 64), (1, 64)), 'encoder_drop_ratio': 0.4277067000542414, 'encoder_norm': None, 'encoder_with_res': False, 'epochs': 300, 'hidden_dim_list': (1024, 256), 'lr': 0.0008343603673047218, 'lrf': 0.01, 'norm_fun': 'batch_norm', 'num_heads': 2, 'residual_coef': 0.7753877506491463}
  0%|          | 0/60 [00:00<?, ?trial/s, best loss=?]                                                      Fold: 1/5
  0%|          | 0/60 [00:00<?, ?trial/s, best loss=?]                                                      Epoch 0 Train loss 21.174 Val loss 21.857
  0%|          | 0/60 [00:07<?, ?trial/s, best loss=?]                                                      Epoch 100 Train loss 8.718 Val loss 10.694
  0%|          | 0/60 [01:22<