In [8]:
import numpy as np
import pandas as pd
import torch
import torch.utils.data as td
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import dask.dataframe as dd

froot = './data/k562_main'

df = dd.read_csv(froot + ".csv")

In [9]:
print(df.head())

   seqnames    start      end strand  ensembl_gene_id  score      ctcf  \
0         1  3859709  3859709      +  ENSG00000169598    0.0 -0.010876   
1         1  3859710  3859710      +  ENSG00000169598    0.0 -0.010887   
2         1  3859711  3859711      +  ENSG00000169598    0.0 -0.010902   
3         1  3859712  3859712      +  ENSG00000169598    0.0 -0.010920   
4         1  3859713  3859713      +  ENSG00000169598    0.0 -0.010941   

   h3k36me3   h3k4me1  h3k79me2  ...       sj3       dms      rpts  wgbs  \
0  0.353765 -0.078256 -0.156547  ... -0.057178 -0.307549  0.249626   0.0   
1  0.347003 -0.077117 -0.155891  ... -0.057178 -0.307549  0.249626   0.0   
2  0.340295 -0.075994 -0.155236  ... -0.057178 -0.307549  0.249626   0.0   
3  0.333641 -0.074887 -0.154583  ... -0.057178 -0.307549  0.249626   0.0   
4  0.327043 -0.073795 -0.153930  ... -0.057178 -0.307549  0.249626   0.0   

       A         T         G         C  lambda_alphaj      zeta  
0 -0.625 -0.678443  1.954571 -0.

In [10]:
grouped_df = df.groupby('ensembl_gene_id')

group_lengths = grouped_df.size().reset_index().rename(columns={0: 'gene_length'})
df = df.merge(group_lengths, on="ensembl_gene_id", how="left")

In [22]:
column_names = df.columns.tolist()
feature_names = column_names[6:-7]
nucleotides = column_names[-7:-3]
print(feature_names)
print(nucleotides)

# process read counts
X_ji_df = grouped_df['score'].apply(list, meta=('score', 'object'))

# process GLM simulated elongation rates
Z_ji_df = grouped_df['zeta'].apply(list, meta=('zeta', 'object'))

X_ji = X_ji_df.to_dask_array()
Z_ji = Z_ji_df.to_dask_array()



#num_samples = len(X_ji)

#X_ji_result = X_ji.compute()
#Z_ji_result = Z_ji.compute()

#X_ji_list = X_ji_result.tolist()
#Z_ji_list = Z_ji_result.tolist()

#print(len(X_ji_list))

#print(num_samples)

['ctcf', 'h3k36me3', 'h3k4me1', 'h3k79me2', 'h3k9me1', 'h3k9me3', 'h4k20me1', 'sj5', 'sj3', 'dms', 'rpts', 'wgbs']
['A', 'T', 'G', 'C']


In [36]:
from sklearn.preprocessing import MinMaxScaler

#Y_ji is a list of samples containing lists of their feature values
    # [   
    #   [[sample_1_feature_1], [sample_1_feature_2], [sample_1_feature_3]],
    #   [[sample_2_feature_1], [sample_1_feature_2], [sample_1_feature_3]],  
    # ]

Y_ji_df = grouped_df[feature_names].apply(list, meta=('features', 'object'))

Y_ji = Y_ji_df.to_dask_array()

Y_ji_arr = Y_ji.compute()

In [None]:
# normalize features
scaler = MinMaxScaler()
Y_ji = np.array(Y_ji)
# reshape dataset to [num_samples, num_features * feature_length]
Y_ji_reshaped = Y_ji.reshape(Y_ji.shape[0], -1)
normalized_Y_ji_reshaped = scaler.fit_transform(Y_ji_reshaped)
Y_ji = normalized_Y_ji_reshaped.reshape(Y_ji.shape)

In [None]:
C_j = df.groupby('ensembl_gene_id')['lambda_alphaj'].apply(list).tolist() 

In [None]:
def custom_collate_fn(batch):
    # Group samples based on their lengths
    grouped_samples = {}
    for sample in batch:
        sample_length = len(sample[0])  # All elements in a sample should have the same length
        if sample_length not in grouped_samples:
            grouped_samples[sample_length] = []
        grouped_samples[sample_length].append(sample)

    # Create batches for each group
    batches = []
    for sample_length, samples in grouped_samples.items():
        batch_data = [torch.stack(items) for items in zip(*samples)]
        batches.append(batch_data)

    return batches

In [None]:
from torch.utils.data import Dataset, DataLoader, TensorDataset

class CustomDataset(Dataset):
    def __init__(self, Y_ji, X_ji, C_j, Z_ji):
        self.Y_ji = Y_ji
        self.X_ji = X_ji
        self.C_j = C_j
        self.Z_ji = Z_ji

    def __len__(self):
        return len(self.X_ji)

    def __getitem__(self, idx):
        return {
            'Y_ji':  torch.tensor(self.Y_ji[idx], dtype=torch.float32),
            'X_ji': torch.tensor(self.X_ji[idx], dtype=torch.float32),
            'C_j': torch.tensor(self.C_j[idx], dtype=torch.float32),
            'Z_ji': torch.tensor(self.Z_ji[idx], dtype=torch.float32)
        }

In [None]:
dataset = CustomDataset(Y_ji, X_ji, C_j, Z_ji)

trnset, valset, tstset = td.random_split(dataset, [0.5,0.25,0.25])

trndl = DataLoader(trnset, batch_size=1, shuffle=True)
tstdl = DataLoader(tstset, batch_size=1, shuffle=False)
valdl = DataLoader(valset, batch_size=1, shuffle=False)

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence

class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Model, self).__init__()
        
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, inputs):
        outputs, _ = self.lstm(inputs)
        averaged_outputs = torch.mean(outputs, dim=1)
        predictions = self.fc(averaged_outputs)
        return predictions

# input size: [50, 12, 2000]
input_size = 2000
hidden_size = 32
output_size = 2000

model = Model(input_size, hidden_size, output_size)
print(model)

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, X_ji, C_j, Z_ji):
        epsilon = 1e-8
        clipped_Z_ji = torch.clamp(Z_ji, epsilon)
        loss = X_ji * torch.log(clipped_Z_ji) + C_j * torch.exp(-clipped_Z_ji)
        # compute mean over batch to normalize due to varying batch sizes
        return loss.mean() 

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 100

loss_fn = CustomLoss()

for epoch in range(num_epochs):
    for batch in trndl:
        optimizer.zero_grad()
        Y_ji_batch = batch['Y_ji']
        X_ji_batch = batch['X_ji']
        C_j_batch = batch['C_j']
        outputs = model(Y_ji_batch)
        loss = loss_fn(X_ji_batch, C_j_batch, outputs)
        loss.backward()
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

In [None]:
from matplotlib import pyplot as plt
epochs = range(1, len(hist[0]) + 1)
plt.plot(epochs, hist[0], label='train_loss')
plt.plot(epochs, hist[1], label='valid_loss')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show 

In [None]:
mae = []
mse = []
for inputs, labels in tstdl:
    #interpret_model(inputs, labels)
    outputs = model(inputs)
    mae.append(torch.mean(torch.abs(outputs - labels), dim=0))
    mse.append(torch.mean((outputs - labels)**2, dim=0))

mean_mae = torch.mean(torch.stack(mae))    
mean_mse = torch.mean(torch.stack(mse))
print("Overall Mean Absolute Error (MAE):", round(mean_mae.item(), 3))
print("Overall Mean Squared Error (MSE):", round(mean_mse.item(), 3))

In [None]:
model.eval()

inputs = next(iter(tstdl)) 
print("number of samples: " + str(len(inputs)))

with torch.no_grad():
    outputs = model(inputs['Y_ji'])
    
print(outputs)

targets = inputs['Z_ji']

In [None]:
print(targets.shape)
print(outputs.shape)
print(targets)

In [None]:
num_points = 2000
indices = np.arange(num_points)

# Subset the data for indices 0 to 200
subset_indices = indices[:201]  # 0 to 200
subset_outputs = outputs[:, :201]  # Select the first 201 points for all samples

fig, axs = plt.subplots(3, 3, figsize=(15, 15))

for i in range(3):
    for j in range(3):
        axs[i, j].plot(subset_indices, subset_outputs[i + j * 3])
        axs[i, j].set_ylim(-1, 2)
        axs[i, j].set_xlabel('Index')
        axs[i, j].set_ylabel('Elongation Rates')
        axs[i, j].set_title(f'Plot {i + j * 3 + 1}')

# Adjust layout
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
indices = range(len(targets[0]))

fig, axs = plt.subplots(3, 3, figsize=(15,15))
for i in range(3):
    for j in range(3):
        axs[i, j].scatter(indices, targets[i+j*3], s=5)
        axs[i, j].scatter(indices, outputs[i+j*3], s=5)
        axs[i, j].set_ylim(-1, 2)

plt.xlabel('Index')
plt.ylabel('Elongation Rates')
plt.legend(['GLM Elongation Rate', 'NN Elongation Rate'], loc='upper center', bbox_to_anchor=(0.5, -0.6))