# Train DL

> Deep neural nets for PSSM

## Overview

**Utilities**

`seed_everything(seed=123)` — Sets random seeds across Python, NumPy, and PyTorch for reproducibility. Ensures deterministic behavior on CUDA.

```python
seed_everything(
    seed=42,  # random seed for reproducibility
)
```

`init_weights(m, leaky=0.)` — Applies Kaiming initialization to Conv layers. Pass to `model.apply()` for weight initialization.

```python
model = CNN1D(ni=1024, nf=230).apply(
    init_weights,  # initializes Conv layers with Kaiming normal
)
```

---

**Layer Builders**

`lin_wn(ni, nf, dp=0.1, act=nn.SiLU)` — Creates a weight-normalized linear layer with BatchNorm, Dropout, and activation.

```python
layer = lin_wn(
    ni=1024,       # input features
    nf=512,        # output features  
    dp=0.1,        # dropout probability
    act=nn.SiLU,   # activation function (None to disable)
)
```

`conv_wn(ni, nf, ks=3, stride=1, padding=1, dp=0.1, act=nn.ReLU)` — Creates a weight-normalized 1D convolution with BatchNorm, Dropout, and activation.

```python
layer = conv_wn(
    ni=256,        # input channels
    nf=512,        # output channels
    ks=5,          # kernel size
    stride=1,      # stride
    padding=2,     # padding
    dp=0.1,        # dropout probability
    act=nn.ReLU,   # activation function
)
```

---

**Models**

`MLP(num_features, num_targets, hidden_units=[512, 218], dp=0.2)` — Builds a multi-layer perceptron with BatchNorm and PReLU activations.

```python
model = MLP(
    num_features=1024,          # input dimension (e.g., T5 embeddings)
    num_targets=230,            # output dimension (23 AA × 10 positions)
    hidden_units=[512, 256],    # list of hidden layer sizes
    dp=0.2,                     # dropout rate (currently commented out)
)
```

`CNN1D(ni, nf, amp_scale=16)` — 1D CNN that amplifies input features, applies convolutions with skip connections, then projects to output.

```python
model = CNN1D(
    ni=1024,        # input features
    nf=230,         # output features (flattened PSSM)
    amp_scale=16,   # amplification factor for feature expansion
).apply(init_weights)
```

`PSSM_model(n_features, n_targets, model='MLP')` — Wrapper that reshapes flat output to `(batch, 23, positions)` PSSM format with softmax-ready logits.

```python
model = PSSM_model(
    n_features=1024,   # input feature dimension
    n_targets=230,     # total targets (must be divisible by 23)
    model='CNN',       # 'MLP' or 'CNN' architecture
)
# Output shape: (batch, 23, 10) for 10 positions
```

---

**Dataset**

`GeneralDataset(df, feat_col, target_col=None, A=23, dtype=np.float32)` — PyTorch Dataset that extracts features and reshapes targets to `(23, L)` PSSM matrices.

```python
ds = GeneralDataset(
    df=train_df,           # DataFrame with features and targets
    feat_col=feat_col,     # Index/list of feature column names
    target_col=target_col, # Index/list of target columns (None for test mode)
    A=23,                  # number of amino acids (including pS, pT, pY)
    dtype=np.float32,      # data type for tensors
)
# Returns (X, y) where y.shape = (23, L)
```

---

**Loss Function**

`CE(logits, target_probs)` — Cross-entropy loss with soft labels. Applies log_softmax to logits and computes against target probability distributions.

```python
loss = CE(
    logits=model_output,      # (B, 23, 10) raw logits
    target_probs=target_pssm, # (B, 23, 10) target probabilities (sum to 1 per position)
)
```

---

**Metrics**

`KLD(logits, target_probs)` — Kullback-Leibler divergence between target distribution (p) and predicted softmax distribution (q).

```python
kl_div = KLD(
    logits=model_output,      # (B, 23, 10) raw logits
    target_probs=target_pssm, # (B, 23, 10) target probabilities
)
```

`JSD(logits, target_probs)` — Jensen-Shannon divergence (symmetric metric) between target and predicted distributions.

```python
js_div = JSD(
    logits=model_output,      # (B, 23, 10) raw logits  
    target_probs=target_pssm, # (B, 23, 10) target probabilities
)
```

---

**Training**

`train_dl(df, feat_col, target_col, split, model_func, ...)` — Trains a model on a single train/valid split using fastai's `Learner` with one-cycle policy.

```python
target, pred = train_dl(
    df=df,                     # full DataFrame
    feat_col=feat_col,         # feature column names
    target_col=target_col,     # target column names
    split=split0,              # (train_idx, valid_idx) tuple
    model_func=get_cnn,        # callable returning fresh model
    n_epoch=10,                # number of training epochs
    bs=32,                     # batch size
    lr=3e-3,                   # learning rate
    loss=CE,                   # loss function
    save='my_model',           # save to models/my_model.pth
    sampler=None,              # optional custom sampler
    lr_find=True,              # run lr_find before training
)
# Returns (target_df, pred_df) for validation set
```

`train_dl_cv(df, feat_col, target_col, splits, model_func, save=None, **kwargs)` — Cross-validation wrapper that trains across multiple folds and concatenates OOF predictions.

```python
oof = train_dl_cv(
    df=df,                     # full DataFrame
    feat_col=feat_col,         # feature column names
    target_col=target_col,     # target column names
    splits=splits,             # list of (train_idx, valid_idx) tuples
    model_func=get_cnn,        # callable returning fresh model
    save='cnn',                # saves as cnn_fold0.pth, cnn_fold1.pth, ...
    n_epoch=10,                # passed to train_dl
    lr=3e-3,                   # passed to train_dl
)
# Returns DataFrame with all OOF predictions + 'nfold' column
```

---

**Prediction**

`predict_dl(df, feat_col, target_col, model_func, model_pth)` — Loads a saved model and generates predictions for a DataFrame.

```python
preds = predict_dl(
    df=test_df,                # DataFrame to predict
    feat_col=feat_col,         # feature column names
    target_col=target_col,     # used for output column names
    model_func=get_cnn,        # must match saved architecture
    model_pth='cnn_fold0',     # model name (without .pth)
)
# Returns DataFrame with softmax probabilities, same shape as target_col
```

## Setup

In [None]:
#| default_exp dnn

In [None]:
#| export
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import fastcore.all as fc

import numpy as np, pandas as pd
import os, random
from katlas.data import *
from katlas.train import *
from katlas.pssm import *
from fastai.vision.all import *
import pandas as pd

## Utils

In [None]:
#| export
def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
seed_everything()

In [None]:
#| export
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def_device

'cpu'

## Load Data

In [None]:
# df=pd.read_parquet('paper/kinase_domain/train/pspa_t5.parquet')

In [None]:
# info=Data.get_kinase_info()

# info = info[info.pseudo=='0']

# info = info[info.kd_ID.notna()]

# subfamily_map = info[['kd_ID','subfamily']].drop_duplicates().set_index('kd_ID')['subfamily']

# pspa_info = pd.DataFrame(df.index.tolist(),columns=['kinase'])

# pspa_info['subfamily'] = pspa_info.kinase.map(subfamily_map)

# splits = get_splits(pspa_info, group='subfamily',nfold=5)

# split0 = splits[0]

In [None]:
# df=df.reset_index()

In [None]:
# df.columns

In [None]:
# # column name of feature and target
# feat_col = df.columns[df.columns.str.startswith('T5_')]
# target_col = df.columns[~df.columns.isin(feat_col)][1:]

In [None]:
# feat_col

In [None]:
# target_col

## Dataset

In [None]:
#| export
class GeneralDataset(Dataset):
    def __init__(self,
                 df,
                 feat_col,            # list/Index of feature columns (e.g., 100 cols)
                 target_col=None,     # list/Index of flattened PSSM cols; AA-first; A=23
                 A: int = 23,
                 dtype=np.float32):
        """
        If target_col is None -> test mode, returns only X.
        Otherwise returns (X, y) where y has shape (23, L), L inferred from target columns.
        """
        self.test = target_col is None
        self.aa = A

        # Features
        self.X = df[feat_col].to_numpy(dtype=dtype, copy=True)

        self.y = None
        if not self.test:
            y_flat = df[target_col].to_numpy(dtype=dtype, copy=True)

            total = y_flat.shape[1]
            if total % A != 0:
                raise ValueError(f"Target columns ({total}) not divisible by A={A}; cannot infer L.")
            self.position = total // self.aa

            # AA-first: reshape to (N, 23, L)
            self.y = y_flat.reshape(-1, A, self.position) # reshape from row-major flatten
            # if column-major as pandas.unstack is column major
            # self.y = y_flat.reshape(-1, self.position,self.aa).transpose(0, 2, 1) 

        self.len = len(df)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        X = torch.from_numpy(self.X[index])        # (feat_dim,)
        if self.test: return X
        y = torch.from_numpy(self.y[index])        # (23, L)
        return X, y

In [None]:
# # dataset
# ds = GeneralDataset(df,feat_col,target_col)

In [None]:
# len(ds)

In [None]:
# dl = DataLoader(ds, batch_size=64, shuffle=True)

In [None]:
# xb,yb = next(iter(dl))

# xb.shape,yb.shape

## Models

### MLP

In [None]:
#| export
def MLP(num_features, 
          num_targets,
          hidden_units = [512, 218],
          dp = 0.2):
    
    # Start with the first layer from num_features to the first hidden layer
    layers = [
        nn.Linear(num_features, hidden_units[0]),
        nn.BatchNorm1d(hidden_units[0]),
        # nn.Dropout(dp),
        nn.PReLU()
    ]
    
    # Loop over hidden units to create intermediate layers
    for i in range(len(hidden_units) - 1):
        layers.extend([
            nn.Linear(hidden_units[i], hidden_units[i+1]),
            nn.BatchNorm1d(hidden_units[i+1]),
            # nn.Dropout(dp),
            nn.PReLU()
        ])
    
    # Add the output layer
    layers.append(nn.Linear(hidden_units[-1], num_targets))
    
    model = nn.Sequential(*layers)
    
    return model

In [None]:
# n_feature = len(feat_col)
# n_target = len(target_col)

In [None]:
# model = MLP(n_feature, n_target)

In [None]:
# model(xb)

### CNN1D

In [None]:
#| export
def lin_wn(ni,nf,dp=0.1,act=nn.SiLU):
    "Weight norm of linear."
    layers = [
            nn.BatchNorm1d(ni),
            nn.Dropout(dp),
            nn.utils.parametrizations.weight_norm(nn.Linear(ni, nf)) 
    ]
    if act: layers.append(act())
    return nn.Sequential(*layers)

In [None]:
lin_wn(10,3)

Sequential(
  (0): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Dropout(p=0.1, inplace=False)
  (2): ParametrizedLinear(
    in_features=10, out_features=3, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): _WeightNorm()
      )
    )
  )
  (3): SiLU()
)

In [None]:
#| export
def conv_wn(ni, nf, ks=3, stride=1, padding=1, dp=0.1,act=nn.ReLU):
    "Weight norm of conv."
    layers = [
        nn.BatchNorm1d(ni),
        nn.Dropout(dp),
        nn.utils.parametrizations.weight_norm(nn.Conv1d(ni, nf, ks, stride, padding)) 
        ]
    if act: layers.append(act())
    return nn.Sequential(*layers)

In [None]:
#| export
class CNN1D(nn.Module):
    
    def __init__(self, ni, nf, amp_scale = 16):
        super().__init__()

        cha_1,cha_2,cha_3 = 256,512,512
        hidden_size = cha_1*amp_scale

        cha_po_1 = hidden_size//(cha_1*2)
        cha_po_2 = (hidden_size//(cha_1*4)) * cha_3
        
        self.lin = lin_wn(ni,hidden_size)
        
        # bs, 256, 16
        self.view = View(-1,cha_1,amp_scale)
        
        self.conv1 = nn.Sequential(
            conv_wn(cha_1, cha_2, ks=5, stride=1, padding=2, dp=0.1),
            nn.AdaptiveAvgPool1d(output_size = cha_po_1),
            conv_wn(cha_2, cha_2, ks=3, stride=1, padding=1, dp=0.1))
        
        self.conv2 = nn.Sequential(
            conv_wn(cha_2, cha_2, ks=3, stride=1, padding=1, dp=0.3),
            conv_wn(cha_2, cha_3, ks=5, stride=1, padding=2, dp=0.2))
        
        self.head = nn.Sequential(
            nn.MaxPool1d(kernel_size=4, stride=2, padding=1),
            nn.Flatten(),
            lin_wn(cha_po_2,nf,act=None) )


    def forward(self, x):
        # amplify features to 4096
        x = self.lin(x)
        
        # reshape to bs,256,16 for conv1d
        x = self.view(x) 

        x = self.conv1(x)
        
        x_s = x  # for skip connection (multiply)
        x = self.conv2(x)
        x = x * x_s

        # Final block
        x = self.head(x)

        return x

In [None]:
#| export
def init_weights(m, leaky=0.):
    "Initiate any Conv layer with Kaiming norm."
    if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d)): nn.init.kaiming_normal_(m.weight, a=leaky)

In [None]:
# model = CNN1D(n_feature,n_target).apply(init_weights)

In [None]:
# model(xb).shape

## Wrapper

In [None]:
#| export
class PSSM_model(nn.Module):
    def __init__(self, 
                 n_features,
                 n_targets,
                 model='MLP'):
        super().__init__()
        self.n_features=n_features
        self.n_targets=n_targets
        self.n_aa = 23
        if self.n_targets % self.n_aa != 0: raise ValueError(f"n_targets ({n_targets}) must be divisible by n_aa ({self.n_aa}).")
        self.n_positions = self.n_targets//self.n_aa
        
        if model =='MLP': self.model=MLP(self.n_features, self.n_targets)
        elif model =='CNN': self.model=CNN1D(self.n_features, self.n_targets).apply(init_weights)
        else: raise ValueError('model must be MLP or CNN.')
    def forward(self,x):
        logits = self.model(x).reshape(-1, self.n_aa,self.n_positions)
        return logits

In [None]:
# model = PSSM_model(n_feature,n_target)

In [None]:
# logits= model(xb)

In [None]:
# logits.shape

In [None]:
# def get_mlp(): return PSSM_model(n_feature,n_target,model='MLP')

# def get_cnn(): return PSSM_model(n_feature,n_target,model='CNN')

## Loss

In [None]:
#| export
def CE(logits: torch.Tensor,
       target_probs: torch.Tensor,
      ):
    """
    Cross-entropy with soft labels.
    logits:       (B, 20, 10)
    target_probs: (B, 20, 10), each column (over AA) sums to 1
    """
    logp = F.log_softmax(logits, dim=1)              # (B, 20, 10)
    ce   = -(target_probs * logp).sum(dim=1)         # (B, 10)
    return ce.mean()

In [None]:
# CE(logits,yb)

## Metrics

In [None]:
#| export
def KLD(logits: torch.Tensor,
          target_probs: torch.Tensor,
         ):
    """
    Averaged KL divergence across positions between target_probs (p) and softmax(logits) (q).
    
    logits:       (B, 20, 10)
    target_probs: (B, 20, 10), each column (over AA) sums to 1
    """
    logq = F.log_softmax(logits, dim=1)    # log q(x)
    logp = torch.log(target_probs + 1e-8) # log p(x), safe for zeros
    kl   = (target_probs * (logp - logq)).sum(dim=1)   # (B, 10)
    return kl.mean()

In [None]:
# KLD(logits,yb)

In [None]:
#| export
def JSD(logits: torch.Tensor,
        target_probs: torch.Tensor,
       ):
    """
    Averaged Jensen-Shannon Divergence across positions between target_probs (p) and softmax(logits) (q).

    logits:       (B, 20, 10)
    target_probs: (B, 20, 10), each column (over AA) sums to 1
    """
    # p, q distributions
    q = F.softmax(logits, dim=1)                # q(x)
    p = target_probs
    m = 0.5 * (p + q)                           # midpoint distribution

    # logs (with epsilon for stability)
    logp = torch.log(p + 1e-8)
    logq = torch.log(q + 1e-8)
    logm = torch.log(m + 1e-8)

    # KL(p||m) and KL(q||m)
    kld_pm = (p * (logp - logm)).sum(dim=1)
    kld_qm = (q * (logq - logm)).sum(dim=1)

    jsd = 0.5 * (kld_pm + kld_qm)               # (B, 10)
    return jsd.mean()

In [None]:
# JSD(logits,yb)

## Trainer

In [None]:
#| export
def train_dl(df, 
            feat_col, 
            target_col,
            split, # tuple of numpy array for split index
            model_func, # function to get pytorch model
             n_epoch = 4, # number of epochs
             bs = 32, # batch size
             lr = 1e-2, # will be useless if lr_find is True
            loss = CE, # loss function
            save = None, # models/{save}.pth
             sampler = None,
             lr_find=False, # if true, will use lr from lr_find
              ):
    "A DL trainer."
    
    train = df.loc[split[0]]
    valid = df.loc[split[1]]
    
    train_ds = GeneralDataset(train, feat_col, target_col)
    valid_ds = GeneralDataset(valid, feat_col, target_col)
    
    dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=bs, num_workers=min(fc.defaults.cpus, 4))

    model = model_func()
    learn = Learner(dls.to(def_device), model.to(def_device), loss, 
                    metrics= [KLD,JSD]
                    # cbs = [GradientClip(1.0)] # prevent overfitting
                   )
    
    if lr_find:
        # get learning rate
        lr = learn.lr_find()
        plt.show()
        plt.close()
        print(lr)

        
    print('lr in training is', lr)
    learn.fit_one_cycle(n_epoch,lr) #cbs = [SaveModelCallback(fname = 'best')] # save best model
    
    if save is not None:
        learn.save(save)
        
    pred,target = learn.get_preds()

    # row first
    pred  = F.softmax(pred, dim=1).reshape(len(valid),-1)
    target = target.reshape(len(valid),-1)

    # column first
    # pred  = F.softmax(pred, dim=1).permute(0, 2, 1).reshape(len(valid),-1)
    # target = target.permute(0, 2, 1).reshape(len(valid),-1)
    
    pred = pd.DataFrame(pred.detach().cpu().numpy(),index=valid.index,columns=target_col)
    target = pd.DataFrame(target.detach().cpu().numpy(),index=valid.index,columns=target_col)
    
    return target, pred

In [None]:
# target, pred = train_dl(df, 
#                         feat_col, 
#                         target_col,
#                         split0, 
#                         model_func=get_cnn,
#                         n_epoch=1,
#                         lr = 3e-3,
#                         lr_find=True,
#                         save = 'test')

In [None]:
# pred

In [None]:
# pred_pssm = recover_pssm(pred.iloc[0])
# pred_pssm.sum()

## Predict

In [None]:
#| export
def predict_dl(df, 
               feat_col, 
               target_col,
               model_func, # model architecture
               model_pth, # only name, not with .pth
              ):
    
    "Predict dataframe given a deep learning model"
    
    test_dset = GeneralDataset(df,feat_col)
    test_dl = DataLoader(test_dset,bs=512)
    
    model = model_func()
    
    learn = Learner(None, model.to(def_device), loss_func=1)
    learn.load(model_pth,weights_only=False)
    
    learn.model.eval()
    
    preds = []
    for data in test_dl:
        inputs = data.to(def_device)
        pred = learn.model(inputs)

        pred  = F.softmax(pred, dim=1).reshape(len(pred),-1)
        # pred  = F.softmax(pred, dim=1).permute(0, 2, 1).reshape(len(pred),-1)

        preds.append(pred.detach().cpu().numpy())

    preds = np.concatenate(preds)
    preds = pd.DataFrame(preds,index=df.index,columns=target_col)

    return preds

In [None]:
# test = df.loc[split0[1]].copy()

In [None]:
# test_pred = predict_dl(test, 
#                feat_col, 
#                target_col,
#                model_func=get_cnn, # model architecture
#                model_pth='test', # only name, not with .pth
#               )

In [None]:
# test_pred.columns

In [None]:
# pssm_pred = recover_pssm(test_pred.iloc[0])
# pssm_pred.sum()

In [None]:
# plot_heatmap(pssm_pred)

## CV train
> cross-validation

In [None]:
#| export
def train_dl_cv(df, 
                feat_col, 
                target_col, 
                splits, # list of tuples
                model_func, # functions like lambda x: return MLP_1(num_feat, num_target)
                save:str=None,
                **kwargs
                ):
    
    OOF = []
    
    for fold,split in enumerate(splits):

        print(f'------fold{fold}------')
        
        fname = f'{save}_fold{fold}' if save is not None else None
        
        
        # train model
        target, pred = train_dl(df,feat_col,target_col, split, model_func ,save=fname,**kwargs)

        pred['nfold'] = fold
        OOF.append(pred)
        

    # Concatenate OOF from each fold to a new dataframe
    oofs = pd.concat(OOF).sort_index()
    
    return oofs

In [None]:
# oof = train_dl_cv(df,feat_col,target_col,
#                   splits = splits,
#                   model_func = get_cnn,
#                   n_epoch=1,lr=3e-3,save='cnn')

In [None]:
# oof.nfold.value_counts()

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()