In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# import os
# for dir, _, files in os.walk('kaggle/input'):

In [None]:
# !pip list
# !nvidia-smi
# !sudo find / -name 'libcudart.so.11.0'
# !conda install cudatoolkit
# %cd /usr/local/cuda-12.1/targets/x86_64-linux/lib
# %ls

In [None]:
## install torch 2.2.1 and cuda 12.1
!pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
!pip install lightning
# !pip install  dgl -f https://data.dgl.ai/wheels/torch-2.1/cu118/repo.html
!pip install  dgl -f https://data.dgl.ai/wheels/torch-2.2/cu121/repo.html
# !pip install wandb
# !pip install scikit-learn
# !pip install transformers

## install for torch 2.1.0 and cuda 11.8
# !pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
# !pip install lightning
# !pip install  dgl -f https://data.dgl.ai/wheels/torch-2.1/cu118/repo.html

## Export ENV

In [None]:
%env DGLBACKEND=pytorch

# Import constants in param.py

In [None]:
# import os
import platform
from typing import Literal

OS_PLATFORM = platform.system()

SP_LABELS = dict(NO_SP=0, SP=1, LIPO=2, TAT=3, PILIN=4, TATLIPO=5)
ORGANISMS = dict(EUKARYA=0, POSITIVE=1, NEGATIVE=2, ARCHAEA=3)

"""
DATA PREPARATION
"""
USE_PREPARED_DATA = False
TRAIN_PATH = 'data/sp_data/train_set.fasta'
BENCHMARK_PATH = 'data/sp_data/benchmark_set_sp5.fasta'

USE_SPLIT_DATASET = False
ON_ORGANISM: Literal['eukarya', 'others'] = 'others'  # use when you set `USE_SPLIT_DATASET=True`

"""
MODEL AND TRAINER CONFIGURATION
"""
# Training
MODEL_TYPE = "gconv"
DATA_TYPE: Literal['aa', 'smiles', 'graph'] = 'graph'
CONF_TYPE = 'default'
EPOCHS = 100
# ENV = 'kaggle'
USE_ORGANISM = True

# MODEL_TYPE = "transformer"
# DATA_TYPE: Literal['aa', 'smiles', 'graph'] = 'aa'
# CONF_TYPE = 'default'
# EPOCHS = 1
# # ENV = 'kaggle'
# USE_ORGANISM = True

BATCH_SIZE = 8
LEARNING_RATE = 1e-7
NUM_WORKERS = 0  # set to 0 because of some random_seeding reason
FREEZE_PRETRAINED = False

DEVICES: list[int] | str | int = 'auto'
ACCELERATOR = 'auto'

ENABLE_CHECKPOINTING = True

# Testing
CHECKPOINT: str = "bert_pretrained-aa-default-1_epochs=100.ckpt"

# DEVICE = 'cpu'  # use when applying old training process

"""
LOGGER CONFIGURATION
"""
USE_LOGGER = False
INPUT_DIR = '/kaggle/input'
WORKING_DIR = '/kaggle/working'

LOG_DIR = 'logs'  # relative path

# Directory _callbacks_

In [None]:
# import all callback in callback_utils.py here
from pathlib import Path
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

filename = f'{MODEL_TYPE}-{DATA_TYPE}-{CONF_TYPE}-{int(USE_ORGANISM)}_epochs={EPOCHS}'
model_checkpoint = ModelCheckpoint(
    dirpath=str(Path(WORKING_DIR, 'checkpoints')),
    filename=filename,
    enable_version_counter=True,
    monitor='val_loss',
    every_n_epochs=1,
    save_on_train_epoch_end=True,
    mode='min',
    save_top_k=1,
)  # return location: ~/checkpoints/<model>-<data>-<conf>-<used_org>_epochs=<epochs>[_v<ver>].ckpt
model_checkpoint.CHECKPOINT_JOIN_CHAR = '_'

early_stopping = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=11,
    verbose=True,
    check_finite=True,
    mode="min"
)


# Directory _configs_

In [None]:
# import all utils in config_utils.py
import json

from transformers import BertConfig


def load_config(model, data, conf_type):
    if model == "bert_pretrained":
        config = BertConfig().from_pretrained("Rostlab/prot_bert")
        return config
    elif model == "bert":
        with open(str(Path(INPUT_DIR, f'sppredictor-config-{data}', f'{model}_config_default.json'))) as f:
            data = json.load(f)
            config = BertConfig(**data)
            return config
    else:
        conf_path = str(Path(INPUT_DIR, f'sppredictor-config-{data}', f'{model}_config_{conf_type}.json'))
        if os.path.exists(conf_path):
            with open(conf_path, 'r') as f:
                config = json.load(f)
                return config
        else:
            raise FileNotFoundError("Config file does not exist")




## Metrics

In [None]:
from typing import Any, Literal, Optional

import torch
from torchmetrics import MatthewsCorrCoef, Metric
from torchmetrics.classification import MulticlassMatthewsCorrCoef, BinaryMatthewsCorrCoef, MultilabelMatthewsCorrCoef
from torchmetrics.utilities.enums import ClassificationTask


def _matthews_corrcoef_non_average(confmat: torch.Tensor):
    mcc = []
    tps = torch.diag(confmat)
    fps = torch.sum(confmat, dim=0) - tps
    fns = torch.sum(confmat, dim=1) - tps
    tns = torch.sum(confmat) - (tps + fns + fps)

    for tp, fp, fn, tn in zip(tps, fps, fns, tns):
        numerator = (tp * tn - fp * fn)
        denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
        if denominator == 0:
            mcc.append(0)
        else:
            mcc.append((numerator / denominator).item())
    return torch.tensor(mcc, device=confmat.device)


class MulticlassMatthewsCorrCoefNoneAverage(MulticlassMatthewsCorrCoef):
    def compute(self) -> Any:
        return _matthews_corrcoef_non_average(self.confmat)


class MCC(MatthewsCorrCoef):
    def __new__(  # type: ignore[misc]
            cls,
            task: Literal["binary", "multiclass", "multilabel"],
            threshold: float = 0.5,
            num_classes: Optional[int] = None,
            num_labels: Optional[int] = None,
            average: Literal["micro"] | None = 'micro',
            ignore_index: Optional[int] = None,
            validate_args: bool = True,
            **kwargs: Any,
    ) -> Metric:
        """Initialize task and average metric."""
        task = ClassificationTask.from_str(task)
        kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
        if average == "micro":
            if task == ClassificationTask.BINARY:
                return BinaryMatthewsCorrCoef(threshold, **kwargs)
            if task == ClassificationTask.MULTICLASS:
                if not isinstance(num_classes, int):
                    raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
                return MulticlassMatthewsCorrCoef(num_classes, **kwargs)
            if task == ClassificationTask.MULTILABEL:
                if not isinstance(num_labels, int):
                    raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
                return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs)
            raise ValueError(f"Not handled value: {task}")
        else:
            # if task == ClassificationTask.BINARY:
            #     return BinaryMatthewsCorrCoefNoneAverage(threshold, **kwargs)
            if task == ClassificationTask.MULTICLASS:
                if not isinstance(num_classes, int):
                    raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
                return MulticlassMatthewsCorrCoefNoneAverage(num_classes, **kwargs)
            # if task == ClassificationTask.MULTILABEL:
            #     if not isinstance(num_labels, int):
            #         raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
            #     return MultilabelMatthewsCorrCoefNoneAverage(num_labels, threshold, **kwargs)
            raise ValueError(f"Not handled value: {task}")

# Directory _models_

In [None]:
from dgl.nn.pytorch import GraphConv
from itertools import islice
# Neural Network Layers

import math

from torch import nn


class OrganismEmbedding(nn.Module):
    def __init__(self, num_orgs: int = 4, e_dim: int = 512):
        super().__init__()
        self.num_orgs = num_orgs
        self.embedding_dim = e_dim
        torch.random.manual_seed(0)
        oe = torch.randn(num_orgs, e_dim)
        self.organism_embedding = nn.Embedding.from_pretrained(oe, freeze=True)

    def forward(self, x):
        return self.organism_embedding(x)


class InputEmbedding(nn.Module):
    def __init__(self, vocab_size: int = 100, d_model: int = 512):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.input_embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_model
        )

    def forward(self, x):
        x = self.input_embedding(x) * math.sqrt(self.d_model)
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int = 512, dropout: float = 0.1, max_len: int = 2048):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        pe_x = self.pe[:, :x.size(1), :]
        x = x + pe_x
        return self.dropout(x)


class LinearPositionalEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 512, dropout: float = 0.1):
        super().__init__()
        self.pe = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        return self.pe(x)


class TransformerEncoder(nn.Module):
    def __init__(self, d_model: int = 512, nhead: int = 8, num_layers: int = 6):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_layers
        )

    def forward(self, x, mask=None):
        x = self.encoder(x, src_key_padding_mask=mask)
        # use average [CLS] token with all other word tokens
        # x = torch.mean(x, dim=1)

        # use only [CLS] token
        x = x[:, 0, :]
        return x


class ConvolutionalEncoder(nn.Module):
    def __init__(
            self,
            embedding_dim: int = 512,
            dropout: float = 0.1,
            kernel_size: int = 3,
            stride: int = 1,
            padding: int = 0,
            n_base: int = 1024,
    ):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Conv1d(in_channels=embedding_dim, out_channels=n_base, kernel_size=kernel_size, stride=stride,
                      padding=padding),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Conv1d(in_channels=n_base, out_channels=n_base * 4, kernel_size=kernel_size, stride=stride,
                      padding=padding),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Conv1d(in_channels=n_base * 4, out_channels=n_base, kernel_size=kernel_size, stride=stride,
                      padding=padding),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2)
        )
        self.conv4 = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Conv1d(in_channels=n_base, out_channels=embedding_dim, kernel_size=kernel_size, stride=stride,
                      padding=padding),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x


class LSTMEncoder(nn.Module):
    def __init__(
            self,
            embedding_dim: int = 512,
            hidden_size: int = 1024,
            n_layers: int = 4,
            dropout: float = 0.1,
            random_init: bool = False
    ):
        super().__init__()
        self.random_init = random_init
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            num_layers=n_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=False
        )

    def forward(self, x):
        out, (h_n, c_n) = self.lstm(x)
        return out, h_n, c_n


class StackedBiLSTMEncoder(nn.Module):
    def __init__(
            self,
            embedding_dim: int = 512,
            hidden_size: int = 1024,
            n_layers: int = 4,
            dropout: float = 0.1,
            random_init: bool = False
    ):
        super().__init__()
        self.random_init = random_init
        # Init state
        if random_init:
            (h_0, c_0) = self.__init_state(n_layers=n_layers, hidden_size=hidden_size)
            self.init_state = (h_0, c_0)

        self.bilstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            bidirectional=True,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout
        )

    def forward(self, x):
        if self.random_init:
            out, (h_n, c_n) = self.bilstm(x, (self.init_state[0].detach(), self.init_state[1].detach()))
            return out, h_n, c_n
        else:
            out, (h_n, c_n) = self.bilstm(x)
            return out, h_n, c_n

    @staticmethod
    def __init_state(n_layers: int = 4, hidden_size: int = 1024):
        h_0 = torch.zeros(n_layers * 2, BATCH_SIZE, hidden_size).requires_grad_(True)
        c_0 = torch.zeros(n_layers * 2, BATCH_SIZE, hidden_size).requires_grad_(True)
        return h_0, c_0


class ParallelBiLSTMEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        pass


class GraphConvEncoder(nn.Module):
    def __init__(self, d_model: int = 512, n_layers=2, dropout: float = 0.1, use_relu_act: bool = True,
                 d_hidden: int = 1024, use_special_tokens: bool = False):
        super().__init__()
        self.d_model = d_model
        self.use_special_tokens = use_special_tokens
        self.act = None
        if use_relu_act:
            self.act = nn.ReLU()
        self.convs = nn.ModuleList()
        convFirst = GraphConv(20, d_hidden, norm='both', bias=True, activation=self.act, allow_zero_in_degree=True)
        self.convs.append(convFirst)

        for i in range(1, n_layers - 1):
            convIn = GraphConv(d_hidden, d_hidden, norm='both', bias=True, activation=self.act,
                               allow_zero_in_degree=True)
            self.convs.append(convIn)

        convLast = GraphConv(d_hidden, d_model, norm='both', bias=True, activation=self.act,
                             allow_zero_in_degree=True)
        self.convs.append(convLast)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, h):
        for (i, conv) in enumerate(self.convs):
            h = conv(x, h)
        # h = self.return_batch(h, x.batch_num_nodes(), max_len='longest')
        if self.use_special_tokens:
            h, mask = self.return_batch_plus(h, x.batch_num_nodes(), max_len='longest')
            # h = torch.reshape(h, (-1, 20, self.d_model))
            h = self.dropout(h)
            return h, mask
        else:
            h = self.return_batch(h, x.batch_num_nodes(), max_len='longest')
            h = self.dropout(h)
            return h

    def return_batch(self, h, batch_num_nodes, max_len: str | int = 70):
        device = h.get_device()
        tmp = [list(islice(iter(h), 0, num_nodes)) for num_nodes in batch_num_nodes]
        ret = []
        if max_len == "longest":
            max_len = torch.max(batch_num_nodes).item()
        if not isinstance(max_len, int):
            raise ValueError('Use `int` or "longest"')

        for i, sample in enumerate(tmp):
            if len(sample) > max_len:
                sample = sample[:max_len]
            else:
                while len(sample) < max_len:
                    pad = torch.zeros(self.d_model, device=device)
                    sample.append(pad)
            ret.append(torch.stack(sample))

        return torch.stack(ret)

    def return_batch_plus(self, h, batch_num_nodes, max_len: str | int = 70):
        device = h.get_device()
        tmp = [list(islice(iter(h), 0, num_nodes)) for num_nodes in batch_num_nodes]
        ret = []
        if max_len == "longest":
            max_len = torch.max(batch_num_nodes).item()
        if not isinstance(max_len, int):
            raise ValueError('Use `int` or "longest"')
        mask = torch.zeros((len(batch_num_nodes), max_len + 2), dtype=torch.float, device=device)
        mask[:, -1] = float('-inf')
        for i, sample in enumerate(tmp):
            if len(sample) > max_len:
                sample = sample[:max_len]
            else:
                while len(sample) < max_len:
                    pad = torch.zeros(self.d_model, device=device)
                    sample.append(pad)
                    mask[i, len(sample)] = float('-inf')
            sample = self.add_special_tokens(sample, device)
            ret.append(torch.stack(sample))

        return torch.stack(ret), mask

    def add_special_tokens(self, sample, device):
        cls = torch.zeros(self.d_model, device=device)  # begin of sentence [CLS]
        eos = torch.zeros(self.d_model, device=device)  # end of sentence [EOS]
        return [cls, *sample, eos]


class Classifier(nn.Module):
    def __init__(self, num_class: int, d_model: int = 512, d_ff: int = 2048):
        super().__init__()
        self.ff1 = nn.Linear(in_features=d_model, out_features=d_ff)
        self.ff2 = nn.Linear(in_features=d_ff, out_features=num_class)
        self.act1 = nn.ReLU()
        # self.act2 = nn.Softmax()

    def forward(self, x):
        x = self.act1(self.ff1(x))
        # x = self.act2(self.ff2(x))
        x = self.ff2(x)
        return x

In [None]:
# models

from torch import nn

""" Transformer """


class TransformerClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.positional_encoding = PositionalEncoding(
            d_model=config['d_model'],
            dropout=config['dropout'],
            max_len=config['max_len']
        )
        self.encoder = TransformerEncoder(
            d_model=config['d_model'],
            nhead=config['nhead'],
            num_layers=config['num_layers']
        )
        self.classifier = Classifier(
            d_model=config['d_model'],
            num_class=len(SP_LABELS)
        )

    def forward(self, x, mask=None):
        x = self.input_embedding(x)
        x = self.positional_encoding(x)
        x = self.encoder(x, mask)
        x = self.classifier(x)
        return x


class TransformerOrganismClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.positional_encoding = PositionalEncoding(
            d_model=config['d_model'],
            dropout=config['dropout'],
            max_len=config['max_len']
        )
        self.encoder = TransformerEncoder(
            d_model=config['d_model'],
            nhead=config['nhead'],
            num_layers=config['num_layers']
        )
        self.organism_embedding = OrganismEmbedding(
            num_orgs=len(ORGANISMS),
            e_dim=config['d_model']
        )
        self.classifier = Classifier(
            d_model=config['d_model'] * 2,
            num_class=len(SP_LABELS)
        )

    def forward(self, x, org, mask=None):
        x = self.input_embedding(x)
        x = self.positional_encoding(x)
        x = self.encoder(x, mask)
        org = self.organism_embedding(org)
        inp = torch.cat((x, org), dim=1)  # concat along model dim
        out = self.classifier(inp)
        return out


""" CNN """


class ConvolutionalClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.conv_encoder = ConvolutionalEncoder(
            embedding_dim=config['d_model'],
            kernel_size=config['kernel_size'],
            n_base=config['n_base']
        )
        self.flatten = nn.Flatten()
        self.classifier = Classifier(num_class=len(SP_LABELS), d_model=119808)
        # self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        x = self.input_embedding(x)
        x = torch.transpose(x, 1, 2)
        x = self.conv_encoder(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x


class ConvolutionalOrganismClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.conv_encoder = ConvolutionalEncoder(
            embedding_dim=config['d_model'],
            kernel_size=config['kernel_size'],
            n_base=config['n_base']
        )
        self.flatten = nn.Flatten()
        self.organism_embedding = OrganismEmbedding(
            num_orgs=len(ORGANISMS),
            e_dim=config['d_model']
        )
        self.classifier = Classifier(num_class=len(SP_LABELS), d_model=120832)

    def forward(self, x, org):
        x = self.input_embedding(x)
        x = torch.transpose(x, 1, 2)
        x = self.conv_encoder(x)
        x = torch.transpose(x, 1, 2)
        org = self.organism_embedding(org)
        org = org.unsqueeze(1)
        inp = torch.cat((x, org), dim=1)
        inp = self.flatten(inp)
        out = self.classifier(inp)
        return out


""" LSTM """


class LSTMClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.encoder = LSTMEncoder(
            embedding_dim=config['d_model'],
            hidden_size=config['hidden_size'],
            n_layers=config['n_layers'],
            dropout=config['dropout'],

        )
        self.classifier = Classifier(num_class=len(SP_LABELS), d_model=config['hidden_size'])

    def forward(self, x):
        x = self.input_embedding(x)
        x, h_n, c_n = self.encoder(x)
        x = self.classifier(x[:, -1, :])
        return x


class LSTMOrganismClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.encoder = LSTMEncoder(
            embedding_dim=config['d_model'],
            hidden_size=config['hidden_size'],
            n_layers=config['n_layers'],
            dropout=config['dropout'],

        )
        self.organism_embedding = OrganismEmbedding(num_orgs=len(ORGANISMS), e_dim=config['hidden_size'])
        self.classifier = Classifier(num_class=len(SP_LABELS), d_model=config['hidden_size'] * 2)

    def forward(self, x, org):
        x = self.input_embedding(x)
        x, h_n, c_n = self.encoder(x)
        x = x[:, -1, :]
        org = self.organism_embedding(org)
        inp = torch.cat((x, org), dim=1)
        out = self.classifier(inp)
        return out


""" Stacked Bi-LTSM """


class StackedBiLSTMClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.stacked_encoder = StackedBiLSTMEncoder(
            embedding_dim=config['d_model'],
            hidden_size=config['hidden_size'],
            n_layers=config['n_layers'],
            dropout=config['dropout'],

        )
        self.classifier = Classifier(num_class=len(SP_LABELS), d_model=config['hidden_size'] * 2)

    def forward(self, x):
        x = self.input_embedding(x)
        x, h_n, c_n = self.stacked_encoder(x)
        x = self.classifier(x[:, -1, :])
        return x


class StackedBiLSTMOrganismClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.stacked_encoder = StackedBiLSTMEncoder(
            embedding_dim=config['d_model'],
            hidden_size=config['hidden_size'],
            n_layers=config['n_layers'],
            dropout=config['dropout'],

        )
        self.organism_embedding = OrganismEmbedding(num_orgs=len(ORGANISMS), e_dim=config['hidden_size'] * 2)
        self.classifier = Classifier(num_class=len(SP_LABELS), d_model=config['hidden_size'] * 4)

    def forward(self, x, org):
        x = self.input_embedding(x)
        x, h_n, c_n = self.stacked_encoder(x)
        x = x[:, -1, :]
        org = self.organism_embedding(org)
        inp = torch.cat((x, org), dim=1)
        out = self.classifier(inp)
        return out


""" CNN+Transformer """


class CNNTransformerClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.conv_encoder = ConvolutionalEncoder(
            embedding_dim=config['d_model'],
            kernel_size=config['kernel_size']
        )
        self.positional_encoding = PositionalEncoding(
            d_model=config['d_model'],
            dropout=config['dropout'],
            max_len=config['max_len']
        )
        self.trans_encoder = TransformerEncoder(
            d_model=config['d_model'],
            nhead=config['nhead'],
            num_layers=config['num_layers']
        )
        self.classifier = Classifier(
            d_model=config['d_model'],
            num_class=len(SP_LABELS)
        )

    def forward(self, x):
        x = self.input_embedding(x)
        x = torch.transpose(x, 1, 2)
        x = self.conv_encoder(x)
        x = torch.transpose(x, 1, 2)
        x = self.positional_encoding(x)
        x = self.trans_encoder(x)
        x = self.classifier(x)
        return x


class CNNTransformerOrganismClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.input_embedding = InputEmbedding(
            vocab_size=config['vocab_size'],
            d_model=config['d_model']
        )
        self.conv_encoder = ConvolutionalEncoder(
            embedding_dim=config['d_model'],
            kernel_size=config['kernel_size']
        )
        self.positional_encoding = PositionalEncoding(
            d_model=config['d_model'],
            dropout=config['dropout'],
            max_len=config['max_len']
        )
        self.trans_encoder = TransformerEncoder(
            d_model=config['d_model'],
            nhead=config['nhead'],
            num_layers=config['num_layers']
        )
        self.classifier = Classifier(
            d_model=config['d_model'] * 2,
            num_class=len(SP_LABELS)
        )
        self.organism_embedding = OrganismEmbedding(
            num_orgs=len(ORGANISMS),
            e_dim=config['d_model']
        )

    def forward(self, x, org):
        x = self.input_embedding(x)
        x = torch.transpose(x, 1, 2)
        x = self.conv_encoder(x)
        x = torch.transpose(x, 1, 2)
        x = self.positional_encoding(x)
        x = self.trans_encoder(x)
        org = self.organism_embedding(org)
        inp = torch.cat((x, org), dim=1)
        out = self.classifier(inp)
        return out


""" ProtBERT """


class ProtBertClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bert = BertModel(config=config)
        # self.bert_encoder = get_peft_model(encoder, peft_config)
        if FREEZE_PRETRAINED:
            self.freeze_pretrained_layer()
        self.classifier = Classifier(num_class=len(SP_LABELS), d_model=config.hidden_size)

    def forward(self, x):
        x = self.bert(x)
        x = x.last_hidden_state[:, 0, :]
        x = self.classifier(x)
        return x

    def freeze_pretrained_layer(self):
        for param in self.bert.parameters():
            param.requires_grad = False


class ProtBertOrganismClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert = BertModel(config=config)
        if FREEZE_PRETRAINED and MODEL_TYPE == "bert_pretrained":
            self.freeze_pretrained_layer()
        self.classifier = Classifier(num_class=len(SP_LABELS), d_model=config.hidden_size * 2)
        self.organism_embedding = OrganismEmbedding(num_orgs=len(ORGANISMS), e_dim=config.hidden_size)

    def forward(self, x, org):
        x = self.bert(x)
        x = x.last_hidden_state[:, 0, :]
        org = self.organism_embedding(org)
        inp = torch.cat((x, org), dim=1)
        out = self.classifier(inp)
        return out

    def freeze_pretrained_layer(self):
        for param in self.bert.parameters():
            param.requires_grad = False

""" GRAPH CONV """
class GraphConvClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.graphconv_encoder = GraphConvEncoder(
            d_model=config['d_model'],
            dropout=config['dropout'],
            use_relu_act=config['use_relu_act'],
            d_hidden=config['d_hidden'],
            use_special_tokens=False
        )
        self.classifier = Classifier(
            d_model=config['d_model'] * 2,
            num_class=len(params.SP_LABELS)
        )

    def forward(self, x):
        x = self.graphconv_encoder(x, x.ndata['n_feat'])
        x = torch.mean(x, dim=1)
        x = self.classifier(x)
        return x


class GraphConvOrganismClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.graphconv_encoder = GraphConvEncoder(
            d_model=config['d_model'],
            dropout=config['dropout'],
            use_relu_act=config['use_relu_act'],
            d_hidden=config['d_hidden']
        )
        self.organism_embedding = OrganismEmbedding(
            num_orgs=len(ORGANISMS),
            e_dim=config['d_model']
        )
        self.classifier = Classifier(
            d_model=config['d_model'] * 2,
            num_class=len(SP_LABELS)
        )

    def forward(self, x, org):
        x = self.graphconv_encoder(x, x.ndata['n_feat'])
        x = torch.mean(x, dim=1)
        org = self.organism_embedding(org)
        inp = torch.cat((x, org), dim=1)
        out = self.classifier(inp)
        return out


""" GRAPH CONV TRANS """


class GraphConvTransformerClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.graphconv_encoder = GraphConvEncoder(
            d_model=config['d_model'],
            dropout=config['dropout'],
            use_relu_act=config['use_relu_act'],
            d_hidden=config['d_hidden']
        )
        self.transformer_encoder = TransformerEncoder(
            d_model=config['d_model'],
            nhead=config['nhead'],
            num_layers=config['num_layers']
        )
        self.classifier = Classifier(
            d_model=config['d_model'],
            num_class=len(SP_LABELS)
        )

    def forward(self, x, mask=None):
        x, mask = self.graphconv_encoder(x, x.ndata['n_feat'])
        x = self.transformer_encoder(x, mask)
        x = self.classifier(x)
        return x


class GraphConvTransformerOrganismClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.graphconv_encoder = GraphConvEncoder(
            d_model=config['d_model'],
            dropout=config['dropout'],
            use_relu_act=config['use_relu_act'],
            d_hidden=config['d_hidden']
        )
        self.transformer_encoder = TransformerEncoder(
            d_model=config['d_model'],
            nhead=config['nhead'],
            num_layers=config['num_layers']
        )
        self.organism_embedding = OrganismEmbedding(
            num_orgs=len(ORGANISMS),
            e_dim=config['d_model']
        )
        self.classifier = Classifier(
            d_model=config['d_model'] * 2,
            num_class=len(SP_LABELS)
        )

    def forward(self, x, org, mask=None):
        x, mask = self.graphconv_encoder(x, x.ndata['n_feat'])
        x = self.transformer_encoder(x, mask)
        org = self.organism_embedding(org)
        inp = torch.cat((x, org), dim=1)
        out = self.classifier(inp)
        return out



In [None]:
# import all model utils in model_utils.py
from transformers import BertModel

# Define types of model
TRANSFORMER = 'transformer'
CNN = 'cnn'
LSTM = 'lstm'
STACKED_BILSTM = 'st_bilstm'
BERT = 'bert'
BERT_PRETRAINED = 'bert_pretrained'
CNN_TRANSFORMER = 'cnn_trans'
GCONV = 'gconv'
GCONV_TRANSFORMER = 'gconv_trans'


def load_model(model_type, data_type, conf_type, use_organism=False):
    config = load_config(model_type, data_type, conf_type)
    if use_organism:
        if model_type == TRANSFORMER:
            return TransformerOrganismClassifier(config)
        elif model_type == CNN:
            return ConvolutionalOrganismClassifier(config)
        elif model_type == STACKED_BILSTM:
            return StackedBiLSTMOrganismClassifier(config)
        elif model_type == BERT or model_type == BERT_PRETRAINED:
            return ProtBertOrganismClassifier(config)
        elif model_type == LSTM:
            return LSTMOrganismClassifier(config)
        elif model_type == CNN_TRANSFORMER:
            return CNNTransformerOrganismClassifier(config)
        elif model_type == GCONV and data_type == 'graph':
            return GraphConvOrganismClassifier(config)
        elif model_type == GCONV_TRANSFORMER and data_type == 'graph':
            return GraphConvTransformerOrganismClassifier(config)
        else:
            return ValueError("Unknown model_type type")
    else:
        if model_type == TRANSFORMER:
            return TransformerClassifier(config)
        elif model_type == CNN:
            return ConvolutionalClassifier(config)
        elif model_type == STACKED_BILSTM:
            return StackedBiLSTMClassifier(config)
        elif model_type == BERT or model_type == BERT_PRETRAINED:
            return ProtBertClassifier(config)
        elif model_type == LSTM:
            return LSTMClassifier(config)
        elif model_type == CNN_TRANSFORMER:
            return CNNTransformerClassifier(config)
        elif model_type == GCONV and data_type == 'graph':
            return GraphConvClassifier(config)
        elif model_type == GCONV_TRANSFORMER and data_type == 'graph':
            return GraphConvTransformerClassifier(config)
        else:
            return ValueError("Unknown model_type type")


# Directory _data_

In [None]:
import dgl
# code SPDataset (extends torch.utils.data.Dataset)

# from transformers import BertTokenizer
# from tokenizers import Tokenizer


from typing import Optional

from torch.utils.data import Dataset


class SPDataset(Dataset):
    def __init__(self, json_paths: Optional[list[str]], data_type: str):
        self.data_type = data_type
        if json_paths is None or isinstance(json_paths, str):
            raise ValueError('provide path to dataset in list of str')
        df = pd.DataFrame(self._read_jsons(json_paths))
        self.length = len(df)
        self.labels = df['label'].tolist()
        self.organisms = df['kingdom'].tolist()
        if data_type == 'graph':
            self.from_list = df['from_list'].tolist()
            self.to_list = df['to_list'].tolist()
            self.adj_matrix = df['adj_matrix'].tolist()
        else:
            self.smiles = df['smiles'].tolist()
            self.aa_seq = df['aa_seq'].tolist()

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        organism = torch.tensor(ORGANISMS[self.organisms[index]])
        label = torch.zeros(len(SP_LABELS), dtype=torch.int64)
        label[SP_LABELS[self.labels[index]]] = 1
        if self.data_type == 'graph':
            graph = dgl.graph((self.from_list[index], self.to_list[index]))
            graph = dgl.add_self_loop(graph)
            graph.ndata['n_feat'] = torch.tensor(self.adj_matrix[index], dtype=torch.float)
            return graph, label.clone().detach(), organism.clone().detach()
        else:
            seq = self.aa_seq[index] if self.data_type == 'aa' else self.smiles[index]
            return seq, label.clone().detach(), organism.clone().detach()  # return (list[int], list[int], int)

    @staticmethod
    def _read_jsons(json_paths: list[str]):
        data = []
        for path in json_paths:
            with open(path, 'r') as f:
                data.extend(json.load(f))
        return data


#### DataLoader

In [None]:
from typing import List
from torch.utils.data import Subset, Sampler, DataLoader, RandomSampler


class SPDataLoader(DataLoader):
    def __init__(
            self,
            dataset,
            shuffle=False,
            use_workers_init_fn=False,
            use_sp_sampler=False,
            use_graph_collate_fn=False,
            current_epoch=0,
            batch_size=1,
            num_workers=0,
            pin_memory=False
    ):
        self.dataset = dataset
        self.current_epoch = current_epoch
        self.batch_size = batch_size
        persistent_workers = False
        if num_workers > 0:
            persistent_workers = True
        worker_init_fn = None
        if use_workers_init_fn:
            worker_init_fn = self.worker_init_fn
        collate_fn = None
        if use_graph_collate_fn:
            collate_fn = SPDataLoader.graph_collate_fn
        if shuffle and use_sp_sampler:
            # warnings.warn("Do not set `shuffle` while using `use_sp_sampler`. Automatically set `shuffle=True`.")
            sp_sampler = SPBatchRandomSampler(dataset, batch_size, current_epoch, shuffle=True)
            super().__init__(
                dataset=dataset,
                batch_sampler=sp_sampler,
                num_workers=num_workers,
                persistent_workers=persistent_workers,
                worker_init_fn=worker_init_fn,
                collate_fn=collate_fn,
                pin_memory=pin_memory
            )
        else:
            super().__init__(
                dataset=dataset,
                shuffle=shuffle,
                batch_size=batch_size,
                num_workers=num_workers,
                persistent_workers=persistent_workers,
                worker_init_fn=worker_init_fn,
                collate_fn=collate_fn,
                pin_memory=pin_memory
            )

    def worker_init_fn(self, worker_id):
        seed = worker_id + self.current_epoch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    @staticmethod
    def graph_collate_fn(batch):
        graphs, lbs, organisms = map(list, zip(*batch))
        g_feats = dgl.batch(graphs)
        lbs = torch.stack(lbs)
        organisms = torch.stack(organisms)
        return g_feats, lbs, organisms


class SPBatchRandomSampler(Sampler[List[int]]):
    def __init__(self, dataset: Dataset, batch_size: int, current_epoch: int, valid_indices=None, shuffle=False,
                 replacement: bool = False, num_samples: Optional[int] = None, generator=None, drop_last=False):
        super(Sampler, self).__init__()
        if num_samples is None:
            num_samples = len(dataset)
        self.num_samples = num_samples
        self.dataset = dataset  # dataset must implement __len__ method
        if valid_indices is None:
            valid_indices = range(len(dataset))
        self.valid_indices = valid_indices
        data_source = Subset(dataset, valid_indices)
        self.batch_size = batch_size
        self.drop_last = drop_last
        if shuffle and generator is None:
            torch.manual_seed(current_epoch)
            torch.cuda.manual_seed(current_epoch)
        self.standard_sampler = RandomSampler(data_source=data_source, replacement=replacement,
                                              num_samples=num_samples, generator=generator)

    def __iter__(self):
        batch = []
        for idx in self.standard_sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        return math.ceil(self.num_samples / self.batch_size)

# Directory _tokenizer_

In [None]:
from transformers import GPT2TokenizerFast, BertTokenizer


# import utils from tokenizer_utils.py
def load_tokenizer(model_type, data_type):
    if data_type in ['aa', 'smiles']:
        tokenizer_path = str(Path(INPUT_DIR, 'sppredictor-tokenizer', f'tokenizer_{data_type}.json'))
        tokenizer = GPT2TokenizerFast(tokenizer_file=tokenizer_path)
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        if model_type == 'bert_pretrained':
            tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert")
        return tokenizer
    else:
        return None

# Directory _lightning_module_

In [None]:
from typing import Dict
# code Lightning Data Module
from typing import Optional

import lightning as L
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS


# from data.sp_dataset import SPDataset


class SPDataModule(L.LightningDataModule):
    def __init__(
            self,
            data_type: str,
            batch_size: int = 8,
            num_workers: int = 1,
            use_prepare_data: bool = False,
            use_split_dataset: bool = False
    ):
        super().__init__()
        self.current_training_epoch = 0
        self.save_hyperparameters()

        self.test_set = None
        self.val_set = None
        self.train_set = None

        self.data_type = data_type
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.use_prepare_data = use_prepare_data
        self.use_split_dataset = use_split_dataset
        self.persistent_workers = False
        if num_workers > 0:
            self.persistent_workers = True
        self.use_graph_collate_fn = False
        if self.data_type == "graph":
            self.use_graph_collate_fn = True

    # def prepare_data(self) -> None:
    #     if self.use_prepare_data and self.data_type == 'graph':
    #         dut.extract_3d_dataset_by_partition()
    #     elif self.use_prepare_data and self.data_type != 'graph':
    #         dut.extract_raw_dataset_by_partition(raw_path=ut.abspath(params.TRAIN_PATH))
    #         dut.extract_raw_dataset_by_partition(raw_path=ut.abspath(params.BENCHMARK_PATH), benchmark=True)

    def state_dict(self) -> Dict[str, Any]:
        state_dict = {
            'current_training_epoch': self.trainer.current_epoch
        }
        return state_dict

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        self.current_training_epoch = state_dict['current_training_epoch']

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == "fit" or stage is None:
            train_paths = [str(Path(INPUT_DIR, 'sppredictor-dataset', 'train_set_graph_partition_0.json')),
                           str(Path(INPUT_DIR, 'sppredictor-dataset', 'train_set_graph_partition_1.json'))]
            val_paths = [str(Path(INPUT_DIR, 'sppredictor-dataset', 'test_set_graph_partition_0.json')),
                         str(Path(INPUT_DIR, 'sppredictor-dataset', 'test_set_graph_partition_0.json'))]
            self.train_set = SPDataset(json_paths=train_paths, data_type=self.data_type)
            self.val_set = SPDataset(json_paths=val_paths, data_type=self.data_type)
        elif stage == "test":
            test_paths = [str(Path(INPUT_DIR, 'sppredictor-dataset', 'train_set_graph_partition_2.json')),
                          str(Path(INPUT_DIR, 'sppredictor-dataset', 'test_set_graph_partition_2.json'))]
            self.test_set = SPDataset(json_paths=test_paths, data_type=self.data_type)

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return SPDataLoader(self.train_set, current_epoch=self.trainer.current_epoch, shuffle=True, use_sp_sampler=True,
                            use_workers_init_fn=False, use_graph_collate_fn=self.use_graph_collate_fn,
                            batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return SPDataLoader(self.val_set, current_epoch=self.trainer.current_epoch, shuffle=False, use_sp_sampler=False,
                            use_workers_init_fn=False, use_graph_collate_fn=self.use_graph_collate_fn,
                            batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

    def test_dataloader(self) -> EVAL_DATALOADERS:
        return SPDataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, use_workers_init_fn=False,
                            use_sp_sampler=False, use_graph_collate_fn=self.use_graph_collate_fn)



In [None]:
import numpy as np
from sklearn.metrics import classification_report
import os.path
from typing import Any, Dict

import lightning as L
import pandas as pd
from torch import optim, Tensor
from torch.nn import CrossEntropyLoss, Softmax
from torch.optim import Optimizer
from torchmetrics import F1Score, AveragePrecision, Recall


# from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, matthews_corrcoef
# from torcheval.metrics.functional import multiclass_auroc, multiclass_auprc, multiclass_f1_score


class SPModule(L.LightningModule):
    def __init__(
            self,
            model_type: str,
            data_type: str,
            conf_type: str = 'default',
            use_organism: bool = False,
            batch_size: int = 8,
            lr: float = 1e-7
    ):
        super().__init__()
        self.save_hyperparameters()

        # Module params
        self.model_type = model_type
        self.data_type = data_type
        self.conf_type = conf_type
        self.use_organism = use_organism
        self.batch_size = batch_size
        self.lr = lr
        if model_type == 'bert' or model_type == 'bert_pretrained':
            self.lr = 1e-5  # according to TSignal, the learning rate for BERT model is fixed to 1e-5
        self.checkpoint_name = ''

#         loss_weight = torch.tensor([0.1, 0.3, 0.5, 0.5, 1, 1], dtype=torch.float)
        loss_weight = torch.tensor([0.15, 1, 1, 1, 1, 1], dtype=torch.float)

        self.loss_fn = CrossEntropyLoss(weight=loss_weight)
        # self.fabric = Fabric()

        # Load config (Remove if unnecessary)
        # self.config = cut.load_config()

        # Tokenizer
        self.tokenizer = load_tokenizer(model_type=model_type, data_type=data_type)

        # Load models
        self.model = load_model(
            model_type=model_type,
            data_type=data_type,
            conf_type=conf_type,
            use_organism=use_organism
        )

        # Load metrics
        self.f1 = F1Score(task='multiclass', num_classes=len(SP_LABELS), average=None)
        self.recall = Recall(task="multiclass", num_classes=len(SP_LABELS), average=None)
        self.mcc = MCC(task='multiclass', num_classes=len(SP_LABELS), average=None)
        self.average_precision = AveragePrecision(task='multiclass', num_classes=len(SP_LABELS), average=None)
        # self.metrics = MulticlassMetrics(num_classes=len(params.SP_LABELS), average=None, device=self.device)

        # Outputs from training process
        self.validation_outputs_lb = []
        self.validation_outputs_pred = []
        self.best_val_loss = 1e6

        self.test_outputs_lb_total = []
        self.test_outputs_pred_total = []
        self.test_outputs_lb_organism = [[], [], [], [], [], []]
        self.test_outputs_pred_organism = [[], [], [], [], [], []]

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=0.1)
        return optimizer

    def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer):
        optimizer.zero_grad(set_to_none=True)

    def backward(self, loss: Tensor, *args: Any, **kwargs: Any):
        loss.backward()

    def tokenize_input(self, x):
        # max_length = 0
        if self.model_type == 'bert' or self.model_type == 'bert_pretrained':
            max_length = self.model.config.max_position_embeddings
        else:
            max_length = self.model.config['max_len']
        encoded = self.tokenizer.batch_encode_plus(
            x,
            max_length=max_length,
            truncation=True,
            padding='max_length'
        )
        # print(len(encoded['input_ids'][0]))
        return torch.tensor(encoded['input_ids'], dtype=torch.int64, device=self.device)

    def base_step(self, batch, batch_idx):
        x, lb, organism = batch
        if self.tokenizer is not None:
            x = self.tokenize_input(x)
        # pred = None  # uncomment this line in case got error do not have variable `pred` defined
        if self.use_organism:
            pred = self.model(x, organism)
        else:
            pred = self.model(x)
        loss = self.loss_fn(pred.float(), lb.float())
        return x, lb, pred, loss, organism

    def training_step(self, batch, batch_idx):
        _, _, pred, loss, _ = self.base_step(batch, batch_idx)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        _, lb, pred, loss, _ = self.base_step(batch, batch_idx)
        self.validation_outputs_pred.extend(pred.tolist())
        self.validation_outputs_lb.extend(lb.tolist())
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, batch_size=self.batch_size)
        # return loss

    def on_validation_epoch_end(self):
        all_pred = torch.tensor(self.validation_outputs_pred, device=self.device)
        all_lb = torch.tensor(self.validation_outputs_lb, device=self.device)

        val_loss = self.loss_fn(all_pred.float(), all_lb.float())
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss

        all_lb = torch.argmax(all_lb, dim=1)

        self.f1.update(all_pred, all_lb)
        self.recall.update(all_pred, all_lb)
        self.mcc.update(all_pred, all_lb)
        self.average_precision.update(all_pred, all_lb)

        print(
            f"\nMetrics on validation set: "
            f"best_val_loss: {self.best_val_loss}, "
            f"f1: {self.f1.compute()}, "
            f"recall: {self.recall.compute()}, "
            f"mcc: {self.mcc.compute()}, "
            f"average_precision: {self.average_precision.compute()} \n"
        )

        self.validation_outputs_lb.clear()
        self.validation_outputs_pred.clear()
        self.f1.reset()
        self.recall.reset()
        self.mcc.reset()
        self.average_precision.reset()

    def test_step(self, batch, batch_idx):
        _, lb, pred, loss, organism = self.base_step(batch, batch_idx)

        pred = pred.clone().detach()

        # Update outputs for calculate metrics on each class (for total)
        self.test_outputs_pred_total.extend(pred.tolist())
        self.test_outputs_lb_total.extend(lb.tolist())

        # Update outputs for calculate metrics on each class (for each organism)
        for i, o in enumerate(organism):
            self.test_outputs_pred_organism[o].append(pred[i].tolist())
            self.test_outputs_lb_organism[o].append(lb[i].tolist())

    def on_test_end(self) -> None:
        # TODO: Tạo một metrics dict để lưu các giá trị này lại và print (xem xét tạo một func như class_report của sklearn

        softmax = Softmax()

        # Apply argmax on these outputs (only for label) and evaluate the metric results
        total_pred = torch.tensor(self.test_outputs_pred_total, device=self.device)
        total_lb = torch.tensor(self.test_outputs_lb_total, device=self.device)
        print(classification_report(torch.argmax(total_pred, dim=1).tolist(), torch.argmax(total_lb, dim=1).tolist(),
                                    zero_division=0))

        # Calculate metrics on each class (for both on total and on organisms)
        total_index = len(ORGANISMS)
        f1_test = [[], [], [], [], [], [], []]
        recall_test = [[], [], [], [], [], [], []]
        mcc_test = [[], [], [], [], [], [], []]
        average_precision_test = [[], [], [], [], [], [], []]
        for k, o in ORGANISMS.items():
            all_pred = softmax(torch.tensor(self.test_outputs_pred_organism[o], device=self.device))
            all_lb = torch.tensor(self.test_outputs_lb_organism[o], device=self.device)
            all_lb = torch.argmax(all_lb, dim=1)

            # Print the statistic (the following function has ERROR about syntax)
            self._save_results_to_txt(all_pred.clone().detach().cpu(), all_lb.clone().detach().cpu(), organism=k)

            f1_test[o] = (self.f1(all_pred, all_lb) * 100).tolist()
            recall_test[o] = (self.recall(all_pred, all_lb) * 100).tolist()
            mcc_test[o] = (self.mcc(all_pred, all_lb) * 100).tolist()
            average_precision_test[o] = (self.average_precision(all_pred, all_lb) * 100).tolist()

            print(
                f'\nMetrics on test set of {k}: '
                f'f1: {f1_test[o]}, '
                f'recall: {recall_test[o]}, '
                f'mcc: {mcc_test[o]}, '
                f'average_precision: {average_precision_test[o]} \n'
            )

            self.f1.reset()
            self.recall.reset()
            self.mcc.reset()
            self.average_precision.reset()

        all_pred = total_pred
        all_lb = total_lb
        all_lb = torch.argmax(all_lb, dim=1)

        self.f1.update(all_pred, all_lb)
        self.recall.update(all_pred, all_lb)
        self.mcc.update(all_pred, all_lb)
        self.average_precision.update(all_pred, all_lb)
        f1_test[total_index] = (self.f1.compute() * 100).tolist()
        recall_test[total_index] = (self.recall.compute() * 100).tolist()
        mcc_test[total_index] = (self.mcc.compute() * 100).tolist()
        average_precision_test[total_index] = (self.average_precision.compute() * 100).tolist()

        print(
            f'\nMetrics on test set of TOTAL: '
            f'f1: {f1_test[total_index]}, '
            f'recall: {recall_test[total_index]}, '
            f'mcc: {mcc_test[total_index]}, '
            f'average_precision: {average_precision_test[total_index]} \n'
        )

        self.f1.reset()
        self.recall.reset()
        self.mcc.reset()
        self.average_precision.reset()

        metric_dict = {
            "f1_score": f1_test,
            "recall": recall_test,
            "mcc": mcc_test,
            "average_precision": average_precision_test,
        }

        print(metric_dict)

        self._save_metrics_to_csv(metric_dict)

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        self.checkpoint_name = CHECKPOINT.split('.')[0]

    def _save_results_to_txt(self, test_prediction_results, test_true_results, organism):
        if not os.path.exists(str(Path(f"{WORKING_DIR}/out/results"))):
            os.makedirs(f"{WORKING_DIR}/out/results", exist_ok=True)

        softmax = Softmax()
        pred_path = f'{WORKING_DIR}/out/results/{organism}_test_prediction_by_{self.model_type}.txt'
        true_path = f'{WORKING_DIR}/out/results/{organism}_test_true.txt'

        np.savetxt(pred_path, softmax(test_prediction_results), fmt="%.4f")
        # np.savetxt(ut.abspath(true_path), test_true_results, fmt="%d")
        if not os.path.exists(true_path):
            np.savetxt(true_path, test_true_results, fmt="%d")

    def _save_metrics_to_csv(self, metric_dict):
        if not os.path.exists(str(Path(f"{WORKING_DIR}/out/metrics"))):
            os.makedirs(f"{WORKING_DIR}/out/metrics", exist_ok=True)

        for k, o in ORGANISMS.items():
            metrics_organisms = {
                "f1_score": metric_dict['f1_score'][o],
                "recall": metric_dict['recall'][o],
                "mcc": metric_dict['mcc'][o],
                "average_precision": metric_dict['average_precision'][o],
            }
            df = pd.DataFrame.from_dict(metrics_organisms).transpose().round(2)
            df.to_csv(str(Path(f'{WORKING_DIR}/out/metrics/{self.checkpoint_name}_test_{k}.csv')),
                      header=list(SP_LABELS.keys()), index_label='metrics', na_rep=str(0.0))

        total_index = len(ORGANISMS)
        metrics_total = {
            "f1_score": metric_dict['f1_score'][total_index],
            "recall": metric_dict['recall'][total_index],
            "mcc": metric_dict['mcc'][total_index],
            "average_precision": metric_dict['average_precision'][total_index],
        }
        df = pd.DataFrame().from_dict(metrics_total).transpose().round(2)
        df.to_csv(str(Path(f'{WORKING_DIR}/out/metrics/{self.checkpoint_name}_test_metrics_TOTAL.csv')),
                  header=list(SP_LABELS.keys()), index_label='metrics', na_rep=str(0.0))


In [None]:
# Training and Validation

import lightning as L
import torch
# from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger
from pathlib import Path
import os
import wandb


def train():
    torch.set_float32_matmul_precision('medium')
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    torch.cuda.empty_cache()

    # CLI parsing arguments
    log_dir = str(Path(WORKING_DIR, LOG_DIR))
    logger = False
    if USE_LOGGER:
        # access wandb
        wandb_api = '3c8685dbfce5b23f56fce47a675b7a3569dead2c'
        wandb.login(key=wandb_api)
        logger = WandbLogger(save_dir=log_dir, project='SPPredictor')
        if USE_ORGANISM:
            logger.experiment.name = f'{MODEL_TYPE}_{DATA_TYPE}_use_organism'
        else:
            logger.experiment.name = f'{MODEL_TYPE}_{DATA_TYPE}'
        logger.experiment.config['batch_size'] = BATCH_SIZE

    resume_ckpt = f'{MODEL_TYPE}-{DATA_TYPE}-{CONF_TYPE}-{int(USE_ORGANISM)}_epochs{EPOCHS}.ckpt'
    checkpoint = str(Path(INPUT_DIR, 'checkpoints', resume_ckpt))
    if not os.path.exists(checkpoint):
        checkpoint = None

    sp_module = SPModule(
        model_type=MODEL_TYPE,
        data_type=DATA_TYPE,
        conf_type=CONF_TYPE,
        use_organism=USE_ORGANISM,
        batch_size=BATCH_SIZE,
        lr=LEARNING_RATE,
    )

    sp_data_module = SPDataModule(
        data_type=DATA_TYPE,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
    )

    trainer = L.Trainer(
        devices=DEVICES,
        accelerator=ACCELERATOR,
        max_epochs=EPOCHS,
        logger=logger,
        val_check_interval=1.0,
        callbacks=[model_checkpoint],
    )

    trainer.fit(sp_module, datamodule=sp_data_module, ckpt_path=checkpoint)

    if logger:
        wandb.finish(quiet=True)


def test():
    torch.set_float32_matmul_precision('medium')

    # args = parse_arguments()
    checkpoint = str(Path(f'{INPUT_DIR}/checkpoints/{CHECKPOINT}'))
    if not os.path.exists(checkpoint):
        raise FileNotFoundError("Path does not exist. Check checkpoint path again")

    sp_module = SPModule.load_from_checkpoint(checkpoint_path=checkpoint)
    sp_data_module = SPDataModule.load_from_checkpoint(checkpoint_path=checkpoint)

    trainer = L.Trainer(
        devices=DEVICES,
        accelerator=ACCELERATOR,
        logger=False,
        enable_checkpointing=False
    )

    trainer.test(sp_module, datamodule=sp_data_module)


In [None]:
train()

In [None]:
# test()