In [None]:
import torch
from torch import nn
import os
import math
import matplotlib.pyplot as plt
import numpy as np
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

device = (
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
device = torch.device(device)
print(f"Using {device} device")

In [None]:
from src import get_data_loaders

faste_files_to_load = 37
normalize = False
num_train_val_data = math.inf
num_test_data = math.inf


train_loader, val_loader, test_loader = get_data_loaders(faste_files_to_load=faste_files_to_load, 
                                                         normalize=normalize, 
                                                         train_val_data_to_load=num_train_val_data, 
                                                         test_data_to_load=num_test_data)
train_loader = train_loader.to(device)
train_loader = val_loader.to(device)
train_loader = test_loader.to(device)

In [None]:
from src import SimpleCNN
from src import train_model

save_dir = os.path.join(os.getcwd(), 'Model_SimpleCNN')
os.makedirs(save_dir, exist_ok=True)

num_kernels=[512] # [32,16,8]
kernel_size=[256]
dropout=0.1
model = SimpleCNN(num_kernels=num_kernels,
                    kernel_size=kernel_size,
                    dropout=dropout,
                    output_size=faste_files_to_load
                    ).to(device)

lr = 0.0001
epochs = 500
patientce = 10
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.PoissonNLLLoss(log_input=True, full=True)
train_model(train_loader, val_loader, model, optimizer, loss_fn, epochs, save_dir, patientce)

In [None]:
X, y = next(iter(test_loader))
y_pred = model.forward(X).to(device)
i = 2
print('Tissue: Predicted, True')
for s, (y_p, y_t) in enumerate(zip(y_pred[i], y[i])):
    print(f'{s}: {y_p:.3f}, {y_t:.3f}')

In [None]:
with torch.no_grad():
        model.eval()
        pred_list = []
        labels_list = []
        for batch_index, (X, y) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)
            y_pred = torch.exp(model(X))
            y_pred = torch.flatten(y_pred).cpu().detach().numpy()
            pred_list.append(y_pred)

            y = torch.flatten(y).cpu().numpy()
            labels_list.append(y)
            
        labels = np.concatenate(labels_list)
        predictions = np.concatenate(pred_list)
        
pearson_r = np.corrcoef(labels, predictions)[0, 1]

plt.scatter(labels, predictions)
plt.xlabel("Experiment Coverage")
plt.ylabel("Predicted Coverage")
plt.title("Model Accuracy on Test Set (Chromosome 5)")
plt.text(0.1, 0.9, f"r = {pearson_r:.2f}", transform=plt.gca().transAxes)
m, b = np.polyfit(y, y_pred, 1)
X_plot = np.linspace(plt.gca().get_xlim()[0], plt.gca().get_xlim()[1], 100)
plt.plot(X_plot, m * X_plot + b, '-', color='red')
plt.plot(X_plot, X_plot, '--', color='blue', label='y = x')
plt.savefig(os.path.join(save_dir, 'Accuracy.png'), dpi=300)
plt.show()