### Install and download

In [1]:
!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



### Importing

In [2]:
!pip install einops
!pip install pytorch_lightning
import os; os.getpid()
from scipy.stats import zscore
import torch
import copy
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch
from IPython.display import display
import torch.nn as nn
from torchvision.utils import make_grid
from torchvision.utils import save_image
from IPython.display import Image
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
from tqdm import tqdm_notebook
import seaborn as sns
import matplotlib.pyplot as plt
from torch.nn.modules.activation import ReLU
from torch.optim import Adam
from tqdm import tqdm_notebook
from torchvision.utils import save_image
import matplotlib
import math
from inspect import isfunction
from functools import partial
import scipy
from scipy.special import rel_entr
from torch import nn, einsum
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange
from torch import nn, einsum
import torch.nn.functional as F
import matplotlib.animation as animation
import matplotlib.image as mpimg
import glob
from PIL import Image
import pytorch_lightning as pl 
import imageio
from torch.autograd import Variable
from torch.autograd import grad as torch_grad

%matplotlib inline

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 332 kB/s 
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.8.1-py3-none-any.whl (798 kB)
[K     |████████████████████████████████| 798 kB 12.1 MB/s 
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.10.2-py3-none-any.whl (529 kB)
[K     |████████████████████████████████| 529 kB 64.3 MB/s 
[?25hCollecting lightning-utilities==0.3.*
  Downloading lightning_utilities-0.3.0-py3-none-any.whl (15 kB)
Collecting fire
  Downloading fire-0.4.0.tar.gz (87 kB)
[K     |████████████████████████████████| 87 kB 5.6 MB/s 
Building wheels for collected packages: fire
  Building wheel f

### Dataset 

In [3]:
class SequenceDatasetBase(Dataset):
    def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", sequence_transform=None, cell_type_transform=None):
        super().__init__()
        self.data = pd.read_csv(data_path, sep="\t")
        self.sequence_length = sequence_length
        self.sequence_encoding = sequence_encoding
        self.sequence_transform = sequence_transform
        self.cell_type_transform = cell_type_transform
        self.alphabet = ["A", "C", "T", "G"]
        self.check_data_validity()

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # Iterating through DNA sequences from dataset and one-hot encoding all nucleotides
        current_seq = self.data["raw_sequence"][index]
        if 'N' not in current_seq: 
            X_seq = self.encode_sequence(current_seq, encoding=self.sequence_encoding)
            
            # Reading cell component at current index
            X_cell_type = self.data["component"][index]
            
            if self.sequence_transform is not None:
                X_seq = self.sequence_transform(X_seq)
            if self.cell_type_transform is not None:
                X_cell_type = self.cell_type_transform(X_cell_type)

            return X_seq, X_cell_type

    def check_data_validity(self):
        """
        Checks if the data is valid.
        """
        if not set("".join(self.data["raw_sequence"])).issubset(set(self.alphabet)):
            raise ValueError(f"Sequence contains invalid characters.")

        uniq_raw_seq_len = self.data["raw_sequence"].str.len().unique()
        if len(uniq_raw_seq_len) != 1 or uniq_raw_seq_len[0] != self.sequence_length:
            raise ValueError(f"The sequence length does not match the data.")

    def encode_sequence(self, seq, encoding):
        """
        Encodes a sequence using the given encoding scheme ("polar", "onehot", "ordinal").
        """
        if encoding == "polar":
            seq = self.one_hot_encode(seq).T
            seq[seq == 0] = -1
        elif encoding == "onehot":
            seq = self.one_hot_encode(seq).T
        elif encoding == "ordinal":
            seq = np.array([self.alphabet.index(n) for n in seq])
        else:
            raise ValueError(f"Unknown encoding scheme: {encoding}")
        return seq

    # Function for one hot encoding each line of the sequence dataset
    def one_hot_encode(self, seq):
        """
        One-hot encoding a sequence
        """
        seq_len = len(seq)
        seq_array = np.zeros((self.sequence_length, len(self.alphabet)))
        for i in range(seq_len):
            seq_array[i, self.alphabet.index(seq[i])] = 1
        return seq_array




In [4]:
class SequenceDatasetTrain(SequenceDatasetBase):
    def __init__(self, data_path="", **kwargs):
        super().__init__(data_path=data_path, **kwargs)

class SequenceDatasetValidation(SequenceDatasetBase):
    def __init__(self, data_path="", **kwargs):
        super().__init__(data_path=data_path, **kwargs)

class SequenceDatasetTest(SequenceDatasetBase):
    def __init__(self, data_path="", **kwargs):
        super().__init__(data_path=data_path, **kwargs)


class SequenceDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_path=None,
        val_path=None,
        test_path=None,
        sequence_length=200,
        sequence_encoding="polar",
        sequence_transform=None,
        cell_type_transform=None,
        batch_size=None,
        num_workers=1
    ):
        super().__init__()
        self.datasets = dict()
        self.train_dataloader, self.val_dataloader, self.test_dataloader = None, None, None

        if train_path:
            self.datasets["train"] = train_path
            self.train_dataloader = self._train_dataloader

        if val_path:
            self.datasets["validation"] = val_path
            self.val_dataloader = self._val_dataloader

        if test_path:
            self.datasets["test"] = test_path
            self.test_dataloader = self._test_dataloader

        self.sequence_length = sequence_length
        self.sequence_encoding = sequence_encoding
        self.sequence_transform = sequence_transform
        self.cell_type_transform = cell_type_transform
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self):
        if "train" in self.datasets:
            self.train_data = SequenceDatasetTrain(
                data_path=self.datasets["train"],
                sequence_length=self.sequence_length,
                sequence_encoding=self.sequence_encoding,
                sequence_transform=self.sequence_transform,
                cell_type_transform=self.cell_type_transform
            )
        if "validation" in self.datasets:
            self.val_data = SequenceDatasetValidation(
                data_path=self.datasets["validation"],
                sequence_length=self.sequence_length,
                sequence_encoding=self.sequence_encoding,
                sequence_transform=self.sequence_transform,
                cell_type_transform=self.cell_type_transform
            )
        if "test" in self.datasets:
            self.test_data = SequenceDatasetTest(
                data_path=self.datasets["test"],
                sequence_length=self.sequence_length,
                sequence_encoding=self.sequence_encoding,
                sequence_transform=self.sequence_transform,
                cell_type_transform=self.cell_type_transform
            )

    def _train_dataloader(self):
        return DataLoader(self.train_data,
                          self.batch_size, 
                          shuffle=True, 
                          num_workers=self.num_workers, 
                          pin_memory=True)

    def _val_dataloader(self):
        return DataLoader(self.val_data,
                          self.batch_size, 
                          shuffle=True,
                          num_workers=self.num_workers,
                          pin_memory=True)

    def _test_dataloader(self):
        return DataLoader(self.test_data,
                          self.batch_size, 
                          shuffle=True, 
                          num_workers=self.num_workers, 
                          pin_memory=True)

### Get the data

In [5]:
! wget https://www.dropbox.com/s/db6up7c0d4jwdp4/train_all_classifier_WM20220916.csv.gz
! gunzip -d /content/train_all_classifier_WM20220916.csv.gz

--2022-11-13 15:25:59--  https://www.dropbox.com/s/db6up7c0d4jwdp4/train_all_classifier_WM20220916.csv.gz
Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:601d:18::a27d:512
Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/db6up7c0d4jwdp4/train_all_classifier_WM20220916.csv.gz [following]
--2022-11-13 15:26:00--  https://www.dropbox.com/s/raw/db6up7c0d4jwdp4/train_all_classifier_WM20220916.csv.gz
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucdf184ab69171cb593834733b47.dl.dropboxusercontent.com/cd/0/inline/BwpJaLiPROHOtErUSWehuRtgE188NF5Gff-ZHmPLIoRAPm313DS-iajmccjkGblZeEZR1pzfvcAlBoKhIdzgaAaJaeg6Hz2IBRcbQaa6AhhshUnFX6RvsyT-a1FVvt7ushC5u3CWiUSpIe8peQF1hy1ToCGYMR741MU5W5ZdK4u8pA/file# [following]
--2022-11-13 15:26:01--  https://ucdf184ab69171cb593834733b47.dl.dropboxusercontent.com/cd/0/inline/Bw

In [6]:
data_test = pd.read_csv('/content/train_all_classifier_WM20220916.csv', sep="\t")
data_test

Unnamed: 0.1,Unnamed: 0,seqname,start,end,DHS_width,summit,total_signal,numsamples,numpeaks,C1,...,C10,C11,C12,C13,C14,C15,C16,raw_sequence,component,proportion
0,1241720,chr16,68843660,68843880,220,68843790,122.770678,61,61,0.101076,...,0.005301,0.016703,0.000000,0.000000,0.000000,0.000000,0.000000,GAGGCATTGAAGCTGCTGCTGAGCCCGGGAGGTGAGAGGACGCATC...,0,0.767372
1,2251755,chr3,143634500,143634720,220,143634610,0.780678,1,1,0.000000,...,0.000000,0.000000,0.000000,0.008636,0.000000,0.000000,0.000000,CTCTCCAACTTTTTCCCTGAGTATTGCCAGCACACTTTTAATCTCC...,12,0.869445
2,3136863,chr7,156928220,156928441,221,156928330,145.069295,32,32,0.000000,...,0.046526,0.002177,0.008559,0.000000,0.106442,0.000000,0.000000,CTTCCTGATAAGATCTCAGGAGCTGGGCAAGTGGCTCAAGTATGTG...,13,0.585111
3,2234828,chr3,130738277,130738580,303,130738460,13.140313,10,10,0.000531,...,0.000000,0.000000,0.000000,0.000370,0.000000,0.043161,0.000000,TGAGGAACATAAGCACATAAAATATAATCTAGAAGTTGGTGCTGAG...,14,0.961271
4,3060272,chr7,95784860,95785160,300,95785010,17.523798,7,7,0.000000,...,0.011486,0.000000,0.000000,0.000000,0.036866,0.000000,0.000000,CCAGGTTCTGCCATTCACTTGGGGCCAGCATAAACAAGGGGGCAGG...,13,0.762448
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
159995,2953063,chr7,2602077,2602454,377,2602280,1.058019,2,2,0.000000,...,0.018323,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,CAGAGCTCACTGGACCGGGAAGTGAGGGGAGGGCATCCCAGCAGAG...,9,0.928153
159996,3205276,chr8,52933200,52933440,240,52933320,11.820170,2,2,0.002420,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000282,0.000000,CAGTAAAAGTTTATCACCAGCAGAATGCACTTAAAATATTAAGTGA...,0,0.664603
159997,2024911,chr22,20858820,20859336,516,20859160,5.475445,4,4,0.004008,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,GCAGTGGGGCTCCTCCTTCTGTTTCCCAGACCGAGAGCCGCGCCGG...,5,0.707226
159998,2598954,chr5,72442240,72442480,240,72442350,4.672712,5,5,0.000000,...,0.000000,0.000714,0.003835,0.000000,0.000000,0.000000,0.000000,GAAGTCTCTGGGAAGTGTCCTGGAAGCCACAGAAATGGTGAGTTCT...,3,0.809984


### Encode data and test loader looping

In [7]:
encode_data = SequenceDatasetBase(data_path="./train_all_classifier_WM20220916.csv",
                                                      sequence_length=200, sequence_encoding="polar",
                                                      sequence_transform=None, cell_type_transform=None)

In [8]:
it = iter(encode_data)
s = next(it) # seems to be a tuple with second dimension empty and first dimension of size (4,200)
print(s[0].shape)

(4, 200)


### WGAN-GP

https://www.biorxiv.org/content/10.1101/2022.07.26.501466v1.full.pdf#page=10
Generative Adversarial Network
To train a GAN model, we used Wasserstein GAN architecture with gradient penalty similar to earlier work. 
The model consists of two parts; generator and discriminator. Generator takes noise as input (size is 128), 
followed by a dense layer with 64,000 (500 * 128) units with ELU activation, a reshape layer (500, 128), 
a convolution tower of 5 convolution blocks with skip connections, 
a 1D convolution layer with 4 filters with kernel width 1, and finally a SOFTMAX activation layer. 
The output of the generator is a 500 × 4 matrix, which represents one-hot encoded DNA sequence. 

Discriminator takes 500 bp one-hot encoded DNA sequence as input (real or fake), 
followed by a 1D convolution layer with 128 filters with kernel width 1, 
a convolution tower of 5 convolution blocks with skip connections, a flatten layer, 
and finally a dense layer with 1 unit.
Each block in the convolution tower consists of a RELU activation layer 
followed by 1D convolution with 128 filters with kernel width 5. 
The noise is generated by the numpy.random.normal(0, 1, (batch_size, 128)) command. We used a batch size of 128. 
For every train_on_batch iteration of the generator, we performed 10 train_on_batch iteration for the discriminator. 
We used Adam optimizer with learning_rate of 0.0001, beta_1 of 0.5, and beta_2 of 0.9. 
We trained the models for around 260,000 batch training iteration for KC and 
around 160,000 batch training iteration for MEL.


In [9]:
BATCH_SIZE = 16
SIZE_OF_INPUT = 128
SIZE_OF_FEATURE_MAP = 128
DNA_BP = 200
SIZE_OF_HIDDEN_LAYERS = DNA_BP*SIZE_OF_INPUT
NUM_OF_1D_CONV_FILTERS = 4
NUM_OF_CONV_1D = 5

NUM_EPOCHS = 6000
LEARNING_RATE = 0.7
MOMENTUM = 0.9

USE_CUDA = torch.cuda.is_available()

In [10]:
class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
      
        self.linears = nn.Sequential(
            nn.Linear(SIZE_OF_INPUT, SIZE_OF_HIDDEN_LAYERS),
            nn.ReLU(),  # replace ELU with RELU
            nn.Dropout()
        )

        self.relu=nn.ReLU()
        self.conv_1d= nn.Conv1d(in_channels=128, out_channels=128, kernel_size=5,stride=1, padding=2, dilation=1)
        self.final_conv_1d=nn.Conv1d(in_channels=128,out_channels=4,kernel_size=1)

    def forward(self, x):
        dense_output = self.linears(x)
        conv_1d_input = torch.reshape(dense_output, (dense_output.shape[0], SIZE_OF_INPUT, DNA_BP )) 
      
        for i in range(NUM_OF_CONV_1D):  # a convolution tower of 5 convolution blocks with skip connections           
            if i==0:
              residual = conv_1d_input
              output = F.relu(self.conv_1d(conv_1d_input))
              output += residual
              
            else:
              residual=output
              output = F.relu(self.conv_1d(output))
              output += residual
       
        output = self.final_conv_1d(output)
        output = F.softmax(output,dim=1)
            
        return output

In [11]:
Gen_test=Generator()
normal_dist=torch.randn(BATCH_SIZE,128)
gen_dna=Gen_test(normal_dist) # final output should be (200,4) or (4,200) matrix

Accessible elements are typically around 200bp in length, and using larger regions could conflate things by combining multiple neighboring sites.

For initial testing purpose, we will use @meuleman's 200 bp dna dataset

In [12]:
gen_dna.shape

torch.Size([16, 4, 200])

What they say about the discriminator:


Discriminator	 takes	500	bp	one-hot	encoded	DNA	sequence	as	input	(real or	fake),	followed	by	a	1D	convolution	layer	with	128 filters	 with	kernel	width	1,	a	convolution	tower	of	5	convolution	 blocks	with	 skip	 connections,	a	 flatten	layer,	and	 finally	a	  dense	layer	with	1	unit. Each	 block	 in	 the	 convolution	 tower	 consists	 of	 a	 RELU	 activation	layer	followed	by	1D	convolution	with	128	filters	 with	 kernel	 width	 5.	 

In [13]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.first_conv = nn.Conv1d(in_channels=4,out_channels=128,kernel_size=1)
        self.conv_1d = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=5,stride=1, padding=2, dilation=1)
        self.linear = nn.Sequential(
            nn.Linear(SIZE_OF_FEATURE_MAP*DNA_BP,1),
            nn.ReLU(),  # replace ELU with RELU
            nn.Dropout()
        )

    def forward(self, x):       
        conv_1d_input = self.first_conv(x)

        for i in range(NUM_OF_CONV_1D):  # a convolution tower of 5 convolution blocks with skip connections
          if i==0:
            residual = conv_1d_input
            output = F.relu(self.conv_1d(conv_1d_input))
            output += residual

          else:
            residual = output
            output = F.relu(self.conv_1d(output))
            output += residual

        output = output.view(output.shape[0],-1)
        output = self.linear(output)

        return output

In [14]:
dis=Discriminator()
test_input=dis(gen_dna)
print(test_input.shape)

torch.Size([16, 1])


### Trainer

In [15]:
# See https://zhuanlan.zhihu.com/p/25071913 for a chinese explanation on WGAN-GP
# Reused from https://github.com/EmilienDupont/wgan-gp/blob/ef82364f2a2ec452a52fbf4a739f95039ae76fe3/training.py
class Trainer:
    def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer,
                 gp_weight=10, critic_iterations=5, print_every=50,
                 use_cuda=False):
        self.G = generator
        self.G_opt = gen_optimizer
        self.D = discriminator
        self.D_opt = dis_optimizer
        self.losses = {'G': [], 'D': [], 'GP': [], 'gradient_norm': []}
        self.num_steps = 0
        self.use_cuda = use_cuda
        self.gp_weight = gp_weight
        self.critic_iterations = critic_iterations
        self.print_every = print_every

        if self.use_cuda:
            self.G.cuda()
            self.D.cuda()

    def _critic_train_iteration(self, data):
        """ """
        # Get generated data
        batch_size = data.size()[0]
        generated_data = self.sample_generator(batch_size)

        # Calculate probabilities on real and generated data
        data = Variable(data)
        if self.use_cuda:
            data = data.cuda()
        d_real = self.D(data)
        d_generated = self.D(generated_data)

        # Get gradient penalty
        gradient_penalty = self._gradient_penalty(data, generated_data)
        self.losses['GP'].append(gradient_penalty.data[0])

        # Create total loss and optimize
        self.D_opt.zero_grad()
        d_loss = d_generated.mean() - d_real.mean() + gradient_penalty
        d_loss.backward()

        self.D_opt.step()

        # Record loss
        self.losses['D'].append(d_loss.data[0])

    def _generator_train_iteration(self, data):
        """ """
        self.G_opt.zero_grad()

        # Get generated data
        batch_size = data.size()[0]
        generated_data = self.sample_generator(batch_size)

        # Calculate loss and optimize
        d_generated = self.D(generated_data)
        g_loss = - d_generated.mean()
        g_loss.backward()
        self.G_opt.step()

        # Record loss
        self.losses['G'].append(g_loss.data[0])

    def _gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # https://ai.stackexchange.com/questions/34926/why-do-we-use-a-linear-interpolation-of-fake-and-real-data-to-penalize-the-gradi
        # Calculates interpolation
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()
        interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.D(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                               grad_outputs=torch.ones(prob_interpolated.size()).cuda()
                               if self.use_cuda else torch.ones(prob_interpolated.size()),
                               create_graph=True, retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)
        self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data[0])

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1) ** 2).mean()

    def _train_epoch(self, data_loader):
        for i, data in enumerate(data_loader):
            self.num_steps += 1
            self._critic_train_iteration(data[0])
            # Only update generator every |critic_iterations| iterations
            if self.num_steps % self.critic_iterations == 0:
                self._generator_train_iteration(data[0])

            if i % self.print_every == 0:
                print("Iteration {}".format(i + 1))
                print("D: {}".format(self.losses['D'][-1]))
                print("GP: {}".format(self.losses['GP'][-1]))
                print("Gradient norm: {}".format(self.losses['gradient_norm'][-1]))
                if self.num_steps > self.critic_iterations:
                    print("G: {}".format(self.losses['G'][-1]))

    def train(self, data_loader, epochs, save_training_gif=True):
        if save_training_gif:
            # Fix latents to see how image generation improves during training
            fixed_latents = Variable(self.G.sample_latent(64))
            if self.use_cuda:
                fixed_latents = fixed_latents.cuda()
            training_progress_images = []

        for epoch in range(epochs):
            print("\nEpoch {}".format(epoch + 1))
            self._train_epoch(data_loader)

            if save_training_gif:
                # Generate batch of images and convert to grid
                img_grid = make_grid(self.G(fixed_latents).cpu().data)
                # Convert to numpy and transpose axes to fit imageio convention
                # i.e. (width, height, channels)
                img_grid = np.transpose(img_grid.numpy(), (1, 2, 0))
                # Add image grid to training progress
                training_progress_images.append(img_grid)

        if save_training_gif:
            imageio.mimsave('./training_{}_epochs.gif'.format(epochs),
                            training_progress_images)

    def sample_generator(self, num_samples):
        latent_samples = Variable(self.G.sample_latent(num_samples))
        if self.use_cuda:
            latent_samples = latent_samples.cuda()
        generated_data = self.G(latent_samples)
        return generated_data

    def sample(self, num_samples):
        generated_data = self.sample_generator(num_samples)
        # Remove color channel
        return generated_data.data.cpu().numpy()[:, 0, :, :]


### Testing network