# CS525 FINAL PROJECT
### GROUP MEMBERS: Ted Monyak, Jack Forman




In [1]:
import logomaker
import os
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from src import get_data_loaders

device = (
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
device = torch.device(device)
print(f"Using {device} device")

Using mps device


# Introduction

### Dataset
##### Training, Validation, Testing Dataset

##### Biological Dataset

To examine the biological relevance and generalizability of our models, we curated an additional dataset from [Ensembl Plants](https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-61/fasta/arabidopsis_thaliana/pep/), consisting of approximately 27,000 protein-coding genes from the *Arabidopsis thaliana* genome. The aim of this dataset is to investigate the chromatin accessibility of these genes across various plant tissues. This is biologically significant, as different tissues express distinct sets of proteins, and gene accessibility plays a key regulatory role in this expression. 

We expect the results to reveal two main patterns:  
1. A core set of genes with consistent accessibility across all tissues  
2. A set of genes exhibiting tissue-specific accessibility profiles



In [3]:
data_dir=os.path.join(os.getcwd(), 'Data', 'Parsed_Data_Window')
train_val_data_to_load=math.inf
test_data_to_load=math.inf
faste_files_to_load = 37
normalize = False
batch_size = 128
train_loader, val_loader, test_loader = get_data_loaders(
    data_dir=data_dir,
    train_val_data_to_load=train_val_data_to_load,
    test_data_to_load=test_data_to_load,
    batch_size=batch_size,
    faste_files_to_load=faste_files_to_load,
    normalize=normalize)

Loading sequences from sequences.fasta
Loading coverage from SRX391990.faste
Loading coverage from SRX9770779.faste
Loading coverage from SRX9770784.faste
Loading coverage from SRX9770786.faste
Loading coverage from SRX391992.faste
Loading coverage from SRX391996.faste
Loading coverage from SRX9770782.faste
Loading coverage from SRX1098138.faste
Loading coverage from SRX9770780.faste
Loading coverage from SRX391994.faste
Loading coverage from SRX9770787.faste
Loading coverage from SRX391993.faste
Loading coverage from SRX9770778.faste
Loading coverage from SRX391991.faste
Loading coverage from SRX9770785.faste
Loading coverage from SRX9770781.faste
Loading coverage from SRX391995.faste
Loading coverage from SRX391997.faste
Loading coverage from SRX9770783.faste
Loading coverage from SRX1098137.faste
Loading coverage from SRX1098135.faste
Loading coverage from SRX1096550.faste
Loading coverage from SRX9770774.faste
Loading coverage from SRX9770789.faste
Loading coverage from SRX9770790.

### Model Architectures

Several neural network architectures were tested to evaluate their effectiveness in predicting chromatin accessibility. In the simplest case—the **Simple CNN**—a single convolutional layer was followed by a fully connected layer. This minimal architecture offers advantages in terms of interpretability and computational efficiency. However, its simplicity may limit its ability to capture higher-order patterns in the sequence data that could be important for accurately modeling chromatin accessibility.

A more complex architecture—the **Deep CNN**—consisted of three convolutional layers followed by three fully connected layers. This model has significantly greater capacity to learn complex features and hierarchical patterns within the sequence data, while still retaining a relatively straightforward structure. Its depth allows it to capture more nuanced relationships that may be missed by simpler model.






# Methods



## Build Model

In [4]:
class LocalDnaCnn(nn.Module):
    def __init__(self, num_kernels=[320, 128, 64, 32], kernel_size=[10,10,10,10],
                 dropout=0, transformer_heads=4, transformer_layers=3):
        super(LocalDnaCnn, self).__init__()
        self.input_channels=4
        self.num_kernels=num_kernels
        self.kernel_size=kernel_size
        self.dropout=dropout
        self.transformer_heads=transformer_heads
        self.transformer_layers=transformer_layers
        self.conv_block = nn.Sequential(
            # first layer
            nn.Conv1d(in_channels=self.input_channels,
                      out_channels=num_kernels[0],
                      kernel_size=kernel_size[0]),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.MaxPool1d(kernel_size=2),
        )
        # second layer
        self.conv_block.append(nn.Sequential(
            nn.Conv1d(in_channels=self.num_kernels[0],
                      out_channels=num_kernels[1],
                      kernel_size=kernel_size[1]),
            #nn.BatchNorm1d(num_features=num_kernels[1]),
            nn.ReLU(),
            nn.Dropout(p=self.dropout),   
            nn.MaxPool1d(kernel_size=2),        
        ))
        # Add a third convolutional layer
        self.conv_block.append(nn.Sequential(
            # second layer
            nn.Conv1d(in_channels=self.num_kernels[1],
                      out_channels=num_kernels[2],
                      kernel_size=kernel_size[2]),
            #nn.BatchNorm1d(num_features=num_kernels[2]),
            nn.ReLU(),
            nn.Dropout(p=self.dropout),   
            nn.MaxPool1d(kernel_size=2),  
        ))
        # Add a fourth convolutional layer
        self.conv_block.append(nn.Sequential(
            # second layer
            nn.Conv1d(in_channels=self.num_kernels[2],
                      out_channels=num_kernels[3],
                      kernel_size=kernel_size[3]),
        #    #nn.BatchNorm1d(num_features=num_kernels[2]),
            nn.ReLU(),
            nn.Dropout(p=self.dropout),
            nn.MaxPool1d(kernel_size=2),  
        ))
        #the transformer is a combo of multiple encoder layers
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(2144,
                                       nhead=self.transformer_heads),
                                       self.transformer_layers)
        self.attention_weights = None  # Store attention weights
        self.regression_block = nn.Sequential(
            nn.Linear(4704, 1024),
            nn.ReLU(),
            nn.Dropout(p=self.dropout),
            nn.Linear(1024, 128),
            nn.ReLU(),
            nn.Dropout(p=self.dropout),
            nn.Linear(128, faste_files_to_load),
        ) 

    def forward(self, x):
        x = self.conv_block(x)
        x =  torch.flatten(x, 1)
        #x,_ = torch.max(x, dim=2) 
        #x = self.transformer(x)     
        x = self.regression_block(x)
        return x

# Train Model

### Training Functions

In [5]:
def train_epoch(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader)
    num_batches = len(dataloader)
    total_loss = 0
    # set the model to training mode - important when you have 
    # batch normalization and dropout layers
    model.train()
    for batch_idx, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        # Compute prediction and loss
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        # backpropagation
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch % 10 == 0 :
        print(f"training loss: {total_loss/num_batches:>7f}")
    return total_loss / num_batches

def validation(dataloader, model, loss_fn, epoch):
    # set the model to evaluation mode 
    model.eval()
    # size of dataset
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    validation_loss = 0
    # Evaluating the model with torch.no_grad() ensures that no gradients 
    # are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage 
    # for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            validation_loss += loss_fn(y_pred, y).item()
    validation_loss /= num_batches
    if epoch%10 == 0 :
        print(f"Validation Loss: {validation_loss:>8f} \n")
    return validation_loss

def train_model(train_loader, val_loader, model, optimizer, loss_fn, lr_scheduler, epochs, overfit_ratio=0.85):

    train_loss = []
    validation_loss = []
    for t in range(epochs):
        if t % 1 == 0 :
            print(f"Epoch {t}\n-------------------------------")
        loss = train_epoch(train_loader, model, loss_fn, optimizer, t)
        train_loss.append(loss)
        loss = validation(val_loader, model, loss_fn, t)
        validation_loss.append(loss)
    
        if train_loss[-1] < validation_loss[-1]:
            # print(f"Training loss {train_loss[-1]} is less than validation loss {validation_loss[-1]}")

            if train_loss[-1]/validation_loss[-1] < overfit_ratio:
                print(f"Training loss {train_loss[-1]} is well below validation loss {validation_loss[-1]}")
                break
        lr_scheduler.step()
                
    print("Done!")

    def plot_loss(train_loss, validation_loss):
        plt.figure(figsize=(4,3))
        plt.plot(np.arange(len(train_loss)), train_loss, label='Training')
        plt.plot(np.arange(len(validation_loss)), validation_loss, label='Validation')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        # plt.savefig('')
        plt.show()
    plot_loss(train_loss, validation_loss)


### Train

In [None]:
num_kernels=[320, 128, 64, 32]
kernel_size=[10,10,10,10]
dropout=0.3
transformer_heads=4
transformer_layers=3
lr = 0.001
faste_files_to_load=37
decay_rate = 0.97

model = LocalDnaCnn(num_kernels=num_kernels,
                    kernel_size=kernel_size,
                    dropout=dropout,
                    transformer_heads=transformer_heads,
                    transformer_layers=transformer_layers).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.PoissonNLLLoss(log_input=True, full=True)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)
train_model(train_loader=train_loader,
            val_loader=val_loader,
            model=model,
            optimizer=optimizer,
            loss_fn=loss_fn,
            lr_scheduler=lr_scheduler,
            epochs=30,
            overfit_ratio=0.85)



Epoch 0
-------------------------------
training loss: 0.917787
Validation Loss: 0.920516 

Epoch 1
-------------------------------
Epoch 2
-------------------------------


### Save Model

In [None]:
model_fname = 'Ted_Models/model.pth'
torch.save(model, model_fname)

# Results

### Validate Model

In [None]:
model_to_load = model_fname
model = torch.load(model_to_load)
#model.to(device)
#model.eval()
with torch.no_grad():
        model.eval()
        pred_list = []
        labels_list = []
        for batch_index, (X, y) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)
            y_pred = model(X)
            y_pred = torch.exp(y_pred)
            y_pred = torch.flatten(y_pred).cpu().detach().numpy()
            pred_list.append(y_pred)

            y = torch.flatten(y).cpu().numpy()
            labels_list.append(y)
            
        labels = np.concatenate(labels_list)
        predictions = np.concatenate(pred_list)
        
pearson_r = np.corrcoef(labels, predictions)[0, 1]

plt.scatter(labels, predictions)
plt.xlabel("Experiment Coverage")
plt.ylabel("Predicted Coverage")
plt.title("Model Accuracy on Test Set (Chromosome 5)")
plt.text(0.1, 0.9, f"r = {pearson_r:.2f}", transform=plt.gca().transAxes)
degree = 1
coeffs = np.polyfit(labels, predictions, degree)
polynomial = np.poly1d(coeffs)
X_plot = np.linspace(plt.gca().get_xlim()[0], plt.gca().get_xlim()[1], 100)
y_plot = polynomial(X_plot)
plt.plot(X_plot, y_plot, color='red')
plt.savefig('LocalDnaCnnAccuracy.png')
plt.show()

### Training

### Simple CNN
<style>
img {
    display: block;
    margin-left: 0;
    margin-right: auto;
    width: 30%;
    height: 30%;
}
</style>


Discussion about training accuracy  
![image](Model_SimpleCNN_Best/TrainingAccuracy.png)

Discussion about validation accuracy  
![image](Model_SimpleCNN_Best/ValAccuracy.png)

Discussion about testing accuracy  
![image](Model_SimpleCNN_Best/TestingAccuracy.png)

### Deeper CCN
<style>
img {
    display: block;
    margin-left: 0;
    margin-right: auto;
    width: 30%;
    height: 30%;
}
</style>


Discussion about training accuracy  
![image](Model_DeepCNN_Best/TrainingAccuracy.png)

Discussion about validation accuracy  
![image](Model_DeepCNN_Best/ValAccuracy.png)

Discussion about testing accuracy  
![image](Model_DeepCNN_Best/TestingAccuracy.png)


### Transformer
<style>
img {
    display: block;
    margin-left: 0;
    margin-right: auto;
    width: 30%;
    height: 30%;
}
</style>
Discussion about training accuracy ![image](Model_DNATranformer_Best/TrainingAccuracy.png) 
Discussion about validation accuracy ![image](Model_DNATranformer_Best/ValAccuracy.png) 
Discussion about testing accuracy ![image](Model_DNATranformer_Best/TestingAccuracy.png)


### Filters

In [None]:
def calc_entropy(row):
    H = 0
    for n in ['A', 'C', 'G', 'T']:
        H -= row[n] * math.log(row[n])
    return H

def create_profile_logo(motif_probs):
    """
    Creates a profile logo based on the motif probabilities

    parameters:
    motif_probs: a 2D array, with dims [k, 4], populated with
    the probabilities of each nucleotide at each position

    Displays the profile logo
    """
    mat = pd.DataFrame(motif_probs)
    mat.columns = ['A', 'C', 'G', 'T']
    logomaker.Logo(mat)

weights=model.conv_block[0].weight.detach().cpu().numpy()
max_h = 0
max_i = 0
for i in range(len(weights)) :
    motif = pd.DataFrame(weights[i].T, columns=['A','C','G','T']).abs()
    motif['entropy'] = motif.apply(calc_entropy, axis=1)
    total_entropy = motif['entropy'].sum()
    if total_entropy > max_h:
        max_h = total_entropy
        max_i = i

motif = pd.DataFrame(weights[max_i].T, columns=['A','C','G','T'])
motif.plot.bar(stacked=True, use_index=True)

fig = plt.figure()
plt.imshow(weights[max_i],cmap="Blues")
plt.colorbar()

create_profile_logo(motif.abs())

### Gene differnece in Tissues

# Discussion