In [2]:
import pandas as pd
import numpy as np
from Bio import SeqIO
import sys
import json
import torch
import re
import os
import h5py
import lightning.pytorch as pl
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
warnings.filterwarnings('ignore')
np.random.seed(42)


In [3]:
from helper_fn_short_val import *


In [5]:
import yaml
config = yaml.safe_load(open("hypertune.yaml"))

## Read data

In [6]:
class_encode_path = config['class_encode_path']
truncated_label_path = config['truncated_label_path']
truncated_embed_path = config['truncated_embed_path']

In [7]:
val_tool = Validation_tool(class_encode_path) # provide your own class_encode.json

In [8]:
N_METALS = len(val_tool.class_encode)

In [9]:
labels = os.listdir(truncated_label_path)
LABEL_TEST_INX = [i for i in range(len(labels)) if 'test' in labels[i]][0]
test_label_name = labels[LABEL_TEST_INX]
del labels[LABEL_TEST_INX]
LABEL_POS_INX = [i for i in range(len(labels)) if 'pos' in labels[i]][0]

In [10]:
val_tool.dataset_class_summary(f'{truncated_label_path}{labels[LABEL_POS_INX]}')


Zn(2+)                   |3399
Mg(2+)                   |1350
[4Fe-4S] cluster         |537
Ca(2+)                   |2758
Mn(2+)                   |679
a divalent metal cation  |8648
Fe cation                |686
[2Fe-2S] cluster         |227
Cu cation                |415
K(+)                     |97
Ni(2+)                   |47
Na(+)                    |79
Fe(3+)                   |82
iron-sulfur cluster      |791
Cu(2+)                   |16
Fe(2+)                   |15
Co(2+)                   |32
a metal cation           |9850
neg                      |1056314


### Read embedding files

In [129]:
embeds = os.listdir(truncated_embed_path)
EMBED_TEST_INX = [i for i in range(len(embeds)) if 'test' in embeds[i]][0]
test_embeds_name = embeds[EMBED_TEST_INX]
del embeds[EMBED_TEST_INX] 
EMBED_POS_INX = [i for i in range(len(embeds)) if 'pos' in embeds[i]][0]

In [131]:
h5files = [h5py.File(f"{truncated_embed_path}{embeds[i]}", 'r') for i in range(len(embeds))]

### Train Validation Split

In [134]:
five_fold_splits = [val_tool.five_fold_val_split(f"{truncated_label_path}{labels[i]}") for i in range(len(labels))]

not enough data to divide ['CHEBI:29033', 'Fe(2+)'], add to all split!


## Fold generation

In [140]:
label_files = [np.load(f"{truncated_label_path}{labels[i]}") for i in range(len(labels))]
label = {}
for f in label_files:
    for k, v in f.items():
        label[k] = v

In [None]:
def fold_generator():
    for i in range(5):
        train_pos = five_fold_splits[LABEL_POS_INX][f'fold{i}']['train'].copy()
        train_neg = [five_fold_splits[i][f'fold{i}']['train'].copy() for i in range(len(five_fold_splits)) if i != LABEL_POS_INX]

        valid_acc = five_fold_splits[LABEL_POS_INX][f'fold{i}']['test'].copy()

        for j in range(len(labels)-1):
            if j != LABEL_POS_INX:
                valid_acc.extend(five_fold_splits[j][f'fold{i}']['test'])
        label_valid = {key: value for key, value in (label).items() if key in valid_acc}
     
        yield train_pos, train_neg, valid_acc, label_valid


## TFE model

In [241]:
class MyStreamDataset_over_sample_TFE:

    def __init__(self, pos_acc, neg_acc_ls, labels, pos, neg=None, precision=np.float16):
        self.pos_f = pos  # h5 file
        self.labels = labels  # labels for both pos and neg
        self.index = np.random.randint(0, len(neg_acc_ls))
        print(f"neg index: {self.index}")
        self.neg_f = neg[self.index]
        self.acc_ls = pos_acc + neg_acc_ls[self.index]
        self.pos_acc = list(self.pos_f.keys())
        self.labels = labels  # labels for both pos and neg
        self.dim = self.pos_f[list(self.pos_f.keys())[0]][()].shape[1]
        self.precision = precision

    def __len__(self):
        return len(self.acc_ls)

    def __getitem__(self, idx):
        acc = self.acc_ls[idx]
        if acc in self.pos_acc:
            embedding = self.pos_f[acc][()]
        else:
            embedding = self.neg_f[acc][()]

        label = self.labels[acc]

        prot_len = embedding.shape[0]

        if label.shape == (1,):
            label = np.zeros((N_METALS, prot_len), dtype=self.precision)

        if self.precision == np.float16:
            torch_type = torch.float16
        else:
            torch_type = torch.float32
        return embedding, torch.tensor(label, dtype=torch_type)

    def padding(self, batch, maxlen):
        batch_protein_feat = []
        batch_protein_mask = []
        for protein_feat in batch:
            padded_protein_feat = np.zeros((maxlen, self.dim))
            padded_protein_feat[:protein_feat.shape[0]] = protein_feat
            padded_protein_feat = torch.tensor(
                padded_protein_feat, dtype=torch.float)
            batch_protein_feat.append(padded_protein_feat)

            protein_mask = np.zeros(maxlen)
            protein_mask[:protein_feat.shape[0]] = 1
            protein_mask = torch.tensor(protein_mask, dtype=torch.long)
            batch_protein_mask.append(protein_mask)

        return torch.stack(batch_protein_feat), torch.stack(batch_protein_mask)

    def collate_fn(self, batch):
        features, labels = zip(*batch)
        max_batch_len = max([x.shape[0] for x in features])
        features, masks = self.padding(features, max_batch_len)
        return features, torch.hstack(labels), masks


In [242]:
class MyStreamDataset_TFE:
    def __init__(self, acc_ls, labels, files, precision=np.float16):
        self.files = files  # h5 file
        self.acc_ls = acc_ls
        self.labels = labels  # labels for both pos and neg
        f0 = self.files[0]
        self.dim = f0[list(f0.keys())[0]][()].shape[1]
        self.precision = precision
        self.dc_ls = {i: files[i].keys() for i in range(len(files))}


    def __len__(self):
        return len(self.acc_ls)

    def __getitem__(self, idx):
        acc = self.acc_ls[idx]
        for i, f in self.dc_ls.items():
            if acc in f:
                embedding = self.files[i][acc][()]
                break
        label = self.labels[acc]

        prot_len = embedding.shape[0]

        if label.shape == (1,):
            label = np.zeros((N_METALS, prot_len), dtype=self.precision)

        if self.precision == np.float16:
            torch_type = torch.float16
        else:
            torch_type = torch.float32
        return embedding, torch.tensor(label, dtype=torch_type)

    def padding(self, batch, maxlen):
        batch_protein_feat = []
        batch_protein_mask = []
        for protein_feat in batch:
            padded_protein_feat = np.zeros((maxlen, self.dim))
            padded_protein_feat[:protein_feat.shape[0]] = protein_feat
            padded_protein_feat = torch.tensor(
                padded_protein_feat, dtype=torch.float)
            batch_protein_feat.append(padded_protein_feat)

            protein_mask = np.zeros(maxlen)
            protein_mask[:protein_feat.shape[0]] = 1
            protein_mask = torch.tensor(protein_mask, dtype=torch.long)
            batch_protein_mask.append(protein_mask)

        return torch.stack(batch_protein_feat), torch.stack(batch_protein_mask)

    def collate_fn(self, batch):
        features, labels = zip(*batch)
        max_batch_len = max([x.shape[0] for x in features])
        features, masks = self.padding(features, max_batch_len)
        return features, torch.hstack(labels), masks


In [248]:


class Self_Attention(nn.Module):
    def __init__(self, num_hidden, num_heads=4, weight_matrix=False):
        super().__init__()
        self.num_heads = num_heads
        self.attention_head_size = int(num_hidden / num_heads)
        self.all_head_size = self.num_heads * self.attention_head_size
        self.wq = nn.Linear(num_hidden, self.all_head_size)
        self.wk = nn.Linear(num_hidden, self.all_head_size)
        self.wv = nn.Linear(num_hidden, self.all_head_size)
        self.wo = nn.Linear(self.all_head_size, num_hidden)
        self.weight_matrix = weight_matrix

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_heads,
                                       self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, q, k, v, mask=None):
        if self.weight_matrix:
            q = self.transpose_for_scores(self.wq(q))
            k = self.transpose_for_scores(self.wk(k))
            v = self.transpose_for_scores(self.wv(v))
        else:
            q = self.transpose_for_scores(q)
            k = self.transpose_for_scores(k)
            v = self.transpose_for_scores(v)

        attention_scores = torch.matmul(q, k.transpose(-1, -2))

        if mask is not None:
            attention_mask = (1.0 - mask) * -10000
            attention_scores = attention_scores + \
                attention_mask.unsqueeze(1).unsqueeze(1)

        attention_scores = nn.Softmax(dim=-1)(attention_scores)

        outputs = torch.matmul(attention_scores, v)

        outputs = outputs.permute(0, 2, 1, 3).contiguous()
        new_output_shape = outputs.size()[:-2] + (self.all_head_size,)
        outputs = outputs.view(*new_output_shape)
        if self.weight_matrix:
            outputs = self.wo(outputs)
        return outputs


class PositionWiseFeedForward(nn.Module):
    def __init__(self, num_hidden, num_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
        self.W_out = nn.Linear(num_ff, num_hidden, bias=True)

    def forward(self, h_V):
        h = F.leaky_relu(self.W_in(h_V))
        h = self.W_out(h)
        return h


class TransformerLayer(nn.Module):
    def __init__(self, num_hidden=64, num_heads=4, dropout=0.2):
        super(TransformerLayer, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.ModuleList(
            [nn.LayerNorm(num_hidden, eps=1e-6) for _ in range(2)])

        self.attention = Self_Attention(num_hidden, num_heads)
        self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)

    def forward(self, h_V, mask=None):
        dh = self.attention(h_V, h_V, h_V, mask)
        h_V = self.norm[0](h_V + self.dropout(dh))

        dh = self.dense(h_V)
        h_V = self.norm[1](h_V + self.dropout(dh))

        if mask is not None:
            mask = mask.unsqueeze(-1)
            h_V = mask * h_V
        return h_V


class MetalBPredictor(nn.Module):
    def __init__(self, feature_dim, hidden_dim=64, num_encoder_layers=2, num_heads=4, dropout=0.2):
        super(MetalBPredictor, self).__init__()

        self.input_block = nn.Sequential(
            nn.LayerNorm(feature_dim, eps=1e-6), nn.Linear(feature_dim,
                                                           hidden_dim), nn.LeakyReLU()
        )

        self.hidden_block = nn.Sequential(
            nn.LayerNorm(hidden_dim, eps=1e-6), nn.Dropout(dropout), nn.Linear(
                hidden_dim, hidden_dim), nn.LeakyReLU(), nn.LayerNorm(hidden_dim, eps=1e-6)
        )

        self.encoder_layers = nn.ModuleList([
            TransformerLayer(hidden_dim, num_heads, dropout)
            for _ in range(num_encoder_layers)
        ])

        self.dense = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(hidden_dim, N_METALS)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, protein_feat, mask):
        h_V = self.input_block(protein_feat)
        h_V = self.hidden_block(h_V)

        for layer in self.encoder_layers:
            h_V = layer(h_V, mask)

        x = self.dense(h_V)
        x = self.dropout(x)
        logits = self.out_proj(x)
        logits = torch.flatten(logits, end_dim=1)
        return logits

In [249]:
class TransformerModel(pl.LightningModule):
    def __init__(self, train_pos, train_neg, feature_dim, hidden_dim=64, num_encoder_layers=2, num_heads=4, dropout=0.2, lr=1e-3, label_weight=[0.228, 5.802], batch_size=32, thres_tune=False):
        super().__init__()
        self.encoder = MetalBPredictor(
            feature_dim, hidden_dim, num_encoder_layers, num_heads, dropout)
        self.save_hyperparameters()

        self.val_loss = 0
        self.test_loss = 0
        self.learning_rate = lr
        self.label_weight = label_weight
        self.batch_size = batch_size
        self.train_pos = train_pos
        self.train_neg = train_neg
        self.val_y = []
        self.val_pred = []
        self.test_y = []
        self.test_pred = []
        self.thres_tune = thres_tune

    def forward(self, x, mask):
        x = self.encoder(x, mask)
        return torch.squeeze(x)

    def loss_fn(self, out, target):
        weight = torch.ones_like(target)
        weight[target == 0] = self.label_weight[0]
        weight[target == 1] = self.label_weight[1]

        return torch.nn.BCEWithLogitsLoss(reduction='none', weight=weight)(out, target).mean(axis=0).sum()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.learning_rate, weight_decay=1e-5)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=64)
        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        x, y, masks = batch
        out = self(x, masks)
        out = out[(torch.flatten(masks) == 1), :]
        loss = self.loss_fn(out, y.T)
        self.log('train_loss', loss)
        return loss

    def train_dataloader(self):
        train_dataset = MyStreamDataset_over_sample_TFE(self.train_pos, self.train_neg, label, h5files[EMBED_POS_INX], [h5files[i] for i in range(len(h5files)) if i != EMBED_POS_INX], precision=np.float32)

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True)
        return train_dataloader

    def validation_step(self, batch, batch_idx):
        x, y, masks = batch
        out = self(x, masks)
        out = out[(torch.flatten(masks) == 1), :]
        loss = self.loss_fn(out, y.T)
        self.val_loss += loss
        self.val_y.append(y)
        self.val_pred.append(out.T)

    def on_validation_epoch_start(self):
        self.val_loss = 0
        self.val_y = []
        self.val_pred = []

    def on_validation_epoch_end(self):
        y = torch.hstack(self.val_y)
        out = torch.hstack(self.val_pred)
        out = torch.sigmoid(out)
        y_true = y.detach().cpu().numpy()
        y_pred = out.detach().cpu().numpy()
        dc = val_tool.sum_full_metrics(y_true, y_pred, 0.5)
        self.val_loss = self.val_loss / len(self.val_y)
        loss = self.val_loss
        print(f'Validation loss: {loss}')
        self.log('val_loss', loss)
        self.log('MCC', dc['mean MCC'])
        self.log('AUPR', dc['mean AUPR'])

    def test_step(self, batch, batch_idx):
        x, y, masks = batch
        out = self(x, masks)
        out = out[(torch.flatten(masks) == 1), :]
        loss = self.loss_fn(out, y.T)
        self.test_loss += loss
        self.test_y.append(y)
        self.test_pred.append(out.T)

    def on_test_epoch_end(self):
        y = torch.hstack(self.test_y)
        out = torch.hstack(self.test_pred)
        out = torch.sigmoid(out)
        y_true = y.detach().cpu().numpy()
        y_pred = out.detach().cpu().numpy()
        
        if self.thres_tune:
            self.y_true = y_true
            self.y_pred = y_pred
        else:
            dc = val_tool.sum_2metrics(y_true, y_pred, 0.5)
            val_tool.plot_pr_curve(y_true, y_pred)
            self.test_loss = self.test_loss / len(self.test_y)
            self.log('test_loss', self.test_loss)


## CNN model

In [250]:

class MyStreamDataset_over_sample_CNN:

    def __init__(self, pos_acc, neg_acc_ls, labels, pos, neg=None, max_len=512, precision=np.float16):
        self.pos_f = pos  # h5 file
        self.labels = labels  # labels for both pos and neg
        self.index = np.random.randint(0, len(neg_acc_ls))
        print(f"neg index: {self.index}")
        self.neg_f = neg[self.index]
        self.acc_ls = pos_acc + neg_acc_ls[self.index]
        self.pos_acc = list(self.pos_f.keys())
        self.labels = labels  # labels for both pos and neg
        self.dim = self.pos_f[list(self.pos_f.keys())[0]][()].shape[1]
        self.max_len = max_len
        self.precision = precision

    def __len__(self):
        return len(self.acc_ls)

    def __getitem__(self, idx):
        acc = self.acc_ls[idx]
        if acc in self.pos_acc:
            embedding = self.pos_f[acc][()]
        else:
            embedding = self.neg_f[acc][()]

        label = self.labels[acc]

        prot_len = embedding.shape[0]

        if label.shape == (1,):
            label = np.zeros((N_METALS, prot_len), dtype=self.precision)

        if self.precision == np.float16:
            torch_type = torch.float16
        else:
            torch_type = torch.float32

        features = np.zeros((self.dim, self.max_len), dtype=self.precision)
        features[:, :prot_len] = np.transpose(embedding)
        mask = np.zeros((self.max_len), dtype=bool)
        mask[:prot_len] = True

        return torch.tensor(features, dtype=torch_type), torch.tensor(label, dtype=torch_type), torch.tensor(mask, dtype=torch.bool)

    def collate_fn(self, batch):
        features, labels, masks = zip(*batch)
        return torch.stack(features), torch.hstack(labels), torch.hstack(masks)


In [251]:
class MyStreamDataset_CNN:
    def __init__(self, acc_ls, labels, files, max_len=512, precision=np.float16):
        self.files = files  # h5 file
        self.acc_ls = acc_ls
        self.labels = labels  # labels for both pos and neg
        f0 = self.files[0]
        self.dim = f0[list(f0.keys())[0]][()].shape[1]
        self.max_len = max_len
        self.precision = precision

    def __len__(self):
        return len(self.acc_ls)

    def __getitem__(self, idx):
        acc = self.acc_ls[idx]
        for f in self.files:
            f_acc = list(f.keys())
            if acc in f_acc:
                embedding = f[acc][()]
                break
        label = self.labels[acc]

        prot_len = embedding.shape[0]

        if label.shape == (1,):
            label = np.zeros((N_METALS, prot_len), dtype=self.precision)

        features = np.zeros((self.dim, self.max_len), dtype=self.precision)
        features[:, :prot_len] = np.transpose(embedding)
        mask = np.zeros((self.max_len), dtype=bool)
        mask[:prot_len] = True
        if self.precision == np.float16:
            torch_type = torch.float16
        else:
            torch_type = torch.float32
        return torch.tensor(features, dtype=torch_type), torch.tensor(label, dtype=torch_type), torch.tensor(mask, dtype=torch.bool)

    def collate_fn(self, batch):
        features, labels, masks = zip(*batch)
        return torch.stack(features), torch.hstack(labels), torch.hstack(masks)


In [259]:
class Conv1dModel(pl.LightningModule):
    def __init__(self, train_pos, train_neg, in_channels, hidden_channel, kernel_size, hidden_layer_num, lr=1e-3, label_weight=[0.228, 5.802], batch_size=32, thres_tune=False):
        super().__init__()
        stride = 1
        padding = int((kernel_size - 1) / 2)
        modules = []
        in_channel_ = in_channels
        feature_channel_ = hidden_channel
        for i in range(hidden_layer_num):
            modules.append(torch.nn.Conv1d(in_channels=in_channel_, out_channels=feature_channel_, kernel_size=kernel_size,
                                           stride=stride, padding=padding))
            modules.append(torch.nn.ELU())
            in_channel_ = feature_channel_
            feature_channel_ = feature_channel_//2

        modules.append(torch.nn.Conv1d(in_channels=in_channel_, out_channels=N_METALS,
                                       kernel_size=kernel_size, stride=stride, padding=padding))

        self.conv1 = torch.nn.Sequential(*modules)

        self.save_hyperparameters()

        self.val_loss = 0
        self.test_loss = 0
        self.learning_rate = lr
        self.label_weight = label_weight
        self.batch_size = batch_size
        self.train_pos = train_pos
        self.train_neg = train_neg
        self.val_y = []
        self.val_pred = []
        self.test_y = []
        self.test_pred = []
        self.thres_tune = thres_tune

    def forward(self, x):
        x = self.conv1(x)
        return torch.squeeze(x)

    def loss_fn(self, out, target):
        weight = torch.ones_like(target)
        weight[target == 0] = self.label_weight[0]
        weight[target == 1] = self.label_weight[1]
        return torch.nn.BCEWithLogitsLoss(reduction='none', weight=weight)(out, target).mean(axis=0).sum()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.learning_rate, weight_decay=1e-5)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=64)
        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        x, y, masks = batch
        out = self(x)
        out = torch.hstack([i for i in out])[:, masks].T
        loss = self.loss_fn(out, y.T)
        self.log('train_loss', loss)
        return loss

    def train_dataloader(self):
        train_dataset = MyStreamDataset_over_sample_CNN(self.train_pos, self.train_neg, label, h5files[EMBED_POS_INX], [
                                                    h5files[i] for i in range(len(h5files)) if i != EMBED_POS_INX], precision=np.float32)

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True)
        return train_dataloader

    def validation_step(self, batch, batch_idx):
        x, y, masks = batch
        out = self(x)
        out = torch.hstack([i for i in out])[:, masks]
        loss = self.loss_fn(out.T, y.T)
        self.val_loss += loss
        self.val_y.append(y)
        self.val_pred.append(out)

    def on_validation_epoch_start(self):
        self.val_loss = 0
        self.val_y = []
        self.val_pred = []

    def on_validation_epoch_end(self):
        y = torch.hstack(self.val_y)
        out = torch.hstack(self.val_pred)
        out = torch.sigmoid(out)
        y_true = y.detach().cpu().numpy()
        y_pred = out.detach().cpu().numpy()
        dc = val_tool.sum_full_metrics(y_true, y_pred, 0.5)
        self.val_loss = self.val_loss / len(self.val_y)
        loss = self.val_loss
        print(f'Validation loss: {loss}')
        self.log('val_loss', loss)
        self.log('MCC', dc['mean MCC'])
        self.log('AUPR', dc['mean AUPR'])

    def test_step(self, batch, batch_idx):
        x, y, masks = batch
        out = self(x)
        out = torch.hstack([i for i in out])[:, masks]
        loss = self.loss_fn(out.T, y.T)
        self.test_loss += loss
        self.test_y.append(y)
        self.test_pred.append(out)

    def on_test_epoch_end(self):
        y = torch.hstack(self.test_y)
        out = torch.hstack(self.test_pred)
        out = torch.sigmoid(out)
        y_true = y.detach().cpu().numpy()
        y_pred = out.detach().cpu().numpy()

        if self.thres_tune:
            self.y_true = y_true
            self.y_pred = y_pred
        else:
            dc = val_tool.sum_2metrics(y_true, y_pred, 0.5)
            val_tool.plot_pr_curve(y_true, y_pred)
            self.test_loss = self.test_loss / len(self.test_y)
            self.log('test_loss', self.test_loss)

## Hyperparameter tuning

In [262]:
import itertools

MODEL = config['model']
BATCH_SIZE = config['batch_size']

if MODEL == 'CNN2L':

    AUPR_ls = []

    feature_dim = h5files[0][list(h5files[0].keys())[0]][()].shape[1]

    hidden_channel_ls = config[MODEL]['hidden_channel']
    hidden_layer_num_ls = config[MODEL]['hidden_layer_num']
    kernel_size_ls = config[MODEL]['kernel_size']
    lr_ls = config[MODEL]['lr']
    label_weight_ls = [tuple(i) for i in config[MODEL]['label_weight']]

    max_AUPR = 0

    for i in itertools.product(hidden_channel_ls, kernel_size_ls, hidden_layer_num_ls, lr_ls, label_weight_ls):

        one_hyp_AUPR = 0
        paras = f"feature_channel: {i[0]}, kernel_size: {i[1]}, hidden_layer_num: {i[2]}, lr: {i[3]}, label_weight: {i[4]}"
        print(paras)
        with open('cnn_hyper.txt', 'a') as f:
            f.write(paras + '\n')

        hidden_channel = i[0]
        kernel_size = i[1]
        hidden_layer_num = i[2]
        lr = i[3]
        label_weight = i[4]
        fold_n = 1
        for trp, trn, val, val_l in fold_generator():
            print(f"==================== fold {fold_n} ====================")

            val_dataset = MyStreamDataset_TFE(
                val, val_l, h5files, precision=np.float32)
            val_dataloader = torch.utils.data.DataLoader(
                val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=val_dataset.collate_fn, drop_last=True)
            model = Conv1dModel(trp, trn, feature_dim, hidden_channel,
                                    kernel_size, hidden_layer_num, lr=lr, label_weight=label_weight, batch_size=BATCH_SIZE)
            early_stopping = EarlyStopping(monitor='AUPR', patience=2, mode='max')
            checkpoint_callback = ModelCheckpoint(
                filename='{epoch}-{val_loss:.6f}-{MCC:.3f}-{AUPR:.3f}', save_top_k=3, monitor="AUPR", save_last=True, mode='max')
            trainer = pl.Trainer(accelerator='gpu', min_epochs=32, max_epochs = 300, precision=32, callbacks=[early_stopping, checkpoint_callback], check_val_every_n_epoch=10, reload_dataloaders_every_n_epochs=1)
            trainer.fit(model=model, val_dataloaders=val_dataloader)
            one_hyp_AUPR += trainer.callback_metrics['AUPR']
            
            one_fold_summary = f"Fold {fold_n} val loss: {trainer.callback_metrics['val_loss']} AUPR: {trainer.callback_metrics['AUPR']}"
            print(one_fold_summary)
            with open('cnn_hyper3.txt', 'a') as f:
                f.write(one_fold_summary + '\n')
            fold_n += 1
        
        one_hyp_AUPR /= 5
        AUPR_ls.append(one_hyp_AUPR)
        if one_hyp_AUPR > max_AUPR:
            max_AUPR = one_hyp_AUPR
            best_AUPR_hyp = i

        one_search_summary = f"====================Hyper {i} AUPR: {one_hyp_AUPR} current best AUPR: {max_AUPR} ====================" + \
            '\n' + \
            f"==================== Best AUPR hyper: {best_AUPR_hyp} ====================" + '\n'
        with open('cnn_hyper.txt', 'a') as f:
            f.write(one_search_summary)

        print(one_search_summary)

    
elif MODEL == 'TFE':
    AUPR_ls = []

    feature_dim = h5files[0][list(h5files[0].keys())[0]][()].shape[1]
    hidden_dim_ls = config[MODEL]['hidden_dim']
    num_encoder_layers_ls = config[MODEL]['num_encoder_layers']
    num_heads_ls = config[MODEL]['num_heads']
    dropout_ls = config[MODEL]['dropout']
    lr_ls = config[MODEL]['lr']
    label_weight_ls = [tuple(i) for i in config[MODEL]['label_weight']]

    max_AUPR = 0

    for i in itertools.product(hidden_dim_ls, num_encoder_layers_ls, num_heads_ls, dropout_ls, lr_ls, label_weight_ls):
        
        one_hyp_AUPR = 0
        paras = f"hidden_dim: {i[0]}, num_encoder_layers: {i[1]}, num_heads: {i[2]}, dropout: {i[3]}, lr: {i[4]}, label_weight: {i[5]}"
        print(paras)
        with open('transformer_hyp.txt', 'a') as f:
            f.write(paras + '\n')
        hidden_dim = i[0]
        num_encoder_layers = i[1]
        num_heads = i[2]
        dropout = i[3]
        lr = i[4]
        label_weight = i[5]
        fold_n = 1
        for trp, trn, val, val_l in fold_generator():
            print(f"==================== fold {fold_n} ====================")

            val_dataset = MyStreamDataset_TFE(
                val, val_l, h5files, precision=np.float32)
            val_dataloader = torch.utils.data.DataLoader(
                val_dataset, batch_size=16, shuffle=False, collate_fn=val_dataset.collate_fn, drop_last=True)
            model = TransformerModel(trp, trn, feature_dim, hidden_dim,
                                    num_encoder_layers, num_heads, dropout, lr=lr, label_weight=label_weight, batch_size=BATCH_SIZE)
            early_stopping = EarlyStopping(monitor='AUPR', patience=2, mode='max')
            checkpoint_callback = ModelCheckpoint(
                filename='{epoch}-{val_loss:.6f}-{MCC:.3f}-{AUPR:.3f}', save_top_k=3, monitor="AUPR", save_last=True, mode='max')
            trainer = pl.Trainer(accelerator='gpu', min_epochs=50, max_epochs = 300, precision=32, callbacks=[early_stopping, checkpoint_callback], check_val_every_n_epoch=10, reload_dataloaders_every_n_epochs=1)
            trainer.fit(model=model, val_dataloaders=val_dataloader)
            one_hyp_AUPR += trainer.callback_metrics['AUPR']

            one_fold_summary = f"Fold {fold_n} val loss: {trainer.callback_metrics['val_loss']} AUPR: {trainer.callback_metrics['AUPR']}"
            print(one_fold_summary)
            with open('transformer_hyp.txt', 'a') as f:
                f.write(one_fold_summary + '\n')
            fold_n += 1

        one_hyp_AUPR /= 5
        AUPR_ls.append(one_hyp_AUPR)
        if one_hyp_AUPR > max_AUPR:
            max_AUPR = one_hyp_AUPR
            best_AUPR_hyp = i

        one_search_summary = f"==================== Hyper {i} AUPR: {one_hyp_AUPR} current best AUPR: {max_AUPR} ====================" + \
            '\n' + \
            f"==================== Best AUPR hyper: {best_AUPR_hyp} ====================" + '\n'
        with open('transformer_hyp.txt', 'a') as f:
            f.write(one_search_summary)

        print(one_search_summary)