# Imports

In [3]:
import sys

sys.path.append("./src")
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import pearsonr
from sklearn.linear_model import Lasso, LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from src.data_storage import random_mutant_dataset, virus_seqs
import pandas as pd
from itertools import product
import torch
import os
from src.model_config import ModelConfig, model_collections
from src.data_storage import random_mutant_dataset, virus_seqs
import data
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

# Linear regression model

## Feature extraction

In [46]:
def one_hot_encode(s):
    one_hot_arr = np.eye(4)
    return np.concatenate([one_hot_arr[i, :].reshape((-1, 1)) for i in s], axis=1)

freq_df = pd.read_csv("data/wt_filter_gate.csv")
# One-hot encoding
freq_df["input_arr"] = freq_df.seq.apply(lambda s: ["AUCG".find(nt) for nt in s])
freq_df["input_arr"] = freq_df.input_arr.apply(one_hot_encode)

# Second-order features
poly = PolynomialFeatures(degree=2, interaction_only=True)
freq_df.input_arr = freq_df.input_arr.apply(
    lambda arr: arr.flatten(order="F").reshape((1, -1))
)

freq_df.input_arr = freq_df.input_arr.apply(poly.fit_transform)
start = 26
mask = np.ones(5461)
flag = 0
total = 0
for i in np.arange(103, 0, -1):
    if flag == 0:
        flag = 3
    else:
        flag -= 1
    mask[start : start + flag] = 0
    # print(start, start+flag)
    total += flag
    start += i
freq_df.input_arr = freq_df.input_arr.apply(lambda x: x[:, mask == 1])

print("Shape of input:", freq_df.loc[0, "input_arr"].shape)

X = np.concatenate(freq_df.input_arr.tolist(), axis=0)
y = freq_df.score.values
np.save("results/epistasis_linear_features_first_order.npy", {"X": X, "y": y})

Shape of input: (1, 5305)


## Training

In [None]:
# Train the LR model
# %reset -f
data_np = np.load("results/epistasis_linear_features_first_order.npy", allow_pickle=True).item()
X, y = data_np["X"], data_np["y"]

alpha = 0.0001
reg = Lasso(alpha=alpha, fit_intercept=False).fit(
    X[: X.shape[0] // 10 * 9, :], y[: len(y) // 10 * 9]
)

## Evaluation

In [None]:
y_pred = reg.predict(X[-X.shape[0] // 10 :, :])
print("Pearson r:", pearsonr(y_pred.flatten(), y[-len(y) // 10 :])[0])

## Epistasis map

In [49]:
independent_coef = abs(reg.coef_[1 : 4 * 26 + 1].reshape((4, -1), order="F")).sum(axis=0)
independent_coef = (independent_coef - independent_coef.min()) / (
   independent_coef.max() - independent_coef.min()
)

In [50]:
interaction_coef = abs(reg.coef_[4 * 26 + 1:])
interaction_coef = (interaction_coef - interaction_coef.min()) / (
    interaction_coef.max() - interaction_coef.min()
)


In [None]:
interaction_mat = np.zeros((26, 26)) + np.diag(independent_coef)
start_idx = 0
for i, n in enumerate(range(25, 0, -1)):
    coef = []
    for _ in range(4):
        #print(i,n)
        #print(start_idx)
        #print(start_idx + n * 4)
        #print(interaction_coef[start_idx:start_idx + n * 4].reshape((4, -1), order="F"))
        #print(interaction_coef[start_idx:start_idx + n * 4].reshape((4, -1), order="F").sum(axis=0))
        coef.append(interaction_coef[start_idx:start_idx + n * 4].reshape((4, -1), order="F").sum(axis=0).tolist())
        start_idx += 4 * n
    interaction_mat[i, i + 1 :] = np.array(coef).sum(axis=0)
#interaction_mat = interaction_mat + interaction_mat.T - np.diag(interaction_mat.diagonal())
def cre_color(r,g,b):
    return [r/256,g/256,b/256]

colors = [
    cre_color(248,248,248),
    cre_color(200,140,165),
    cre_color(190,135,160),
    cre_color(185,130,155),
    cre_color(180,120,145),
    cre_color(165,105,120),
    cre_color(150,100,115),
    cre_color(145,95,105),
    cre_color(140,90,100),
    cre_color(135,85,95),
    cre_color(130,80,90),
    cre_color(125,75,85),
]

cmap1 = LinearSegmentedColormap.from_list("mycmap", colors)


_, ax = plt.subplots(1, 1, figsize=(7,7))
mask = np.zeros_like(interaction_mat,dtype=bool)
mask[np.triu_indices_from(mask,k=0)] = True
interaction_mat_scale = []
for item in interaction_mat:
    item = np.log10(item+2)
    interaction_mat_scale.append(item)
sns.heatmap(abs(interaction_mat), 
            vmax=2,
            vmin=0.2,
            cmap='viridis',
            # center=0.28,
            square=True,
            ax=ax,
            cbar_kws={'ticks':[0,1,2],'pad':0.1,'fraction':0.05,'aspect':3,'shrink':3,'location':'right'},
            )
ax.tick_params(top='on',bottom='off',right='on',left='off')
ax.tick_params(labeltop='on',labelbottom='off',labelright='on',labelleft='off')
ax.set_xticks(np.array(range(0,26,4))+0.5)
ax.set_yticks(np.array(range(0,26,4))+0.5)
ax.set_xticklabels(np.array(range(0,26,4))+1)
ax.set_yticklabels(np.array(range(0,26,4))+1)
plt.rcParams['svg.fonttype']= 'none'
plt.savefig('results/epistasis_WT.svg', format='svg', bbox_inches='tight')

In [54]:
import pandas as pd
interaction_df = pd.DataFrame(interaction_mat)
interaction_df.index = [f"Pos{i+1}" for i in range(26)]
interaction_df.columns = [f"Pos{i+1}" for i in range(26)]
interaction_df.to_csv("results/interaction_matrix.csv")
