# Preparation

In [7]:
from Bio import SeqIO
from collections import defaultdict
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List
import torch.nn as nn
import os
import json
import numpy as np

from genome import parse_gff

## Util functions

### Get cerevisiae fasta and gtf

In [16]:
fasta_file = SeqIO.parse("../data/genome/fasta_file.fsa", "fasta")

In [None]:
for seq in SeqIO.parse("../data/genome/fasta_file.fsa", "fasta"):
    print(len(seq))

In [None]:

cds_coords = parse_gff("../data/genome/gff_file.gff")
print(cds_coords["YDL075W"])
print(cds_coords["YDL045C"])

# Investigating the data

Please, add a shortcut of the Shared folder to your drive: thus you can get access to the files from here

In [32]:
EMBEDDING_SIZE = 768

In [None]:
condition_samples = json.load(open("samples.json"))
condition_samples

In [8]:
def load_sample(sample_id: str):
  return {
      "Sense": np.load(f"../data/waern_2013/{sample_id}.sense_bp1.npz"),
      "Antisense": np.load(f"../data/waern_2013/{sample_id}.antisense_bp1.npz")
  }

In [None]:
conditions = sorted(list(condition_samples.keys()))
conditions[:5]

In [None]:
samples = sorted([s for c_s in condition_samples.values() for s in c_s])
samples[:5]

In [2]:
import numpy as np

raw = np.load("../data/embeddings/chrI.npy")
processed = np.load("../data/prepared/chrI.npz")["X"]

In [3]:
raw.shape

(230218, 768)

In [None]:
chromosome_genes = defaultdict(list)
for gene, coords in cds_coords.items():
  chromosome_genes[coords['chromosome']].append(gene)

chromosome_genes['chrI'][:5]

In [51]:
holdout = ["chrVIII"]
folds = [
    ['chrI', 'chrII', 'chrIII'],
    ['chrIV', 'chrV', 'chrVI'],
    ['chrVII', 'chrIX', 'chrX'],
    ['chrXI', 'chrXII', 'chrXIII'],
    ['chrXIV', 'chrXV', 'chrXVI'],
]

In [None]:
ds = CustomDataset(['chrI', 'chrII', 'chrIII'])

In [None]:
ds.X

In [None]:
loader = DataLoader(ds, batch_size=64, shuffle=True)

In [104]:
X, Y = next(loader.__iter__())

In [None]:
X.shape

In [105]:
class LinearModel(nn.Module):
  """
  Ridge regression model.
  Parameters
  ----------
  window_size : int
    Size of the embedding context window.
  """
  def __init__(self, window_size: int = 500, n_conditions: int = 18):
    super(LinearModel, self).__init__()
    self.linear = nn.Linear(window_size * EMBEDDING_SIZE, n_conditions)

  def forward(self, x):
    # Flatten x before linear
    x = torch.flatten(x, start_dim=1)

    return self.linear(x)

In [None]:
train_dataset = CustomDataset(train_chromosomes)
test_dataset = CustomDataset(test_chromosomes)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [110]:
model = LinearModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fn = nn.MSELoss()

In [111]:
def train_one_epoch(epoch_index):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_dataloader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            # tb_x = epoch_index * len(train_dataloader) + i + 1
            # tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [None]:
train_one_epoch(0)

In [None]:
for i in range(len(folds)):
  test_chromosomes = folds[i]
  train_chromosomes = []
  for j in range(len(folds)):
    if i != j:
      train_chromosomes += folds[j]

  train_dataset = CustomDataset(train_chromosomes)
  test_dataset = CustomDataset(test_chromosomes)

  train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

  print(train_dataset, test_dataset)
  break

In [None]:
data_tracks = {
    "Sense": data,
    "Antisense": data_antisense,
}
plot_region_with_signals(
    cds_coords    = cds_coords,
    data_tracks   = data_tracks,
    chromosome    = 'chrIV',
    region_start  = 399000,
    region_end    = 400800
)