# Summary

# Imports

In [None]:
import importlib
import os
import sys
from collections import Counter
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
%matplotlib inline

pd.set_option("max_columns", 100)

In [None]:
SRC_PATH = Path.cwd().joinpath('..', 'src').resolve(strict=True)

if SRC_PATH.as_posix() not in sys.path:
    sys.path.insert(0, SRC_PATH.as_posix())

import helper
importlib.reload(helper)

# Parameters

In [None]:
NOTEBOOK_PATH = Path(os.getenv("CI_JOB_NAME", "add_adjacency_distances_test"))
NOTEBOOK_PATH

In [None]:
OUTPUT_PATH = Path(os.getenv('OUTPUT_DIR', NOTEBOOK_PATH.name)).resolve()
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH

In [None]:
DEBUG = "CI" not in os.environ    

TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT

In [None]:
if DEBUG:
    %load_ext autoreload
    %autoreload 2

# `DATAPKG`

In [None]:
DATAPKG = {
    'training_dataset': 
        Path(os.environ['DATAPKG_OUTPUT_DIR']).joinpath(
            "adjacency-net-v2", "master", "training_dataset"),
    'training_dataset_wdistances':
        Path(os.environ['DATAPKG_OUTPUT_DIR']).joinpath(
            "adjacency-net-v2", "master", "training_dataset_wdistances"),
    'pdb_mmcif_ffindex':
        Path(os.environ['DATAPKG_OUTPUT_DIR']).joinpath(
            "pdb-ffindex", "master", "pdb_mmcif_ffindex", "pdb-mmcif"),
}

# Load data

In [None]:
parquet_files = list(
    DATAPKG['training_dataset_wdistances'].joinpath("adjacency_matrix.parquet").glob("*/*.parquet")
)

In [None]:
dfs = []

for i, parquet_file in enumerate(parquet_files[:10]):
    print(i)
    file_obj = pq.ParquetFile(parquet_file)
    df = file_obj.read_row_group(0).to_pandas()
    dfs.append(df)

In [None]:
master_df = pd.concat(dfs)

## Save

In [None]:
output_file = OUTPUT_PATH.joinpath("example_rows.parquet")

if not output_file.is_file():
    table = pa.Table.from_pandas(master_df, preserve_index=False)
    pq.write_table(table, output_file)

# Process data

## Get distances

In [None]:
output_file = OUTPUT_PATH.joinpath("example_rows.parquet")

master_df = pq.read_table(output_file).to_pandas()

In [None]:
master_df = master_df[
    (master_df['residue_idx_1_corrected'].notnull())
].copy()

### `get_aa_distances`

In [None]:
def get_aa_distances(seq, residue_idx_1_corrected, residue_idx_2_corrected):
    arr1 = np.array(residue_idx_1_corrected)
    arr2 = np.array(residue_idx_2_corrected)
    aa_distances = np.hstack([np.zeros(len(seq), dtype=np.int), np.abs(arr1 - arr2)])
    return aa_distances

In [None]:
master_df["aa_distances"] = [
    get_aa_distances(seq, idx1, idx2)
    for seq, idx1, idx2
    in master_df[["sequence", "residue_idx_1_corrected", "residue_idx_2_corrected"]].values
]

In [None]:
master_df["aa_distances"].head()

In [None]:
aa_distances = np.hstack(master_df["aa_distances"].values)

### `get_cart_distances`

In [None]:
def get_cart_distances(seq, distances):
    cart_distances = np.hstack([np.zeros(len(seq), dtype=np.float), distances])
    return cart_distances

In [None]:
master_df["cart_distances"] = [
    get_cart_distances(seq, distances)
    for seq, distances
    in master_df[["sequence", "distances"]].values
]

In [None]:
master_df["cart_distances"].head()

In [None]:
cart_distances = np.hstack(master_df["cart_distances"].values)

## Validate

In [None]:
assert len(aa_distances) == len(cart_distances)

## Shuffle

In [None]:
indices = np.arange(len(aa_distances))
np.random.RandomState(42).shuffle(indices)
np.random.RandomState(42).shuffle(aa_distances)
np.random.RandomState(42).shuffle(cart_distances)

In [None]:
assert (cart_distances[:10_000][aa_distances[:10_000] == 0] == 0).all()

## Save

In [None]:
np.save(OUTPUT_PATH.joinpath("aa_distances.npy"), aa_distances)

In [None]:
np.save(OUTPUT_PATH.joinpath("cart_distances.npy"), cart_distances)

# Machine learning

## Load data

In [None]:
aa_distances = np.load(OUTPUT_PATH.joinpath("aa_distances.npy"))

In [None]:
cart_distances = np.load(OUTPUT_PATH.joinpath("cart_distances.npy"))

## Functions

### `gen_barcode`

In [None]:
from numba import njit, prange

@njit(parallel=True)
def gen_barcode(distances, bins):
    barcode = np.zeros((len(distances), len(bins)), dtype=np.int32)
    for i in prange(len(distances)):
        a = distances[i]
        for j in range(len(bins)):
            if a < bins[j]:
                barcode[i, j] = 1
                break
    return barcode

## Normalize seq distances

In [None]:
plt.hist(np.clip(aa_distances, 0, 100), bins=50)
plt.xlabel("Amino acid distance")
plt.label("Number of amino acid pairs")
None

In [None]:
def normalize_seq_distances(aa_distances):
    aa_distances_log_mean = 3.5567875815104903
    aa_distances_log_std = 1.7065822763411669

    with np.errstate(divide='ignore'):
        aa_distances_log = np.where(aa_distances > 0, np.log(aa_distances) + 1, 0)
        
    aa_distances_corrected = (aa_distances_log - aa_distances_log_mean) / aa_distances_log_std
    return aa_distances_corrected

In [None]:
aa_distances_corrected = normalize_seq_distances(aa_distances)

In [None]:
assert np.isclose(aa_distances_corrected.mean(), 0)
assert np.isclose(aa_distances_corrected.std(), 1)

In [None]:
plt.hist(aa_distances_corrected, bins=50)
plt.xlabel("Amino acid distance (normalized)")
plt.ylabel("Number of amino acid pairs")
None

## Bin seq distances

In [None]:
aa_quantile_custom = normalize_seq_distances(np.array([1, 4, 8, 14, 32, 100_000])).tolist()
aa_quantile_custom

In [None]:
aa_quantile_2 = np.quantile(aa_distances_corrected, np.linspace(0, 1, 3)[1:]).tolist()
aa_quantile_4 = np.quantile(aa_distances_corrected, np.linspace(0, 1, 5)[1:]).tolist()
aa_quantile_6 = np.quantile(aa_distances_corrected, np.linspace(0, 1, 7)[1:]).tolist()

aa_quantile_2[-1] = aa_quantile_custom[-1]
aa_quantile_4[-1] = aa_quantile_custom[-1]
aa_quantile_6[-1] = aa_quantile_custom[-1]

aa_quantile_2, aa_quantile_4, aa_quantile_6

In [None]:
%timeit -n 1 -r 3 gen_barcode(aa_distances_corrected, aa_quantile_custom)

In [None]:
%timeit -n 1 -r 3 gen_barcode(aa_distances_corrected, aa_quantile_6)

## Normalize cart distances

In [None]:
plt.hist(np.clip(cart_distances, 0, 14), bins=50)
plt.xlabel("Euclidean distance")
plt.ylabel("Number of amino acid pairs")
None

In [None]:
def normalize_cart_distances(cart_distances):
    cart_distances_mean = 6.9936892028873965
    cart_distances_std = 3.528368101492991
    
    cart_distances_corrected = (cart_distances - cart_distances_mean) / cart_distances_std
    return cart_distances_corrected

In [None]:
cart_distances_corrected = normalize_cart_distances(cart_distances)

In [None]:
assert np.isclose(cart_distances_corrected.mean(), 0)
assert np.isclose(cart_distances_corrected.std(), 1)

In [None]:
plt.hist(cart_distances_corrected, bins=50)
plt.xlabel("Euclidean distance (normalized)")
plt.ylabel("Number of amino acid pairs")
None

## Bin cart distances

In [None]:
cart_quantile_custom = normalize_cart_distances(np.array([1.0, 2.0, 4.0, 6.2, 8.5, 100_000.0])).tolist()
cart_quantile_custom

In [None]:
cart_quantile_2 = np.quantile(cart_distances_corrected, np.linspace(0, 1, 3)[1:]).tolist()
cart_quantile_4 = np.quantile(cart_distances_corrected, np.linspace(0, 1, 5)[1:]).tolist()
cart_quantile_6 = np.quantile(cart_distances_corrected, np.linspace(0, 1, 7)[1:]).tolist()

cart_quantile_2[-1] = cart_quantile_custom[-1]
cart_quantile_4[-1] = cart_quantile_custom[-1]
cart_quantile_6[-1] = cart_quantile_custom[-1]

cart_quantile_2, cart_quantile_4, cart_quantile_6

In [None]:
%timeit -n 1 -r 3 gen_barcode(cart_distances, cart_quantile_6)

## Networks

In [None]:
class DistanceNet(nn.Module):
    
    def __init__(self, barcode_size, hidden_layer_size=32):
        super().__init__()

        self.barcode_size = barcode_size
        self.hidden_layer_size = hidden_layer_size

        self.linear1 = nn.Linear(1, self.hidden_layer_size)
        self.linear2 = nn.Linear(self.hidden_layer_size, self.barcode_size)
        
#         self.reset_parameters()

    def reset_parameters(self):
        stdv = np.sqrt(self.hidden_layer_size)
        self.linear1.weight.data.normal_(0, stdv)
        self.linear1.bias.data.normal_(0, stdv)

    def forward(self, x):
        x = self.linear1(x)
        x = F.leaky_relu(x)

        x = self.linear2(x)
        x = F.leaky_relu(x)

        return x

In [None]:
class TwoDistanceNet(nn.Module):
    
    def __init__(self, barcode_size, hidden_layer_size=64):
        super().__init__()
        
        self.hidden_layer_size = hidden_layer_size
        self.barcode_size = barcode_size

        self.linear1 = nn.Linear(2, self.hidden_layer_size)
        self.linear2 = nn.Linear(self.hidden_layer_size, self.barcode_size)
        
#         self.reset_parameters()

    def reset_parameters(self):
        stdv = np.sqrt(self.hidden_layer_size)
        self.linear1.weight.data.normal_(0, stdv)
        self.linear1.bias.data.normal_(0, stdv)

    def forward(self, x):
        x = self.linear1(x)
        x = F.leaky_relu(x)

        x = self.linear2(x)
        x = F.leaky_relu(x)

        return x

In [None]:
np.sqrt(2 / 64)

In [None]:
aa_distances_onehot = gen_barcode(aa_distances_corrected, aa_quantile_custom).astype(np.float32)
assert (aa_distances_onehot.sum(axis=1) == 1).all()

cart_distances_onehot = gen_barcode(cart_distances_corrected, cart_quantile_custom).astype(np.float32)
assert (cart_distances_onehot.sum(axis=1) == 1).all()

In [None]:
model_a = DistanceNet(6)
model_c = DistanceNet(6)
model_ac = TwoDistanceNet(12)

learning_rate = 1e-4  #0.00005
betas = (0.5, 0.9)

optimizer_a = torch.optim.Adam(model_a.parameters(), lr=learning_rate, betas=betas)
optimizer_c = torch.optim.Adam(model_c.parameters(), lr=learning_rate, betas=betas)
optimizer_ac = torch.optim.Adam(model_ac.parameters(), lr=learning_rate, betas=betas)

# loss_fn = nn.BCELoss()
loss_fn = nn.MSELoss()

batch_size = 64

#
losses_a = []
losses_c = []
losses_ac = []

In [None]:
for t in range(30_000):
    if t % 2_000 == 0:
        print(t)
    
    t_slice = slice(t * batch_size, (t + 1) * batch_size)

    X_a = torch.from_numpy(aa_distances_corrected[t_slice]).to(torch.float32).unsqueeze(1)
    X_c = torch.from_numpy(cart_distances_corrected[t_slice]).to(torch.float32).unsqueeze(1)
    X_ac = torch.cat([X_a, X_c], 1)

    Y_a = torch.from_numpy(aa_distances_onehot[t_slice, :])
    Y_c = torch.from_numpy(cart_distances_onehot[t_slice, :])
    Y_ac = torch.cat([Y_a, Y_c], 1)

    Y_a_pred = model_a(X_a)
    Y_c_pred = model_c(X_c)
    Y_ac_pred = model_ac(X_ac)

    loss_a = loss_fn(Y_a_pred, Y_a)
    losses_a.append(loss_a.detach().data.numpy())

    loss_c = loss_fn(Y_c_pred, Y_c)
    losses_c.append(loss_c.detach().data.numpy())

    loss_ac = loss_fn(Y_ac_pred, Y_ac)
    losses_ac.append(loss_ac.detach().data.numpy())

    optimizer_a.zero_grad()
    optimizer_c.zero_grad()
    optimizer_ac.zero_grad()

    loss_a.backward()
    loss_c.backward()
    loss_ac.backward()

    optimizer_a.step()
    optimizer_c.step()
    optimizer_ac.step()

In [None]:
# Sequence distance
test_seq_distances = normalize_seq_distances(np.linspace(0, 50, 50)).astype(np.float32).reshape(-1, 1)

img = model_a(torch.from_numpy(test_seq_distances)).data.numpy()

with plt.rc_context(rc={"font.size": 12}):
    fig, ax = plt.subplots(figsize=(14, 2.5))
    im = ax.imshow(img.T, aspect=2, extent=[-0.5, 50 - 0.5, 5 - 0.5, - 0.5])
    fig.colorbar(im)
    plt.xlabel("Sequence distance")
    plt.ylabel("One-hot encoding")

In [None]:
# Euclidean distance
test_cart_distances = normalize_cart_distances(np.linspace(0, 12, 50)).astype(np.float32).reshape(-1, 1)

img = model_c(torch.from_numpy(test_cart_distances)).data.numpy()

with plt.rc_context(rc={"font.size": 12}):
    fig, ax = plt.subplots(figsize=(14, 2.5))
    im = ax.imshow(img.T, aspect=12/50*2, extent=[-0.5, 12 - 0.5, 5 - 0.5, - 0.5])
    fig.colorbar(im)
    plt.xlabel("Euclidean distance")
    plt.ylabel("One-hot encoding")

In [None]:
for i in range(13):
    # Sequence - Euclidean distance
    test_seq_cart_distances = np.hstack([
        test_seq_distances,
        np.ones((len(test_seq_distances), 1), dtype=np.float32) * normalize_cart_distances(i),
    ])

    img = model_ac(torch.from_numpy(test_seq_cart_distances)).data.numpy()

    with plt.rc_context(rc={"font.size": 12}):
        fig, ax = plt.subplots(figsize=(14, 5))
        im = ax.imshow(img.T, aspect=12/50*8, extent=[-0.5, 50 - 0.5, 11 - 0.5, - 0.5])
        fig.colorbar(im)
        plt.xlabel("Euclidean distance")
        plt.ylabel("One-hot encoding")
    plt.show()

In [None]:
torch.save(
    model_a.state_dict(),
    "/home/kimlab1/database_data/datapkg/adjacency-net-v2/src/model_data/seq_barcode_model.state",
)

torch.save(
    model_c.state_dict(),
    "/home/kimlab1/database_data/datapkg/adjacency-net-v2/src/model_data/cart_barcode_model.state",
)

torch.save(
    model_ac.state_dict(),
    "/home/kimlab1/database_data/datapkg/adjacency-net-v2/src/model_data/seq_cart_barcode_model.state",
)

### `linear1`

In [None]:
fg, axs = plt.subplots(1, 2, figsize=(14, 4))
axs[0].hist(model_a.linear1.weight.data.numpy().reshape(-1))
axs[1].hist(model_a.linear1.bias.data.numpy().reshape(-1))

print(model_a.linear1.weight.data.numpy().std())
print(model_a.linear1.bias.data.numpy().std())

In [None]:
fg, axs = plt.subplots(1, 2, figsize=(14, 4))
axs[0].hist(model_ac.linear1.weight.data.numpy().reshape(-1))
axs[1].hist(model_ac.linear1.bias.data.numpy().reshape(-1))

print(model_ac.linear1.weight.data.numpy().std())
print(model_ac.linear1.bias.data.numpy().std())

### `linear2`

In [None]:
fg, axs = plt.subplots(1, 2, figsize=(14, 4))
axs[0].hist(model_a.linear2.weight.data.numpy().reshape(-1))
axs[1].hist(model_a.linear2.bias.data.numpy().reshape(-1))

print(model_a.linear2.weight.data.numpy().std())
print(model_a.linear2.bias.data.numpy().std())

In [None]:
fg, axs = plt.subplots(1, 2, figsize=(14, 4))
axs[0].hist(model_ac.linear2.weight.data.numpy().reshape(-1))
axs[1].hist(model_ac.linear2.bias.data.numpy().reshape(-1))

print(model_ac.linear2.weight.data.numpy().std())
print(model_ac.linear2.bias.data.numpy().std())

In [None]:
losses_array = np.hstack(losses_ac)
plt.plot(losses_array)

In [None]:
t = torch.eye(10, dtype=torch.float32).unsqueeze(0)
t

In [None]:
mp = nn.MaxPool1d(3, ceil_mode=True)

In [None]:
mp(t)

In [None]:
t.unsqueeze(0).shape