### Install and download

In [None]:
!nvidia-smi

Fri Nov 18 07:26:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

### Importing

In [None]:
!pip install einops
!pip install pytorch_lightning
import os; os.getpid()
from scipy.stats import zscore
import torch
import copy
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch
from IPython.display import display
import torch.nn as nn
from torchvision.utils import make_grid
from torchvision.utils import save_image
from IPython.display import Image
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
from tqdm import tqdm_notebook
import seaborn as sns
import matplotlib.pyplot as plt
from torch.nn.modules.activation import ReLU
import torch.optim as optim
from torch.optim import Adam
from tqdm import tqdm_notebook
from torchvision.utils import save_image
import matplotlib
import math
from inspect import isfunction
from functools import partial
import scipy
from scipy.special import rel_entr
from torch import nn, einsum
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange
from torch import nn, einsum
import torch.nn.functional as F
import matplotlib.animation as animation
import matplotlib.image as mpimg
import glob
from PIL import Image
import pytorch_lightning as pl 
import imageio
from torch.autograd import Variable
from torch.autograd import grad as torch_grad

%matplotlib inline

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 372 kB/s 
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.8.2-py3-none-any.whl (798 kB)
[K     |████████████████████████████████| 798 kB 5.1 MB/s 
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.10.3-py3-none-any.whl (529 kB)
[K     |████████████████████████████████| 529 kB 60.3 MB/s 
Collecting lightning-utilities==0.3.*
  Downloading lightning_utilities-0.3.0-py3-none-any.whl (15 kB)
Collecting fire
  Downloading fire-0.4.0.tar.gz (87 kB)
[K     |████████████████████████████████| 87 kB 8.0 MB/s 
Building wheels for collected packages: fire
  Building wheel for fire

### Dataset 

In [None]:
class SequenceDatasetBase(Dataset):
    def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", sequence_transform=None, cell_type_transform=None):
        super().__init__()
        self.data = pd.read_csv(data_path, sep="\t")
        self.sequence_length = sequence_length
        self.sequence_encoding = sequence_encoding
        self.sequence_transform = sequence_transform
        self.cell_type_transform = cell_type_transform
        self.alphabet = ["A", "C", "T", "G"]
        self.check_data_validity()

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # Iterating through DNA sequences from dataset and one-hot encoding all nucleotides
        current_seq = self.data["raw_sequence"][index]
        if 'N' not in current_seq: 
            X_seq = self.encode_sequence(current_seq, encoding=self.sequence_encoding)
            
            # Reading cell component at current index
            X_cell_type = self.data["component"][index]
            
            if self.sequence_transform is not None:
                X_seq = self.sequence_transform(X_seq)
            if self.cell_type_transform is not None:
                X_cell_type = self.cell_type_transform(X_cell_type)

            return X_seq, X_cell_type

    def check_data_validity(self):
        """
        Checks if the data is valid.
        """
        if not set("".join(self.data["raw_sequence"])).issubset(set(self.alphabet)):
            raise ValueError(f"Sequence contains invalid characters.")

        uniq_raw_seq_len = self.data["raw_sequence"].str.len().unique()
        if len(uniq_raw_seq_len) != 1 or uniq_raw_seq_len[0] != self.sequence_length:
            raise ValueError(f"The sequence length does not match the data.")

    def encode_sequence(self, seq, encoding):
        """
        Encodes a sequence using the given encoding scheme ("polar", "onehot", "ordinal").
        """
        if encoding == "polar":
            seq = self.one_hot_encode(seq).T
            seq[seq == 0] = -1
        elif encoding == "onehot":
            seq = self.one_hot_encode(seq).T
        elif encoding == "ordinal":
            seq = np.array([self.alphabet.index(n) for n in seq])
        else:
            raise ValueError(f"Unknown encoding scheme: {encoding}")
        return seq

    # Function for one hot encoding each line of the sequence dataset
    def one_hot_encode(self, seq):
        """
        One-hot encoding a sequence
        """
        seq_len = len(seq)
        seq_array = np.zeros((self.sequence_length, len(self.alphabet)))
        for i in range(seq_len):
            seq_array[i, self.alphabet.index(seq[i])] = 1
        return seq_array




In [None]:
class SequenceDatasetTrain(SequenceDatasetBase):
    def __init__(self, data_path="", **kwargs):
        super().__init__(data_path=data_path, **kwargs)

class SequenceDatasetValidation(SequenceDatasetBase):
    def __init__(self, data_path="", **kwargs):
        super().__init__(data_path=data_path, **kwargs)

class SequenceDatasetTest(SequenceDatasetBase):
    def __init__(self, data_path="", **kwargs):
        super().__init__(data_path=data_path, **kwargs)


class SequenceDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_path=None,
        val_path=None,
        test_path=None,
        sequence_length=200,
        sequence_encoding="polar",
        sequence_transform=None,
        cell_type_transform=None,
        batch_size=None,
        num_workers=1
    ):
        super().__init__()
        self.datasets = dict()
        self.train_dataloader, self.val_dataloader, self.test_dataloader = None, None, None

        if train_path:
            self.datasets["train"] = train_path
            self.train_dataloader = self._train_dataloader

        if val_path:
            self.datasets["validation"] = val_path
            self.val_dataloader = self._val_dataloader

        if test_path:
            self.datasets["test"] = test_path
            self.test_dataloader = self._test_dataloader

        self.sequence_length = sequence_length
        self.sequence_encoding = sequence_encoding
        self.sequence_transform = sequence_transform
        self.cell_type_transform = cell_type_transform
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self):
        if "train" in self.datasets:
            self.train_data = SequenceDatasetTrain(
                data_path=self.datasets["train"],
                sequence_length=self.sequence_length,
                sequence_encoding=self.sequence_encoding,
                sequence_transform=self.sequence_transform,
                cell_type_transform=self.cell_type_transform
            )
        if "validation" in self.datasets:
            self.val_data = SequenceDatasetValidation(
                data_path=self.datasets["validation"],
                sequence_length=self.sequence_length,
                sequence_encoding=self.sequence_encoding,
                sequence_transform=self.sequence_transform,
                cell_type_transform=self.cell_type_transform
            )
        if "test" in self.datasets:
            self.test_data = SequenceDatasetTest(
                data_path=self.datasets["test"],
                sequence_length=self.sequence_length,
                sequence_encoding=self.sequence_encoding,
                sequence_transform=self.sequence_transform,
                cell_type_transform=self.cell_type_transform
            )

    def _train_dataloader(self):
        return DataLoader(self.train_data,
                          self.batch_size, 
                          shuffle=True, 
                          num_workers=self.num_workers, 
                          pin_memory=True)

    def _val_dataloader(self):
        return DataLoader(self.val_data,
                          self.batch_size, 
                          shuffle=True,
                          num_workers=self.num_workers,
                          pin_memory=True)

    def _test_dataloader(self):
        return DataLoader(self.test_data,
                          self.batch_size, 
                          shuffle=True, 
                          num_workers=self.num_workers, 
                          pin_memory=True)

### Get the data

We will use the dataset from https://www.meuleman.org/research/synthseqs/#material

training set: 160k sequences, 10k per NMF component (chr3-chrY)

In [None]:
! wget https://www.meuleman.org/train_all_classifier_light.csv.gz
! gunzip -d /content/train_all_classifier_light.csv.gz

--2022-11-18 07:26:24--  https://www.meuleman.org/train_all_classifier_light.csv.gz
Resolving www.meuleman.org (www.meuleman.org)... 185.199.110.153, 185.199.111.153, 185.199.108.153, ...
Connecting to www.meuleman.org (www.meuleman.org)|185.199.110.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 15316727 (15M) [application/gzip]
Saving to: ‘train_all_classifier_light.csv.gz’


2022-11-18 07:26:24 (128 MB/s) - ‘train_all_classifier_light.csv.gz’ saved [15316727/15316727]



validation set: 16k sequences, 1k per NMF component (chr2 only)

In [None]:
! wget https://www.meuleman.org/validation_all_classifier_light.csv.gz
! gunzip -d /content/validation_all_classifier_light.csv.gz

--2022-11-18 07:26:24--  https://www.meuleman.org/validation_all_classifier_light.csv.gz
Resolving www.meuleman.org (www.meuleman.org)... 185.199.110.153, 185.199.111.153, 185.199.108.153, ...
Connecting to www.meuleman.org (www.meuleman.org)|185.199.110.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1517659 (1.4M) [application/gzip]
Saving to: ‘validation_all_classifier_light.csv.gz’


2022-11-18 07:26:25 (26.3 MB/s) - ‘validation_all_classifier_light.csv.gz’ saved [1517659/1517659]



test set: 16k sequences, 1k per NMF component (chr1 only)

In [None]:
! wget https://www.meuleman.org/test_all_classifier_light.csv.gz
! gunzip -d /content/test_all_classifier_light.csv.gz

--2022-11-18 07:26:25--  https://www.meuleman.org/test_all_classifier_light.csv.gz
Resolving www.meuleman.org (www.meuleman.org)... 185.199.110.153, 185.199.111.153, 185.199.108.153, ...
Connecting to www.meuleman.org (www.meuleman.org)|185.199.110.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1516076 (1.4M) [application/gzip]
Saving to: ‘test_all_classifier_light.csv.gz’


2022-11-18 07:26:25 (24.2 MB/s) - ‘test_all_classifier_light.csv.gz’ saved [1516076/1516076]



Previews some of the data inside the test set

In [None]:
data_preview = pd.read_csv('/content/test_all_classifier_light.csv', sep="\t")
data_preview

Unnamed: 0,seqname,start,end,DHS_width,summit,total_signal,numsamples,raw_sequence,component,proportion
0,chr1,6281480,6281920,440,6281630,8.114064,3,TCAAGCCCCCGCCCAGGCGGGCTCTCTCCTGGCCGGGAGTGGCAGC...,12,0.734255
1,chr1,167472820,167473100,280,167472970,9.213141,12,CATGGCCTAGAGAGGATTCTTTGTGTGTCCACACCTGTGTTGCCTG...,10,0.696706
2,chr1,203479740,203479940,200,203479840,11.202431,14,GGAGTCTCTCTAGAGAATCTGCTGTTTATAAACAAATAAATGAGTA...,4,0.975903
3,chr1,87720660,87720860,200,87720770,19.890169,23,TTCCATTCTTTTTGAACTTACTCTCTACCCCGGAAGAATGACAACA...,6,0.792214
4,chr1,151313200,151313400,200,151313290,0.507441,1,AATGAATTCAGGTATTTCATTCTGTCAGTATCAGATAACGCAGGAG...,13,0.869445
...,...,...,...,...,...,...,...,...,...,...
15995,chr1,86616560,86616760,200,86616670,6.529608,8,TAACTTTAAAAAAAAAAAAAAAAAAGAGCTGGGCATGCTGGGAACA...,13,0.887900
15996,chr1,100233640,100234000,360,100233910,0.918690,1,TTAAAAGAGGCAAAGGTAGAGGAGAACAAAGGAAGGAGGAAGTAAC...,16,1.000000
15997,chr1,2072180,2072600,420,2072430,1.173570,1,GTTCAGGCAGGTGTGGGAGGCCAGCCATCAGGAGATGATGCCGTTG...,10,0.759222
15998,chr1,90817520,90817860,340,90817660,12.766754,10,CTGCTTCCTCCACATCTGTCTCCTTCAATGGTATATCATCACCACC...,6,0.800863


### Encode data and test loader looping

In [None]:
BATCH_SIZE = 16
NUM_EPOCHS = 6000
LEARNING_RATE = 0.7
MOMENTUM = 0.9

USE_CUDA = torch.cuda.is_available()

In [None]:
encoded_data = SequenceDataModule(
        train_path = "/content/train_all_classifier_light.csv",
        val_path = "/content/validation_all_classifier_light.csv",
        test_path = "/content/test_all_classifier_light.csv",
        sequence_length = 200,
        sequence_encoding = "polar",
        sequence_transform = None,
        cell_type_transform = None,
        batch_size = BATCH_SIZE,
        num_workers = 0
    )

In [None]:
encoded_data.setup()

In [None]:
print(len(encoded_data.train_data))
print(len(encoded_data.val_data))
print(len(encoded_data.test_data))

160000
16000
16000


In [None]:
train_loader=encoded_data.train_dataloader()
val_loader=encoded_data.val_dataloader()
test_loader=encoded_data.test_dataloader()

In [None]:
s=iter(train_loader)
l=next(s)

print(l)

[tensor([[[ 1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1.,  1.,  1.],
         [-1.,  1.,  1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ...,  1., -1., -1.]],

        [[ 1.,  1., -1.,  ..., -1., -1., -1.],
         [-1., -1.,  1.,  ...,  1., -1., -1.],
         [-1., -1., -1.,  ..., -1.,  1.,  1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]],

        [[-1., -1., -1.,  ..., -1., -1., -1.],
         [ 1., -1.,  1.,  ..., -1., -1.,  1.],
         [-1.,  1., -1.,  ...,  1.,  1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]],

        ...,

        [[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ...,  1., -1., -1.],
         [-1.,  1.,  1.,  ..., -1.,  1., -1.],
         [ 1., -1., -1.,  ..., -1., -1.,  1.]],

        [[-1.,  1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1.,  1.],
         [-1., -1.,  1.,  ..., -1.,  1., -1.],
         [ 1., -1., -1.,  ...,  1., -1., -1.]],

        [[-1.,  1., -1.,  ...,  1.,

In [None]:
for index,(x,y) in enumerate(train_loader):
  print(x.shape) #(bs,4,len)
  print(y.shape) # (bs)
  break

torch.Size([16, 4, 200])
torch.Size([16])


### WGAN-GP

https://www.biorxiv.org/content/10.1101/2022.07.26.501466v1.full.pdf#page=10
Generative Adversarial Network
To train a GAN model, we used Wasserstein GAN architecture with gradient penalty similar to earlier work. 
The model consists of two parts; generator and discriminator. Generator takes noise as input (size is 128), 
followed by a dense layer with 64,000 (500 * 128) units with ELU activation, a reshape layer (500, 128), 
a convolution tower of 5 convolution blocks with skip connections, 
a 1D convolution layer with 4 filters with kernel width 1, and finally a SOFTMAX activation layer. 
The output of the generator is a 500 × 4 matrix, which represents one-hot encoded DNA sequence. 

Discriminator takes 500 bp one-hot encoded DNA sequence as input (real or fake), 
followed by a 1D convolution layer with 128 filters with kernel width 1, 
a convolution tower of 5 convolution blocks with skip connections, a flatten layer, 
and finally a dense layer with 1 unit.
Each block in the convolution tower consists of a RELU activation layer 
followed by 1D convolution with 128 filters with kernel width 5. 
The noise is generated by the numpy.random.normal(0, 1, (batch_size, 128)) command. We used a batch size of 128. 
For every train_on_batch iteration of the generator, we performed 10 train_on_batch iteration for the discriminator. 
We used Adam optimizer with learning_rate of 0.0001, beta_1 of 0.5, and beta_2 of 0.9. 
We trained the models for around 260,000 batch training iteration for KC and 
around 160,000 batch training iteration for MEL.


In [None]:
SIZE_OF_NOISE_INPUT = 128
SIZE_OF_FEATURE_MAP = 128
DNA_BP = 200
SIZE_OF_HIDDEN_LAYERS = DNA_BP*SIZE_OF_NOISE_INPUT
NUM_OF_1D_CONV_FILTERS = 4
NUM_OF_CONV_1D = 5

In [None]:
class Generator(nn.Module):

    def __init__(self, latent_dim=128):
        super(Generator, self).__init__()

        self.latent_dim = latent_dim

        self.linears = nn.Sequential(
            nn.Linear(SIZE_OF_NOISE_INPUT, SIZE_OF_HIDDEN_LAYERS),
            nn.ReLU()  # replace ELU with RELU
            #nn.Dropout()
        )

        self.relu=nn.ReLU()
        self.conv_1d= nn.Conv1d(in_channels=128, out_channels=128, kernel_size=5,stride=1, padding=2, dilation=1)
        self.final_conv_1d=nn.Conv1d(in_channels=128,out_channels=4,kernel_size=1)

    def forward(self, x):
        dense_output = self.linears(x)
        conv_1d_input = torch.reshape(dense_output, (dense_output.shape[0], SIZE_OF_NOISE_INPUT, DNA_BP )) 
      
        for i in range(NUM_OF_CONV_1D):  # a convolution tower of 5 convolution blocks with skip connections           
            if i==0:
              residual = conv_1d_input
              output = F.relu(self.conv_1d(conv_1d_input))
              output = output.clone()
              output += residual
              
            else:
              residual=output
              output = F.relu(self.conv_1d(output))
              output = output.clone()
              output += residual
       
        output = self.final_conv_1d(output)
        output = F.softmax(output,dim=1)
            
        return output

    def sample_latent(self, num_samples):
        return torch.randn((num_samples, self.latent_dim))

In [None]:
Gen_test=Generator()
normal_dist=torch.randn(BATCH_SIZE,128)
gen_dna=Gen_test(normal_dist) # final output should be (200,4) or (4,200) matrix

Accessible elements are typically around 200bp in length, and using larger regions could conflate things by combining multiple neighboring sites.

For initial testing purpose, we will use @meuleman's 200 bp dna dataset

In [None]:
gen_dna.shape

torch.Size([16, 4, 200])

What they say about the discriminator:


Discriminator	 takes	500	bp	one-hot	encoded	DNA	sequence	as	input	(real or	fake),	followed	by	a	1D	convolution	layer	with	128 filters	 with	kernel	width	1,	a	convolution	tower	of	5	convolution	 blocks	with	 skip	 connections,	a	 flatten	layer,	and	 finally	a	  dense	layer	with	1	unit. Each	 block	 in	 the	 convolution	 tower	 consists	 of	 a	 RELU	 activation	layer	followed	by	1D	convolution	with	128	filters	 with	 kernel	 width	 5.	 

In [None]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.first_conv = nn.Conv1d(in_channels=4,out_channels=128,kernel_size=1)
        self.conv_1d = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=5,stride=1, padding=2, dilation=1)
        self.linear = nn.Sequential(
            nn.Linear(SIZE_OF_FEATURE_MAP*DNA_BP,1),
            nn.ReLU()  # replace ELU with RELU
            #nn.Dropout()
        )

    def forward(self, x):       
        conv_1d_input = self.first_conv(x)

        for i in range(NUM_OF_CONV_1D):  # a convolution tower of 5 convolution blocks with skip connections
          if i==0:
            residual = conv_1d_input
            output = F.relu(self.conv_1d(conv_1d_input))
            output = output.clone()
            output += residual

          else:
            residual = output
            output = F.relu(self.conv_1d(output))
            output = output.clone()
            output += residual

        output = output.view(output.shape[0],-1)
        output = self.linear(output)

        return output

In [None]:
dis=Discriminator()
test_input=dis(gen_dna)
print(test_input.shape)

torch.Size([16, 1])


### Trainer

In [None]:
# See https://zhuanlan.zhihu.com/p/25071913 for a chinese explanation on WGAN-GP
# Reused from https://github.com/EmilienDupont/wgan-gp/blob/ef82364f2a2ec452a52fbf4a739f95039ae76fe3/training.py
class Trainer:
    def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer,
                 gp_weight=10, critic_iterations=5, print_every=50,
                 use_cuda=False):
        self.G = generator
        self.G_opt = gen_optimizer
        self.D = discriminator
        self.D_opt = dis_optimizer
        self.losses = {'G': [], 'D': [], 'GP': [], 'gradient_norm': []}
        self.num_steps = 0
        self.use_cuda = use_cuda
        self.gp_weight = gp_weight
        self.critic_iterations = critic_iterations
        self.print_every = print_every

        if self.use_cuda:
            self.G.cuda()
            self.D.cuda()

    def _critic_train_iteration(self, data):
        """ """
        # Get generated data
        batch_size = data.size()[0]
        generated_data = self.sample_generator(batch_size)

        # Calculate probabilities on real and generated data
        data = Variable(data).float()
        if self.use_cuda:
            data = data.cuda()
        d_real = self.D(data)
        d_generated = self.D(generated_data)

        # Get gradient penalty
        gradient_penalty = self._gradient_penalty(data, generated_data)
        self.losses['GP'].append(gradient_penalty)#.data[0])

        # Create total loss and optimize
        self.D_opt.zero_grad()
        d_loss = d_generated.mean() - d_real.mean() + gradient_penalty
        d_loss.backward()

        self.D_opt.step()

        # Record loss
        self.losses['D'].append(d_loss)#.data[0])

    def _generator_train_iteration(self, data):
        """ """
        self.G_opt.zero_grad()

        # Get generated data
        batch_size = data.size()[0]
        generated_data = self.sample_generator(batch_size)

        # Calculate loss and optimize
        d_generated = self.D(generated_data)
        g_loss = - d_generated.mean()
        g_loss.backward()
        self.G_opt.step()

        # Record loss
        self.losses['G'].append(g_loss)#.data[0])

    def _gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # https://ai.stackexchange.com/questions/34926/why-do-we-use-a-linear-interpolation-of-fake-and-real-data-to-penalize-the-gradi
        # Calculates interpolation
        alpha = torch.rand(batch_size, 1, 1)
        alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()
        interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.D(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                               grad_outputs=torch.ones(prob_interpolated.size()).cuda()
                               if self.use_cuda else torch.ones(prob_interpolated.size()),
                               create_graph=True, retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)
        self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean())#.data[0])

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1) ** 2).mean()

    def _train_epoch(self, data_loader):
        for i, data in enumerate(data_loader):
            self.num_steps += 1
            self._critic_train_iteration(data[0])
            # Only update generator every |critic_iterations| iterations
            if self.num_steps % self.critic_iterations == 0:
                self._generator_train_iteration(data[0])

            if i % self.print_every == 0:
                print("Iteration {}".format(i + 1))
                print("D: {}".format(self.losses['D'][-1]))
                print("GP: {}".format(self.losses['GP'][-1]))
                print("Gradient norm: {}".format(self.losses['gradient_norm'][-1]))
                if self.num_steps > self.critic_iterations:
                    print("G: {}".format(self.losses['G'][-1]))

    def train(self, data_loader, epochs, save_training_gif=True):
        if save_training_gif:
            # Fix latents to see how image generation improves during training
            fixed_latents = Variable(self.G.sample_latent(64))
            if self.use_cuda:
                fixed_latents = fixed_latents.cuda()
            training_progress_images = []

        for epoch in range(epochs):
            print("\nEpoch {}".format(epoch + 1))
            self._train_epoch(data_loader)

            if save_training_gif:
                # Generate batch of images and convert to grid
                img_grid = make_grid(self.G(fixed_latents).cpu().data)
                # Convert to numpy and transpose axes to fit imageio convention
                # i.e. (width, height, channels)
                img_grid = np.transpose(img_grid.numpy(), (1, 2, 0))
                # Add image grid to training progress
                training_progress_images.append(img_grid)

        if save_training_gif:
            imageio.mimsave('./training_{}_epochs.gif'.format(epochs),
                            training_progress_images)

    def sample_generator(self, num_samples):
        latent_samples = Variable(self.G.sample_latent(num_samples))
        if self.use_cuda:
            latent_samples = latent_samples.cuda()
        #print("size of latent_samples = ", latent_samples.shape)
        generated_data = self.G(latent_samples)
        #print("size of generated_data = ", generated_data.shape)
        return generated_data

    def sample(self, num_samples):
        generated_data = self.sample_generator(num_samples)
        # Remove color channel
        return generated_data.data.cpu().numpy()#[:, 0, :, :]


### Testing network

In [None]:
generator=Generator(latent_dim = SIZE_OF_NOISE_INPUT)
discriminator=Discriminator()

print(generator)
print(discriminator)

Generator(
  (linears): Sequential(
    (0): Linear(in_features=128, out_features=25600, bias=True)
    (1): ReLU()
  )
  (relu): ReLU()
  (conv_1d): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
  (final_conv_1d): Conv1d(128, 4, kernel_size=(1,), stride=(1,))
)
Discriminator(
  (first_conv): Conv1d(4, 128, kernel_size=(1,), stride=(1,))
  (conv_1d): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
  (linear): Sequential(
    (0): Linear(in_features=25600, out_features=1, bias=True)
    (1): ReLU()
  )
)


In [None]:
gen_optimizer=torch.optim.Adam(generator.parameters(), amsgrad=True, lr=0.001)
dis_optimizer=torch.optim.Adam(discriminator.parameters(), amsgrad=True, lr=0.001)

In [None]:
train_loader = encoded_data.train_dataloader()
val_loader = encoded_data.val_dataloader()
test_loader = encoded_data.test_dataloader()

In [None]:
# Train model
trainer = Trainer(generator,discriminator,gen_optimizer,dis_optimizer,use_cuda=USE_CUDA)

In [None]:
# 'save_training_gif = False' is sufficient for now, let's get some genomic person to check this part later
# needs to adapt 'fixed_latents' stuff to plot intermediary figures of the generated DNA sequences
epochs = 10
trainer.train(train_loader, epochs, save_training_gif = False)


Epoch 1
Iteration 1
D: 1.2670364379882812
GP: 1.2819125652313232
Gradient norm: 0.8303024768829346
Iteration 51
D: -48.8807487487793
GP: 13.572396278381348
Gradient norm: 2.1551895141601562
G: -30.87215232849121
Iteration 101
D: -49.768341064453125
GP: 22.538537979125977
Gradient norm: 2.4976038932800293
G: -18.478038787841797
Iteration 151
D: -43.333885192871094
GP: 21.27605438232422
Gradient norm: 2.4570722579956055
G: -12.298599243164062
Iteration 201
D: -48.06883239746094
GP: 19.276140213012695
Gradient norm: 2.38116455078125
G: -2.7971625328063965
Iteration 251
D: -42.80603790283203
GP: 23.888444900512695
Gradient norm: 2.5449233055114746
G: -27.403606414794922
Iteration 301
D: -46.64417266845703
GP: 14.609437942504883
Gradient norm: 2.202986717224121
G: -0.9619186520576477
Iteration 351
D: -42.83174514770508
GP: 17.83197784423828
Gradient norm: 2.3349967002868652
G: -8.67154312133789
Iteration 401
D: -46.29710388183594
GP: 20.61407470703125
Gradient norm: 2.428495407104492
G: -1

In [None]:
# Save models
name = 'dna_model'
torch.save(trainer.G.state_dict(), './gen_' + name + '.pt')
torch.save(trainer.D.state_dict(), './dis_' + name + '.pt')