# Summary

# Imports

In [1]:
import concurrent.futures
import itertools
import importlib
import logging
import multiprocessing
import os
import os.path as op
import pickle
import subprocess
import sys
import tempfile
from functools import partial
from pathlib import Path

from kmtools import py_tools, sequence_tools

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
import sqlalchemy as sa
import torch
import torch.nn as nn
import torch.nn.functional as F
from numba import njit, prange
from scipy import stats

from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

import pagnn.models.dcn
from pagnn.datavargan import dataset_to_datavar
from pagnn.models.common import AdjacencyConv, SequenceConv, SequentialMod
from pagnn.utils import expand_adjacency_tensor, padding_amount, reshape_internal_dim
from pagnn.dataset import dataset_to_gan, row_to_dataset

from kmtools import py_tools, sequence_tools

from torch.utils.data import Dataset
from torchvision import transforms

from pagnn.types import DataRow, DataSetGAN

In [2]:
%matplotlib inline

In [3]:
pd.set_option("max_columns", 100)

In [4]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

<module 'helper' from '/home/kimlab1/database_data/datapkg/adjacency-net-v2/src/helper/__init__.py'>

# Parameters

In [5]:
NOTEBOOK_PATH = Path('validation_protherm_dataset')
NOTEBOOK_PATH

PosixPath('validation_protherm_dataset')

In [6]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

PosixPath('/home/kimlab1/database_data/datapkg/adjacency-net-v2/notebooks/validation_protherm_dataset')

In [7]:
proc = subprocess.run(["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE)
GIT_REV = proc.stdout.decode().strip()
GIT_REV

'16ca70d'

In [8]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")
NETWORK_NAME = os.getenv("CI_COMMIT_SHA")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT, NETWORK_NAME

(None, None, None)

In [9]:
DEBUG = "CI" not in os.environ    
DEBUG

True

In [10]:
if DEBUG:
    NETWORK_NAME = ",".join([
        "7b4ff1af3ec63a01fa415435420c554be1fecbb0",  # test74
    ])
else:
    assert NETWORK_NAME is not None
    
NETWORK_NAME

'7b4ff1af3ec63a01fa415435420c554be1fecbb0'

In [11]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

## `DATAPKG`

In [12]:
DATAPKG = {}

In [13]:
DATAPKG["protherm_validaton_dataset"] = Path(os.environ["DATAPKG_OUTPUT_DIR"]).joinpath(
    "adjacency-net-v2", "v0.2", "protherm_dataset", "protherm_validaton_dataset.parquet"
)

# Load data

In [14]:
input_file = DATAPKG["protherm_validaton_dataset"].resolve(strict=True)
input_file

PosixPath('/home/kimlab1/database_data/datapkg_output_dir/adjacency-net-v2/v0.2/protherm_dataset/protherm_validaton_dataset.parquet')

In [15]:
input_df = pq.read_table(input_file).to_pandas()
input_df.head(2)

Unnamed: 0,filename_wt,chain_id,mutation,cartesian_ddg_beta_nov15_cart_1,ddg_exp,cartesian_ddg_beta_nov16_cart_1,cartesian_ddg_score12_cart_1,cartesian_ddg_talaris2013_cart_1,cartesian_ddg_talaris2014_cart_1,ddg_monomer_soft_rep_design_1,local_filename_wt,structure_id,model_id,qseq,residue_idx_1_corrected,residue_idx_2_corrected,distances,mutation_matches_sequence,qseq_mutation
0,/home/kimlab2/database_data/biological-data-wa...,A,G44S,-1.808667,-0.53,-0.701,0.088,-0.289667,-0.633667,-2.384,/home/kimlab1/database_data/datapkg/adjacency-...,107l,0,MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKGEL...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 66, 70, 91...","[1.3463744749991646, 4.7287436401778065, 6.389...",True,MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSEL...
1,/home/kimlab2/database_data/biological-data-wa...,A,A120M,2.617667,-0.2,0.354,0.56,-0.069,-0.188,2.472,/home/kimlab1/database_data/datapkg/adjacency-...,160l,0,MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSEL...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 66, 70, 91, 92...","[1.345760255620924, 4.727237880524551, 6.42311...",True,MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSEL...


In [16]:
stats.spearmanr(input_df['cartesian_ddg_beta_nov16_cart_1'].values, input_df['ddg_exp'])

SpearmanrResult(correlation=0.5913166222602743, pvalue=0.0)

# Workflow

## Load master network

In [17]:
%run trained_networks.ipynb

In [18]:
TRAINED_NETWORKS[NETWORK_NAME]

{'network_state': PosixPath('/home/kimlab1/database_data/datapkg_output_dir/adjacency-net-v2/7b4ff1af3ec63a01fa415435420c554be1fecbb0/train_network/models/model_000000006753.state'),
 'network_file': PosixPath('/home/kimlab1/database_data/datapkg_output_dir/adjacency-net-v2/7b4ff1af3ec63a01fa415435420c554be1fecbb0/train_network/model.py'),
 'network_info': {'network_name': 'DCN_7b4ff1af3ec63a01fa415435420c554be1fecbb0',
  'network_settings': {}},
 'stats_db': PosixPath('/home/kimlab1/database_data/datapkg_output_dir/adjacency-net-v2/7b4ff1af3ec63a01fa415435420c554be1fecbb0/train_network/stats.db')}

In [19]:
network_info = TRAINED_NETWORKS[NETWORK_NAME]['network_info']
network_file = TRAINED_NETWORKS[NETWORK_NAME]['network_file']
network_state = Path(TRAINED_NETWORKS[NETWORK_NAME]['network_state'])

runpy.run_path(network_file)

Net = getattr(pagnn.models.dcn, network_info["network_name"])
net = Net(**network_info["network_settings"])
net.load_state_dict(torch.load(network_state.as_posix()))
net.eval()

DCN_7b4ff1af3ec63a01fa415435420c554be1fecbb0(
  (layer_1): SequentialMod(
    (0): PairwiseConv(
      (seq_cart_barcode_model): DistanceNet(
        (linear1): Linear(in_features=2, out_features=64, bias=True)
        (linear2): Linear(in_features=64, out_features=12, bias=True)
      )
      (spatial_conv): Conv1d(26, 64, kernel_size=(2,), stride=(2,), bias=False)
    )
    (1): ReLU()
  )
  (layer_n): Sequential(
    (0): Conv1d(64, 128, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): RepeatPad()
    (2): MaxPool1d(kernel_size=64, stride=64, padding=0, dilation=1, ceil_mode=False)
    (3): Conv1d(128, 1, kernel_size=(1,), stride=(1,))
  )
)

## Define TL network

In [20]:
class TupleToDataSet:
    
    def __init__(self, dataset_to_datavar):
        self.dataset_to_datavar = dataset_to_datavar
        
    def __call__(self, tup):
        # DataRow
        row_pos = DataRow(
            sequence=tup.sequence,
            adjacency_idx_1=tup.adjacency_idx_1,
            adjacency_idx_2=tup.adjacency_idx_2,
            distances=tup.distances,
            target=0
        )
        row_neg = DataRow(
            sequence=tup.sequence_mut,
            adjacency_idx_1=tup.adjacency_idx_1,
            adjacency_idx_2=tup.adjacency_idx_2,
            distances=tup.distances,
            target=tup.ddg_exp,
        )

        # DataSet
        permute_offset = pagnn.dataset.get_offset(len(tup.sequence.replace('-', '')), np.random.RandomState())
        dataset_pos = dataset_to_gan(row_to_dataset(row_pos, permute_offset=permute_offset))
        dataset_neg = dataset_to_gan(row_to_dataset(row_neg, permute_offset=permute_offset))

        assert dataset_pos.adjs == dataset_neg.adjs
        dataset = DataSetGAN(
            dataset_pos.seqs + dataset_neg.seqs,
            dataset_neg.adjs,
            dataset_neg.meta,
        )
        
        # DataVar
        datavar = self.dataset_to_datavar(dataset)
        return datavar, torch.tensor([tup.ddg_exp], dtype=torch.float32)

In [21]:
class ProthermData(Dataset):
    
    def __init__(self, input_file, transform) -> None:
        input_df = pq.read_table(input_file).to_pandas()
        input_df['sequence'] = input_df['qseq']
        # input_df['sequence_mut'] = input_df.apply(mutate_sequence, axis=1)
        input_df['sequence_mut'] = input_df['qseq_mutation']
        input_df['adjacency_idx_1'] = input_df['residue_idx_1_corrected']
        input_df['adjacency_idx_2'] = input_df['residue_idx_2_corrected']
        
        columns = ["sequence", "sequence_mut", "adjacency_idx_1", "adjacency_idx_2", "distances", "ddg_exp"]
        self.tuples = list(input_df[columns].itertuples())
        
        self.transform = transform

    def __len__(self):
        return len(self.tuples)
        
    def __getitem__(self, index):
        tup = self.tuples[index]
        datapoint = self.transform(tup)
        return datapoint

In [22]:
class ProthermTransferLearner(nn.Module):
    
    def __init__(self, master_model) -> None:
        super().__init__()

        self.master_model = copy.deepcopy(master_model)      
        self.master_model.eval()

        for param in self.master_model.parameters():
            param.requires_grad = False
        
        for param in self.master_model.layer_n.parameters():
            param.requires_grad = True
            
#         for param in self.master_model.layer_n[3].parameters():
#             param.requires_grad = True
        
        input_size = self.master_model.hidden_size
        if False:
            input_size *= 2  # For wt and mut
            hidden_size = input_size * 2
        else:
            hidden_size = input_size
        

        self.layer_n = nn.Sequential(
#             nn.Conv1d(
#                 input_size,
#                 hidden_size,
#                 kernel_size=self.master_model.kernel_size,
#                 stride=self.master_model.stride,
#                 padding=self.master_model.padding,
#                 bias=True,
#             ),
            nn.MaxPool1d(4000),
            nn.Conv1d(hidden_size, 1, kernel_size=1, bias=True),
        )

    def forward(self, seq, adjs):
        x_wt = seq[0:1]
        x_wt = self.master_model.layer_1(x_wt, adjs[0][0])

        x_mut = seq[1:2]
        x_mut = self.master_model.layer_1(x_mut, adjs[0][0])

        if False:
            x = torch.cat([x_wt, x_mut], dim=1)
            x = self.layer_n(x)
        elif False:
            x_wt = self.master_model.layer_n(x_wt)
            x_mut = self.master_model.layer_n(x_mut)
            x = (x_wt - x_mut).sum()
        else:
            x = x_wt - x_mut
            x = self.master_model.layer_n(x)
            x = x.sum()

        # Layer N
        return x
    
    def dataset_to_datavar(self, *args, **kwargs):
        return self.master_model.dataset_to_datavar(*args, **kwargs)

## Train new network

In [23]:
net_tl = ProthermTransferLearner(net).cuda()

In [24]:
dataset = ProthermData(
    input_file,
    transform=transforms.Compose([
        TupleToDataSet(net_tl.dataset_to_datavar),
    ]),
)

In [25]:
dv, target = dataset[0]

for dv, target in dataset:
    out = net_tl(dv.seqs.cuda(), [[dv.adjs[0].cuda()]])
    print(out.squeeze(), target)
    break

tensor(-1.3797, device='cuda:0', grad_fn=<SqueezeBackward0>) tensor([-0.5300])


In [26]:
dataset_train, dataset_val = torch.utils.data.random_split(
    dataset, [int(len(dataset) * 0.70), len(dataset) - int(len(dataset) * 0.70)]
)

dataset_sizes = {
    'train': len(dataset_train),
    'val': len(dataset_val),
}

dataloaders = {
    "train": DataLoader(dataset_train, batch_size=64, shuffle=False, num_workers=0, collate_fn=list),
    "val": DataLoader(dataset_train, batch_size=1, shuffle=False, num_workers=0, collate_fn=list),
}

In [27]:
out = next(iter(dataloaders['train']))

In [28]:
len(out)

64

In [29]:
device = torch.device("cuda")


def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    model.eval()
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs - 1))
        print("-" * 10)

        optimizer.zero_grad()

        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            #             if phase == 'train':
            #                 if scheduler is not None:
            #                     scheduler.step()
            #                 model.train()  # Set model to training mode
            #             else:
            #                 model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            pred_list = []
            target_list = []

            # Iterate over data.
            for i, batch in enumerate(dataloaders[phase]):
                for j, (dv, target) in enumerate(batch):
                    dv = dv._replace(seqs=dv.seqs.cuda(), adjs=[dv.adjs[0].cuda()])
                    target = target.to(device)

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == "train"):
                        preds = model(dv.seqs, [dv.adjs])
                        if False:
                            preds_diff = preds[0] - preds[1]
                            preds_diff_sum = preds_diff.sum().squeeze()
                        else:
                            preds_diff_sum = preds.squeeze()

                        loss = criterion(preds_diff_sum, target)

                    # statistics
                    running_loss += loss.item()
                    running_corrects += torch.mean(torch.abs(preds_diff_sum - target.data))
                    pred_list.append(preds_diff_sum.cpu().data.numpy())
                    target_list.append(target.cpu().data.numpy())

                # backward + optimize only if in training phase
                if phase == "train":
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            pearson_corr = stats.pearsonr(np.hstack(pred_list), np.hstack(target_list))[0]
            spearman_corr = stats.spearmanr(np.hstack(pred_list), np.hstack(target_list))[0]

            print(
                f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} "
                f"Pearson {pearson_corr:.4f} Spearman {spearman_corr:.4f}"
            )

            # deep copy the model
            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
    print("Best val Acc: {:4f}".format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [30]:
net_tl = ProthermTransferLearner(net)
net_tl = net_tl.to(device="cuda")

criterion = nn.L1Loss()

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(net_tl.master_model.layer_n.parameters(), lr=0.001)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [31]:
model_ft = train_model(net_tl, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)

Epoch 0/24
----------
train Loss: 3.1717 Acc: 3.1717 Pearson 0.0957 Spearman 0.0739


KeyboardInterrupt: 

In [None]:
ds = dataset[4]
net_tl(ds.seqs, [ds.adjs]).mean(2).squeeze()

In [None]:
net_tl.

In [None]:
for ds in dataset:
    print(net_tl(ds.seqs, [ds.adjs]))

In [None]:
dataset_neg.adjs

In [None]:
len(dataset_pos.adjs)

In [None]:
# try transformations like shuffle by same amount
dataloader = DataLoader(dataset, batch_size=4,
                        shuffle=True, num_workers=4)


In [None]:
next(dataloader)

In [None]:
df = input_df[['sequence', 'sequence_mut', 'adjacency_idx_1', 'adjacency_idx_2', 'distances', 'ddg_exp']].copy()

from pagnn.dataset import dataset_to_gan, row_to_dataset

for row in df.itertuples():
    row_pos = DataRow
    dataset = dataset_to_gan(row_to_dataset(row, 0))
    datavar = net_tl.dataset_to_datavar(dataset)
    outputs = net_tl(datavar.seqs, [datavar.adjs])


In [None]:
input_df

# Run network

In [None]:
%run trained_networks.ipynb

In [None]:
def mutate_sequence(row):
    sequence = row['sequence']
    wt = row['mutation'][0]
    pos = int(row['mutation'][1:-1])
    mut = row['mutation'][-1]
    sequence_mut = sequence[:pos - 1] + mut + sequence[pos:]
    assert len(sequence) == len(sequence_mut)
    return sequence_mut

In [None]:
input_df['sequence'] = input_df['qseq']
input_df['sequence_mut'] = input_df.apply(mutate_sequence, axis=1)
# input_df['sequence_mut'] = input_df['qseq_mutation']
input_df['adjacency_idx_1'] = input_df['residue_idx_1_corrected']
input_df['adjacency_idx_2'] = input_df['residue_idx_2_corrected']

In [None]:
for network_name in NETWORK_NAME.split(','):
    input_df[f'{network_name}_wt'] = helper.predict_with_network(
        input_df[['sequence', 'adjacency_idx_1', 'adjacency_idx_2', 'distances']]
            .copy(),
        network_state=TRAINED_NETWORKS[network_name]['network_state'],
        network_info=TRAINED_NETWORKS[network_name]['network_info'],
    )
    input_df[f'{network_name}_mut'] = helper.predict_with_network(
        input_df[['sequence_mut', 'adjacency_idx_1', 'adjacency_idx_2', 'distances']]
            .rename(columns={'sequence_mut': 'sequence'}).copy(),
        network_state=TRAINED_NETWORKS[network_name]['network_state'],
        network_info=TRAINED_NETWORKS[network_name]['network_info'],
    )

In [None]:
for network_name in NETWORK_NAME.split(','):
    input_df[f'{network_name}_change'] = (
        input_df[f'{network_name}_mut'] -
        input_df[f'{network_name}_wt']
    )

## Save to cache

In [None]:
table = pa.Table.from_pandas(input_df, preserve_index=True)
pq.write_table(
    table,
    OUTPUT_PATH.joinpath("validation_protherm_dataset.parquet"),
    version='2.0',
    flavor='spark',
)

# Analyze