# Summary

Generate training and validation datasets.

----

# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable

In [None]:
%run _imports.ipynb

In [None]:
%run _settings.ipynb

In [None]:
%run _spark.ipynb

In [None]:
import random
from typing import NamedTuple

import h5py
from sklearn import metrics

In [None]:
import pagnn
importlib.reload(pagnn)

# Parameters

In [None]:
NOTEBOOK_NAME = 'mutation_datasets'
NOTEBOOK_PATH = Path(NOTEBOOK_NAME).absolute()
NOTEBOOK_PATH.mkdir(exist_ok=True)

In [None]:
os.environ['DATABIN_DIR']

# Datasets

## Gene3D domains

In [None]:
with open(f'generate_datasets/gene3d_domains.pickle', 'rb') as fin:
    GENE3D_DOMAINS = pickle.load(fin)

## Training / validation domains

In [None]:
with open(f'generate_datasets/training_domains.pickle', 'rb') as fin:
    TRAINING_DOMAINS = pickle.load(fin)
    
with open(f'generate_datasets/validation_domains.pickle', 'rb') as fin:
    VALIDATION_DOMAINS = pickle.load(fin)
    
with open(f'generate_datasets/test_domains.pickle', 'rb') as fin:
    TEST_DOMAINS = pickle.load(fin)

## Training / validation parquet files

In [None]:
with open(f'generate_datasets/training_parquet_files.pickle', 'rb') as fin:
    TRAINING_PARQUET_FILES = pickle.load(fin)
    
with open(f'generate_datasets/validation_parquet_files.pickle', 'rb') as fin:
    VALIDATION_PARQUET_FILES = pickle.load(fin)
    
with open(f'generate_datasets/test_parquet_files.pickle', 'rb') as fin:
    TEST_PARQUET_FILES = pickle.load(fin)

## Output folders

In [None]:
VALIDATION_DIR = NOTEBOOK_PATH.joinpath('validation')
VALIDATION_DIR.mkdir(exist_ok=True)
VALIDATION_DIR

In [None]:
PROTHERM_DIR = NOTEBOOK_PATH.joinpath('protherm')
PROTHERM_DIR.mkdir(exist_ok=True)
PROTHERM_DIR

In [None]:
HUMSAVAR_DIR = NOTEBOOK_PATH.joinpath('humsavar')
HUMSAVAR_DIR.mkdir(exist_ok=True)
HUMSAVAR_DIR

# Datasets

## External validation dataset

### Uniparc xref

In [None]:
uniparc_xref_file = op.join(os.environ['DATABIN_DIR'], 'uniparc', 'v0.1.0', 'uniparc_xref.parquet')
uniparc_xref_file

In [None]:
!ls {uniparc_xref_file} -lSh

In [None]:
ds = spark.sql(f"""\
SELECT *
FROM parquet.`{uniparc_xref_file}`
""")

In [None]:
ds.limit(10).toPandas()

In [None]:
# 'UniProtKB/Swiss-Prot' | 'UniProtKB/TrEMBL'
# active = 'Y'
# db_id = ''

### Adjacency matrix

In [None]:
ADJACENCY_MATRIX_PATH

In [None]:
ds = spark.sql(f"""\
SELECT *
FROM parquet.`{ADJACENCY_MATRIX_PATH}`
""")

In [None]:
ds.limit(10).toPandas()

### Protherm dataset

In [None]:
protherm_file = (
    op.join(os.environ['DATABIN_DIR'], 'protein_folding_energy', 'v0.1.0', 'protherm_star.parquet')
)

#### Examine file

In [None]:
protherm = pq.read_table(protherm_file).to_pandas().set_index('__index_level_0__')

In [None]:
display(protherm.head())
print(protherm.shape)

In [None]:
protherm = protherm.dropna(subset=['uniprot_id', 'uniprot_mutation', 'ddg_exp'])

In [None]:
protherm.shape

#### Spark query

In [None]:
ADJACENCY_MATRIX_PATH

In [None]:
ds = spark.sql(f"""\
SELECT
    ds.uniprot_id, ds.uniprot_mutation, ds.ddg_exp,
    
    xref.uniparc_id,
    
    ud.sequence, ud.database_id, ud.domain_start, ud.domain_end, ud.__index_level_0__ domain_index,
       
    ud.structure_id, ud.model_id, ud.chain_id,
    ud.pc_identity, ud.alignment_length, ud.mismatches, ud.gap_opens, 
    ud.q_start, ud.q_end, ud.s_start, ud.s_end,
    
    ud.qseq, ud.sseq,
    ud.residue_idx_1_corrected, ud.residue_idx_2_corrected

FROM parquet.`{protherm_file}` ds
JOIN parquet.`{uniparc_xref_file}` xref ON (uniprot_id = db_id)
JOIN parquet.`{ADJACENCY_MATRIX_PATH}` ud USING (uniparc_id)
WHERE (xref.db_type = 'UniProtKB/Swiss-Prot' OR xref.db_type = 'UniProtKB/TrEMBL')
AND xref.active = 'Y'
""")

In [None]:
ds.write.parquet(
    PROTHERM_DIR.joinpath('protherm_2.parquet').as_posix(),
    mode='overwrite',
)

#### Remove mutations outside domains

In [None]:
protherm_df = pq.ParquetDataset(PROTHERM_DIR.joinpath('protherm_2.parquet').as_posix()).read().to_pandas()

In [None]:
display(protherm_df.head(2))
print(protherm_df.shape[0])

In [None]:
importlib.reload(pagnn.utils)

In [None]:
protherm_df = pagnn.filter_mismatch_mutations(protherm_df)

In [None]:
protherm_df.shape[0]

In [None]:
plt.hist(protherm_df['ddg_exp'])

In [None]:
pq.write_table(
    pa.Table.from_pandas(protherm_df, preserve_index=False),
    NOTEBOOK_PATH.joinpath('protherm_validaton_dataset.parquet').as_posix(),
    version='2.0',
    flavor='spark',
)

In [None]:
PROTHERM_DIR

### Humsavar dataset

In [None]:
humsavar_file = (
    op.join(os.environ['DATABIN_DIR'], 'mutation_sets', 'v0.1.0', 'humsavar.parquet')
)

#### Examine file

In [None]:
humsavar = pq.read_table(humsavar_file).to_pandas().set_index('__index_level_0__')

In [None]:
display(humsavar.head())
print(humsavar.shape)

In [None]:
humsavar = humsavar.dropna(subset=['uniprot_id', 'uniprot_mutation', 'type_of_variant'])

In [None]:
humsavar.shape

#### Spark query

In [None]:
ds = spark.sql(f"""\
SELECT
    ds.uniprot_id, ds.uniprot_mutation, ds.type_of_variant,
    
    xref.uniparc_id,
    
    ud.sequence, ud.database_id, ud.domain_start, ud.domain_end, ud.__index_level_0__ domain_index,
       
    ud.structure_id, ud.model_id, ud.chain_id,
    ud.pc_identity, ud.alignment_length, ud.mismatches, ud.gap_opens, 
    ud.q_start, ud.q_end, ud.s_start, ud.s_end,
    
    ud.qseq, ud.sseq,
    ud.residue_idx_1_corrected, ud.residue_idx_2_corrected

FROM parquet.`{humsavar_file}` ds
JOIN parquet.`{uniparc_xref_file}` xref ON (uniprot_id = db_id)
JOIN parquet.`{ADJACENCY_MATRIX_PATH}` ud USING (uniparc_id)
WHERE (xref.db_type = 'UniProtKB/Swiss-Prot' OR xref.db_type = 'UniProtKB/TrEMBL')
AND xref.active = 'Y'
""")

In [None]:
# ds.limit(10).toPandas()

In [None]:
ds.write.parquet(
    HUMSAVAR_DIR.joinpath('humsavar_2.parquet').as_posix(),
    mode='overwrite',
)

#### Remove mutations outside domains

In [None]:
humsavar_df = pq.ParquetDataset(HUMSAVAR_DIR.joinpath('humsavar_2.parquet').as_posix()).read().to_pandas()

In [None]:
display(humsavar_df.head(2))
print(humsavar_df.shape[0])

In [None]:
importlib.reload(pagnn.utils)

In [None]:
humsavar_df = pagnn.filter_mismatch_mutations(humsavar_df)

In [None]:
humsavar_df.shape[0]

In [None]:
Counter(humsavar_df['type_of_variant'])

In [None]:
humsavar_df = humsavar_df[humsavar_df['type_of_variant'].isin({'Disease', 'Polymorphism'})]

In [None]:
print(humsavar_df.shape[0])

In [None]:
humsavar_df['score_exp'] = (humsavar_df['type_of_variant'] == 'Disease').astype(int)

In [None]:
pq.write_table(
    pa.Table.from_pandas(humsavar_df, preserve_index=False),
    NOTEBOOK_PATH.joinpath('humsavar_validaton_dataset.parquet').as_posix(),
    version='2.0',
    flavor='spark',
)