# Summary

# Imports

In [None]:
import concurrent.futures
import itertools
import importlib
import logging
import multiprocessing
import os
import os.path as op
import pickle
import subprocess
import sys
import tempfile
from functools import partial
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
import sqlalchemy as sa
import torch
import torch.nn as nn
import torch.nn.functional as F
from numba import njit, prange
from scipy import stats

import pagnn.models.dcn
from pagnn.datavargan import dataset_to_datavar
from pagnn.models.common import AdjacencyConv, SequenceConv, SequentialMod
from pagnn.utils import expand_adjacency_tensor, padding_amount, reshape_internal_dim
from pagnn.dataset import dataset_to_gan, row_to_dataset

from kmtools import py_tools, sequence_tools

In [None]:
%matplotlib inline

In [None]:
pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path('validation_protherm_dataset')
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
proc = subprocess.run(["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE)
GIT_REV = proc.stdout.decode().strip()
GIT_REV

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")
NETWORK_NAME = os.getenv("CI_COMMIT_SHA")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT, NETWORK_NAME

In [None]:
DEBUG = "CI" not in os.environ    
DEBUG

In [None]:
if DEBUG:
    NETWORK_NAME = ",".join([
        "7b4ff1af3ec63a01fa415435420c554be1fecbb0",  # test74
    ])
else:
    assert NETWORK_NAME is not None
    
NETWORK_NAME

In [None]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

# `DATAPKG`

In [None]:
DATAPKG = {}

In [None]:
DATAPKG['protherm_validaton_dataset'] = (
    Path(os.environ['DATAPKG_OUTPUT_DIR'])
    .joinpath("adjacency-net-v2", "v0.2", "protherm_dataset", "protherm_validaton_dataset.parquet")
)

# Load data

In [None]:
input_file = (
    DATAPKG['protherm_validaton_dataset']
    .resolve(strict=True)
)
input_file

In [None]:
input_df = pq.read_table(input_file).to_pandas()
input_df.head(2)

In [None]:
stats.spearmanr(input_df['cartesian_ddg_beta_nov16_cart_1'].values, input_df['ddg_exp'])

# Workflow

In [None]:
%run trained_networks.ipynb

In [None]:
TRAINED_NETWORKS[NETWORK_NAME]['network_state']

In [None]:
TRAINED_NETWORKS[NETWORK_NAME]

In [None]:
network_info = TRAINED_NETWORKS[NETWORK_NAME]['network_info']
network_file = TRAINED_NETWORKS[NETWORK_NAME]['network_file']
network_state = Path(TRAINED_NETWORKS[NETWORK_NAME]['network_state'])

runpy.run_path(network_file)

Net = getattr(pagnn.models.dcn, network_info["network_name"])
net = Net(**network_info["network_settings"])
net.load_state_dict(torch.load(network_state.as_posix()))
net.eval()

In [None]:
class ProthermTransferLearner(nn.Module):
    
    def __init__(self, master_model) -> None:
        super().__init__()

        self.master_model = master_model      
        self.master_model.eval()

        for param in self.master_model.parameters():
            param.requires_grad = False
        
        for param in self.master_model.layer_n.parameters():
            param.requires_grad = True

    def forward(self, seq, adjs):
        return self.master_model.forward(seq, adjs)
    
    def dataset_to_datavar(self, *args, **kwargs):
        return self.master_model.dataset_to_datavar(*args, **kwargs)

In [None]:
net_tl = ProthermTransferLearner(net)

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms

from pagnn.types import DataRow

In [None]:
        offset = np.random.randint(3, len(row_pos.sequence) - 3)

        datavar_pos = 
        datavar_neg = 
            dataset = dataset_to_gan(row_to_dataset(row, 0))
    datavar = net_tl.dataset_to_datavar(dataset)
    outputs = net_tl(datavar.seqs, [datavar.adjs])

    

In [None]:
transform = transforms.Compose([
    dataset_to_datavar
    

In [None]:
class TupleToDataSet:
    
    def __init__(self, dataset_to_datavar):
        self.dataset_to_datavar = dataset_to_datavar
        
    def __call__(self, tup):
        # DataRow
        row_pos = DataRow(
            sequence=tup.sequence,
            adjacency_idx_1=tup.adjacency_idx_1,
            adjacency_idx_2=tup.adjacency_idx_2,
            distances=tup.distances,
            target=0
        )
        row_neg = DataRow(
            sequence=tup.sequence_mut,
            adjacency_idx_1=tup.adjacency_idx_1,
            adjacency_idx_2=tup.adjacency_idx_2,
            distances=tup.distances,
            target=tup.ddg_exp,
        )

        # DataSet
        random_state = np.random.RandomState()
        dataset_pos = dataset_to_gan(row_to_dataset(row_pos, permute=True, random_state=random_state))
        dataset_neg = dataset_to_gan(row_to_dataset(row_neg, permute=True, random_state=random_state))

        return dataset_pos, dataset_neg, tup.ddg_exp
        assert dataset_pos.adjs == dataset_neg.adjs
        dataset = (
            dataset_pos.seqs + dataset_neg.seqs,
            dataset_neg.targets,
            dataset_neg.meta,
        )
        
        # DataVar
        datavar = self.dataset_to_datavar(dataset)
        return datavar

In [None]:
class ProthermData(Dataset):
    
    def __init__(self, input_file, transform) -> None:
        input_df = pq.read_table(input_file).to_pandas()
        input_df['sequence'] = input_df['qseq']
        # input_df['sequence_mut'] = input_df.apply(mutate_sequence, axis=1)
        input_df['sequence_mut'] = input_df['qseq_mutation']
        input_df['adjacency_idx_1'] = input_df['residue_idx_1_corrected']
        input_df['adjacency_idx_2'] = input_df['residue_idx_2_corrected']
        
        columns = ["sequence", "sequence_mut", "adjacency_idx_1", "adjacency_idx_2", "distances", "ddg_exp"]
        self.tuples = list(input_df[columns].itertuples())
        
        self.transform = transform

    def __len__(self):
        return len(self.tuples)
        
    def __getitem__(self, index):
        tup = self.tuples[index]
        datapoint = self.transform(tup)
        return datapoint

dataset = ProthermData(
    input_file,
    transform=transforms.Compose([
        TupleToDataSet(net_tl.dataset_to_datavar),
    ]),
)

In [None]:
row_pos

In [None]:
dataset_pos, dataset_neg, target = dataset[0]
dataset_pos.adjs

In [None]:
dataset_neg.adjs

In [None]:
len(dataset_pos.adjs)

In [None]:
# try transformations like shuffle by same amount
dataloader = DataLoader(dataset, batch_size=4,
                        shuffle=True, num_workers=4)


In [None]:
next(dataloader)

In [None]:
df = input_df[['sequence', 'sequence_mut', 'adjacency_idx_1', 'adjacency_idx_2', 'distances', 'ddg_exp']].copy()

from pagnn.dataset import dataset_to_gan, row_to_dataset

for row in df.itertuples():
    row_pos = DataRow
    dataset = dataset_to_gan(row_to_dataset(row, 0))
    datavar = net_tl.dataset_to_datavar(dataset)
    outputs = net_tl(datavar.seqs, [datavar.adjs])


In [None]:
input_df

# Run network

In [None]:
%run trained_networks.ipynb

In [None]:
def mutate_sequence(row):
    sequence = row['sequence']
    wt = row['mutation'][0]
    pos = int(row['mutation'][1:-1])
    mut = row['mutation'][-1]
    sequence_mut = sequence[:pos - 1] + mut + sequence[pos:]
    assert len(sequence) == len(sequence_mut)
    return sequence_mut

In [None]:
input_df['sequence'] = input_df['qseq']
input_df['sequence_mut'] = input_df.apply(mutate_sequence, axis=1)
# input_df['sequence_mut'] = input_df['qseq_mutation']
input_df['adjacency_idx_1'] = input_df['residue_idx_1_corrected']
input_df['adjacency_idx_2'] = input_df['residue_idx_2_corrected']

In [None]:
for network_name in NETWORK_NAME.split(','):
    input_df[f'{network_name}_wt'] = helper.predict_with_network(
        input_df[['sequence', 'adjacency_idx_1', 'adjacency_idx_2', 'distances']]
            .copy(),
        network_state=TRAINED_NETWORKS[network_name]['network_state'],
        network_info=TRAINED_NETWORKS[network_name]['network_info'],
    )
    input_df[f'{network_name}_mut'] = helper.predict_with_network(
        input_df[['sequence_mut', 'adjacency_idx_1', 'adjacency_idx_2', 'distances']]
            .rename(columns={'sequence_mut': 'sequence'}).copy(),
        network_state=TRAINED_NETWORKS[network_name]['network_state'],
        network_info=TRAINED_NETWORKS[network_name]['network_info'],
    )

In [None]:
for network_name in NETWORK_NAME.split(','):
    input_df[f'{network_name}_change'] = (
        input_df[f'{network_name}_mut'] -
        input_df[f'{network_name}_wt']
    )

## Save to cache

In [None]:
table = pa.Table.from_pandas(input_df, preserve_index=True)
pq.write_table(
    table,
    OUTPUT_PATH.joinpath("validation_protherm_dataset.parquet"),
    version='2.0',
    flavor='spark',
)

# Analyze