In [None]:
!pip install pybedtools
!pip install Bio

In [None]:
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import pybedtools
import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.utils.data
from torchsummary import summary
import logging
from torch import autograd
from torch import optim
import pprint
from sklearn.neighbors import KDTree
from sklearn.neighbors import NearestNeighbors
import editdistance
import argparse
import json
import datetime
import pickle as pk
from sklearn import metrics
from torch.utils.data import TensorDataset, DataLoader
import math

In [None]:
codes = {
 'A': [1., 0., 0., 0., 0.],
 'T': [0., 1., 0., 0., 0.],
 'G': [0., 0., 1., 0., 0.],
 'C': [0., 0., 0., 1., 0.],
 'N': [0., 0., 0., 0., 1.],
 }

In [None]:
def KL_divergence(real: pd.Series, generated: pd.Series) -> torch.Tensor:
    """
    This function encapsulates the logic of evaluating the KL divergence metric
    between two sequences.
    Returns
    -------
    kl_divergence: Float
      The KL divergence between the input and output (generated)
      sequences' distribution
    """

    kl_pq = rel_entr(real, generated)
    return np.sum(kl_pq)

In [None]:
def diversity(
    generated, real, scoring_metric=KL_divergence, plot_motif_probs=False
):
    """
    This function encapsulates the logic of evaluating the difference between the distribution
    of frequencies between generated (diffusion/df_motifs_a) and the input (training/df_motifs_b) for an arbitrary metric ("motif_scoring_metric")

    Please note that some metrics, like KL_divergence, are not metrics in official sense. Reason
    for that is that they dont satisfy certain properties, such as in KL case, the simmetry property.
    Hence it makes a big difference what are the positions of input.
    """
    set_all_data = set(generated.index.values.tolist() + real.index.values.tolist())
    create_new_matrix = []
    for x in set_all_data:
        list_in = []
        list_in.append(x)  # adding the name
        if x in generated.index:
            list_in.append(generated.loc[x][0])
        else:
            list_in.append(1)

        if x in real.index:
            list_in.append(real.loc[x][0])
        else:
            list_in.append(1)

        create_new_matrix.append(list_in)

    df_kl = pd.DataFrame(create_new_matrix, columns=['Flipon', 'generated', 'real'])

    df_kl['Generated_seqs'] = df_kl['generated'] / df_kl['generated'].sum()
    df_kl['Training_seqs'] = df_kl['real'] / df_kl['real'].sum()
    plt.rcParams["figure.figsize"] = (3, 3)
    sns.regplot(x='Generated_seqs', y='Training_seqs', data=df_kl)
    plt.xlabel('Generated_seqs')
    plt.ylabel('Training Seqs')
    plt.title('Motifs Probs')
    plt.show()

    return scoring_metric(df_kl['Genereated_seqs'].values, df_kl['Training_seqs'].values)

In [None]:
def one_hot_encoding(sequence, target_length):
    """Convert DNA sequence to one-hot encoding."""
    nucleotides = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    one_hot_seq = np.zeros((target_length, len(nucleotides)))

    for i, nucleotide in enumerate(sequence[:target_length]):
        if nucleotide in nucleotides:
            one_hot_seq[i, nucleotides[nucleotide]] = 1

    # Flatten the one-hot encoding to a 1D vector
    one_hot_flat = one_hot_seq.flatten()

    return one_hot_flat

def preprocess_sequences(sequences, target_length):
    """Pad or truncate sequences to the target length."""
    return [seq.ljust(target_length, 'N')[:target_length] for seq in sequences]

def calculate_diversity_optimized_one_hot(dataset, target_length):
    """Calculate diversity within a single dataset using KDTree and one-hot encoding."""
    diversity = 0.0
    total_pairs = 0

    sequences = preprocess_sequences(dataset, target_length)

    sequences_one_hot = [one_hot_encoding(seq, target_length) for seq in sequences]

    kdtree = KDTree(sequences_one_hot)

    for i, seq_i in enumerate(sequences_one_hot):
        for j, seq_j in enumerate(sequences_one_hot):
            if i != j:

                edit_dist_ij = editdistance.eval(seq_i, seq_j)
                diversity += edit_dist_ij
                total_pairs += 1

    if total_pairs > 0:
        diversity /= total_pairs

    return diversity

In [None]:
def one_hot_encoding(sequence, target_length):
    """Convert DNA sequence to one-hot encoding."""
    nucleotides = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    one_hot_seq = np.zeros((target_length, len(nucleotides)))

    for i, nucleotide in enumerate(sequence[:target_length]):
        if nucleotide in nucleotides:
            one_hot_seq[i, nucleotides[nucleotide]] = 1

    # Flatten the one-hot encoding to a 1D vector
    one_hot_flat = one_hot_seq.flatten()

    return one_hot_flat

def preprocess_sequences(sequences, target_length):
    """Pad or truncate sequences to the target length."""
    return [seq.ljust(target_length, 'N')[:target_length] for seq in sequences]

def calculate_novelty_optimized_one_hot(generated_sequences, initial_sequences, target_length):
    """Calculate novelty of the generated sequences using KDTree and one-hot encoding."""
    novelty = 0.0
    total_pairs = 0

    initial_sequences = preprocess_sequences(initial_sequences, target_length)
    generated_sequences = preprocess_sequences(generated_sequences, target_length)

    initial_sequences_one_hot = [one_hot_encoding(seq, target_length) for seq in initial_sequences]
    generated_sequences_one_hot = [one_hot_encoding(seq, target_length) for seq in generated_sequences]

    kdtree = KDTree(initial_sequences_one_hot)

    for gen_seq_one_hot in generated_sequences_one_hot:

        distance, _ = kdtree.query(gen_seq_one_hot)
        novelty += distance
        total_pairs += 1

    if total_pairs > 0:
        novelty /= total_pairs

    return novelty

#KAN

Linear KAN

In [None]:
class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=True):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

kan conv

In [None]:
#Util
def add_padding_1d(array: np.ndarray, padding: int) -> np.ndarray:
    """Adds padding to a 1D array."""
    n = array.shape[0]
    padded_array = np.zeros(n + 2 * padding)
    padded_array[padding: n + padding] = array
    return padded_array


def calc_out_dims_1d(array, kernel_size, stride, dilation, padding):
    """Calculate output dimensions for 1D convolution."""
    batch_size, n_channels, n = matrix.shape
    out_size = np.floor((n + 2 * padding - kernel_size - (kernel_size - 1) * (dilation - 1)) / stride).astype(int) + 1
    return out_size, batch_size, n_channels


def multiple_convs_kan_conv1d(array,
                               kernels,
                               kernel_size,
                               out_channels,
                               stride=1,
                               dilation=1,
                               padding=0,
                               device="cuda") -> torch.Tensor:
    """Performs a 1D convolution with multiple kernels on the input array using specified stride, dilation, and padding.

    Args:
        array (torch.Tensor): 1D tensor of shape (batch_size, channels, length).
        kernels (list): List of kernel functions to be applied.
        kernel_size (int): Size of the 1D kernel.
        out_channels (int): Number of output channels.
        stride (int): Stride along the length of the array. Default is 1.
        dilation (int): Dilation rate along the length of the array. Default is 1.
        padding (int): Number of elements to pad on each side. Default is 0.
        device (str): Device to perform calculations on. Default is "cuda".

    Returns:
        torch.Tensor: Feature map after convolution with shape (batch_size, out_channels, length_out).
    """
    length_out, batch_size = calc_out_dims_1d(array, kernel_size, stride, dilation, padding)
    n_convs = len(kernels)

    array_out = torch.zeros((batch_size, out_channels, length_out)).to(device)

    array = F.pad(array, (padding, padding), mode='constant', value=0)
    conv_groups = array.unfold(2, kernel_size, stride)
    conv_groups = conv_groups.contiguous()

    kern_per_out = len(kernels) // out_channels

    for c_out in range(out_channels):
        out_channel_accum = torch.zeros((batch_size, length_out), device=device)

        for k_idx in range(kern_per_out):
            kernel = kernels[c_out * kern_per_out + k_idx]
            conv_result = kernel(conv_groups.view(-1, 1, kernel_size))
            out_channel_accum += conv_result.view(batch_size, length_out)

        array_out[:, c_out, :] = out_channel_accum

    return array_out

In [None]:
def kan_conv1d(matrix: torch.Tensor,
               kernel,
               kernel_size: int,
               stride: int = 1,
               dilation: int = 1,
               padding: int = 0,
               device: str = "cpu") -> torch.Tensor:
    """
    Performs a 1D convolution with the given kernel over a 1D matrix using the defined stride, dilation, and padding.

    Args:
        matrix (torch.Tensor): 3D tensor (batch_size, channels, width) to be convolved.
        kernel (function): Kernel function to apply on the 1D patches of the matrix.
        kernel_size (int): Size of the kernel (assumed to be square).
        stride (int, optional): Stride along the width axis. Defaults to 1.
        dilation (int, optional): Dilation along the width axis. Defaults to 1.
        padding (int, optional): Padding along the width axis. Defaults to 0.
        device (str): Device to perform the operation on (e.g., "cuda" or "cpu").

    Returns:
        torch.Tensor: 1D Feature map after convolution.
    """

    batch_size, n_channels, width_in = matrix.shape
    width_out = ((width_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride) + 1
    matrix_out = torch.zeros((batch_size, n_channels, width_out), device=device)

    matrix_padded = torch.nn.functional.pad(matrix, (padding, padding))

    for i in range(width_out):

        start = i * stride
        end = start + kernel_size * dilation
        patch = matrix_padded[:, :, start:end:dilation]

        matrix_out[:, :, i] = kernel.forward(patch).squeeze(-1)

    return matrix_out

In [None]:
class KAN_Convolutional_Layer_1D(torch.nn.Module):
    def __init__(self, in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=0, dilation=1, device="cuda"):
        super(KAN_Convolutional_Layer_1D, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.device = device
        self.convs = torch.nn.ModuleList([KAN_Convolution_1D(kernel_size, stride, padding, dilation, device) for _ in range(in_channels * out_channels)])

    def forward(self, x: torch.Tensor):
        return torch.cat([conv(x[:, i, :].unsqueeze(1)) for i, conv in enumerate(self.convs)], dim=1)

In [None]:
class KAN_Convolutional_Layer_1D(torch.nn.Module):
    def __init__(
            self,
            in_channels: int = 1,
            out_channels: int = 1,
            kernel_size: int = 2,
            stride: int = 1,
            padding: int = 0,
            dilation: int = 1,
            grid_size: int = 5,
            spline_order: int = 3,
            scale_noise: float = 0.1,
            scale_base: float = 1.0,
            scale_spline: float = 1.0,
            base_activation=torch.nn.SiLU,
            grid_eps: float = 0.02,
            grid_range: tuple = [-1, 1],
            device: str = "cpu"
        ):
        super(KAN_Convolutional_Layer_1D, self).__init__()
        self.out_channels = out_channels
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = padding
        self.stride = stride


        self.convs = torch.nn.ModuleList()
        for _ in range(in_channels * out_channels):
            self.convs.append(
                KAN_Convolution_1D(
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    dilation=dilation,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor):
        batch_size, in_channels, length = x.shape
        output_length = (length + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1) // self.stride + 1
        output = torch.zeros((batch_size, self.out_channels, output_length), device=x.device)


        for i in range(self.out_channels):
            output_accum = torch.zeros((batch_size, output_length), device=x.device)
            for j in range(self.in_channels):
                kernel_idx = i * self.in_channels + j
                conv_result = self.convs[kernel_idx].forward(x[:, j, :].unsqueeze(1))
                output_accum += conv_result.squeeze(1)  # Squeeze
            output[:, i, :] = output_accum  # A to output channel

        return output

class KAN_Convolution_1D(torch.nn.Module):
    def __init__(
            self,
            kernel_size: int = 2,
            stride: int = 1,
            padding: int = 0,
            dilation: int = 1,
            grid_size: int = 50,
            spline_order: int = 3,
            scale_noise: float = 0.1,
            scale_base: float = 1.0,
            scale_spline: float = 1.0,
            base_activation=torch.nn.SiLU,
            grid_eps: float = 0.02,
            grid_range: tuple = [-1, 1]
        ):
        super(KAN_Convolution_1D, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.conv = KANLinear(
            in_features = kernel_size,
            out_features = 1,
            grid_size=grid_size,
            spline_order=spline_order,
            scale_noise=scale_noise,
            scale_base=scale_base,
            scale_spline=scale_spline,
            base_activation=base_activation,
            grid_eps=grid_eps,
            grid_range=grid_range
        )

    def forward(self, x: torch.Tensor):
        self.device = x.device
        return kan_conv1d(x, self.conv, self.kernel_size,self.stride, self.dilation, self.padding, self.device)

# Model

In [None]:
class ResBlock(nn.Module):
    def __init__(self, model_size):
        super(ResBlock, self).__init__()
        self.res_block = nn.Sequential(nn.ReLU(True),
                                       nn.Conv1d(model_size, model_size, 5, padding=2),
                                       nn.ReLU(True), nn.Conv1d(model_size, model_size, 5, padding=2))

    def forward(self, input):
        output = self.res_block(input)
        return input + 0.3 * output


class WGANGenerator(nn.Module):
    def __init__(
            self,
            model_size,
            seq_len,
            onehot_len,
            last_channel_is_prob=False,
    ):
        super(WGANGenerator, self).__init__()

        self.model_size = model_size
        self.seq_len = 512
        self.onehot_len = onehot_len
        self.last_channel_is_prob = last_channel_is_prob
        self.fc1 = KANLinear(128, self.model_size * self.seq_len)
        # nn.Linear(128, self.model_size * self.seq_len)
        self.block = nn.Sequential(ResBlock(self.model_size),
                                   ResBlock(self.model_size),
                                   ResBlock(self.model_size),
                                   ResBlock(self.model_size),
                                   ResBlock(self.model_size))
        self.conv1 = nn.Conv1d(self.model_size, self.onehot_len, 1)
        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, noise):
        output = self.fc1(noise)
        output = output.view(-1, self.model_size, self.seq_len)
        output = self.block(output)
        output = self.conv1(output)
        output = output.transpose(1, 2)
        shape = output.size()
        output = output.contiguous()
        output = output.view(noise.shape[0] * self.seq_len, -1)
        prob_channel = output.shape[-1]
        if self.last_channel_is_prob:
            prob_channel = -1
            output1 = self.softmax(output[:, :prob_channel])
            output2 = self.sigmoid(output[:, prob_channel]).view(output.shape[0], 1)
            output = torch.cat((output1, output2), 1)
        else:
            output = self.softmax(output)
        return output.view(shape)


class WGANDiscriminator(nn.Module):
    def __init__(
            self,
            model_size,
            seq_len,
            onehot_len,
    ):
        super(WGANDiscriminator, self).__init__()
        self.model_size = model_size
        self.seq_len = 512
        self.onehot_len = onehot_len
        self.block = nn.Sequential(ResBlock(self.model_size),
                                   ResBlock(self.model_size),
                                   ResBlock(self.model_size),
                                   ResBlock(self.model_size),
                                   ResBlock(self.model_size))
        self.conv1d = nn.Conv1d(self.onehot_len, self.model_size, 1)
        self.linear = KANLinear(self.seq_len * self.model_size, 1)
        # self.linear = nn.Linear(self.seq_len * self.model_size, 1)

    def forward(self, input):
        output = input.transpose(1, 2)
        output = self.conv1d(output)
        output = self.block(output)
        output = output.view(-1, self.seq_len * self.model_size)
        output = self.linear(output)
        return output


def load_wgan_generator(
    filepath,
    model_size=64,
    num_channels=6,
    latent_dim=100,
    post_proc_filt_len=512,
    upsample=True,
    last_channel_is_prob=True,
    **kwargs
):
    model = WGANGenerator(
        model_size=model_size,
        seq_len=post_proc_filt_len,
        onehot_len=num_channels,
        last_channel_is_prob=last_channel_is_prob,
    )
    model.load_state_dict(torch.load(filepath))
    return model


def load_wgan_discriminator(
        filepath,
        model_size=64,
        ngpus=1,
        num_channels=6,
        shift_factor=2,
        alpha=0.2,
        **kwargs
):
    model = WGANDiscriminator(model_size=model_size, ngpus=ngpus, num_channels=num_channels,
                               shift_factor=shift_factor, alpha=alpha)
    model.load_state_dict(torch.load(filepath))
    return model


# Save samles method
def save_samples(
        epoch_samples,
        epoch,
        output_dir,
        model_gen,
        model_dis,
        last_channel_is_prob=True,
):
    """
 Save output samples to disk
 """
    sample_dir = output_dir
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    samples = np.array(epoch_samples)
    if last_channel_is_prob:
        labels = np.take(samples, -1, axis=2)
        samples = np.delete(samples, -1, axis=2)
    seq = np.argmax(samples, axis=2)
    df = pd.DataFrame(np.take(np.array(list(codes.keys())),
                              indices=seq))
    res = []
    for i in range(len(seq)):
        res.append(np.array(df.iloc[i]))
        if last_channel_is_prob:
            res.append(labels[i])
    res_df = pd.DataFrame(res)
    print(res_df)
    res_path = os.path.join(sample_dir, '{}.csv'.format(epoch))
    res_df.to_csv(res_path, index=False)
    # save model
    model_gen_output_path = os.path.join(sample_dir,
                                         'model_gen_last.pkl')
    model_dis_output_path = os.path.join(sample_dir,
                                         'model_dis_last.pkl')
    torch.save(model_gen.state_dict(), model_gen_output_path,
               pickle_protocol=pk.HIGHEST_PROTOCOL)
    torch.save(model_dis.state_dict(), model_dis_output_path,
               pickle_protocol=pk.HIGHEST_PROTOCOL)


# Wasserstain training process
LOGGER = logging.getLogger('g4gan')
LOGGER.setLevel(logging.DEBUG)


def compute_discr_loss_terms(
        model_dis,
        model_gen,
        real_data_v,
        noise_v,
        batch_size,
        latent_dim,
        lmbda,
        use_cuda,
        use_binary_mask,
        mask_v,
        p_mask,
        compute_grads=False,
        last_channel_is_prob=True,
):
    # Convenient values for
    one = torch.tensor(1, dtype=torch.float)
    neg_one = one * -1
    if use_cuda:
        one = one.cuda()
        neg_one = neg_one.cuda()
    # Reset gradients
    model_dis.zero_grad()
    # Apply binary mask to real data
    if use_binary_mask:
        real_data_v = real_data_v  *  mask_v * (1 / p_mask)
        if use_cuda:
            real_data_v = real_data_v.cuda()
        real_data_v = autograd.Variable(real_data_v)
    # a) Compute loss contribution from real training data and backprop
    # (negative of the empirical mean, w.r.t. the data distribution, of the discr.output)
    D_real = model_dis(real_data_v)
    D_real = D_real.mean()
    # Negate since we want to _maximize_ this quantity
    if compute_grads:
        D_real.backward(neg_one)
    # b) Compute loss contribution from generated data and backprop
    # (empirical mean, w.r.t. the generator distribution, of the discr. output)
    # Generate noise in latent space
    # Generate data by passing noise through the generator
    fake = autograd.Variable(model_gen(noise_v).data)
    inputv = fake
    # Apply binary mask to fake data
    if use_binary_mask:
        inputv = inputv *  mask_v * (1 / p_mask)
        if use_cuda:
            inputv = inputv.cuda()
        inputv = autograd.Variable(inputv)
    D_fake = model_dis(inputv)
    D_fake = D_fake.mean()
    if compute_grads:
        D_fake.backward(one)
    # c) Compute gradient penalty and backprop
    gradient_penalty = calc_gradient_penalty(model_dis,
                                             real_data_v.data,
                                             fake.data,
                                             batch_size,
                                             lmbda,
                                             use_cuda=use_cuda,
                                             )
    if compute_grads:
        gradient_penalty.backward(one)
    # Compute metrics and record in batch history
    D_cost = D_fake - D_real + gradient_penalty
    Wasserstein_D = D_real - D_fake
    return (D_cost, Wasserstein_D)


def compute_gener_loss_terms(
        model_dis,
        model_gen,
        batch_size,
        latent_dim,
        use_cuda,
        use_binary_mask,
        mask_v,
        p_mask,
        compute_grads=False,
        last_channel_is_prob=True,
):
    # Convenient values for
    one = torch.tensor(1, dtype=torch.float)
    neg_one = one * -1
    if use_cuda:
        one = one.cuda()
        neg_one = neg_one.cuda()
    # Reset generator gradients
    model_gen.zero_grad()
    # Sample from the generator
    noise = torch.Tensor(batch_size, latent_dim).uniform_(-1, 1)
    if use_cuda:
        noise = noise.cuda()
    noise_v = autograd.Variable(noise)
    fake = model_gen(noise_v)
    # Apply binary mask to fake data
    if use_binary_mask:
        fake = fake  * mask_v * (1 / p_mask)
    if use_cuda:
        fake = fake.cuda()
    # Compute generator loss and backprop
    # (negative of empirical mean (w.r.t generator distribution) of discriminator
    G = model_dis(fake)
    G = G.mean()

    if compute_grads:
        G.backward(neg_one)
    G_cost = -G
    return G_cost


def np_to_input_var(data, use_cuda):
    data = torch.Tensor(data)
    if use_cuda:
        data = data.cuda()
    return autograd.Variable(data)


# Adapted from https://github.com/caogang/wgan-gp/blob/master/gan_toy.py
def calc_gradient_penalty(
        model_dis,
        real_data,
        fake_data,
        batch_size,
        lmbda,
        use_cuda=True,
):
    # Compute interpolation factors
    alpha = torch.rand(batch_size, 1, 1)
    alpha = alpha.expand(real_data.size())
    alpha = (alpha.cuda() if use_cuda else alpha)
    # Interpolate between real and fake data
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    if use_cuda:
        interpolates = interpolates.cuda()
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    # Evaluate discriminator
    disc_interpolates = model_dis(interpolates)
    # Obtain gradients of the discriminator with respect to the inputs
    gradients = autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=(torch.ones(disc_interpolates.size()).cuda() if use_cuda else
                      torch.ones(disc_interpolates.size())),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    # Compute MSE between 1.0 and the gradient of the norm penalty to encourage

    # to be a 1-Lipschitz function
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() \
                       * lmbda
    return gradient_penalty

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def train_wgan(
        model_gen,
        model_dis,
        train_gen,
        valid_gen,
        test_data,
        num_epochs,
        batches_per_epoch,
        batch_size,
        output_dir=None,
        lmbda=0.1,
        use_cuda=True,
        discriminator_updates=5,
        epochs_per_sample=10,
        sample_size=20,
        lr=1e-4,
        beta_1=0.5,
        beta_2=0.9,
        latent_dim=100,
        use_knn_score=False,
        knn_score_sample_size=1000,
        knn_score_real_data=None,
        use_binary_mask=False,
        start_p_mask=1,
        mask_end_epoch=100,
        last_channel_is_prob=True,
):
    if use_cuda:
        model_gen = model_gen.cuda()
        model_dis = model_dis.cuda()
    # Initialize optimizers for each model
    optimizer_gen = optim.Adam(model_gen.parameters(), lr=lr,
                               betas=(beta_1, beta_2))
    optimizer_dis = optim.Adam(model_dis.parameters(), lr=lr,
                               betas=(beta_1, beta_2))
    # Sample noise used for seeing the evolution of generated output samples

    sample_noise = torch.Tensor(sample_size, latent_dim).uniform_(-1, 1)
    if use_cuda:
        sample_noise = sample_noise.cuda()
    sample_noise_v = autograd.Variable(sample_noise)
    samples = {}
    history = []
    plot_history = []
    plot_knn_score_history = []
    train_iter = iter(train_gen)
    valid_iter = iter(valid_gen)
    test_data_v = np_to_input_var(test_data, use_cuda)
    # Sample noise for KNN
    if use_knn_score and knn_score_real_data is not None:
        LOGGER.info('Create KDTree')
        X = np.argmax(knn_score_real_data, 2)

        knn_tree = NearestNeighbors(n_neighbors=2, algorithm='auto',
                                    metric=lambda a, b: \
                                        editdistance.eval(a, b))
        knn_tree.fit(X)
        knn_sample_noise = torch.Tensor(knn_score_sample_size,
                                        latent_dim).uniform_(-1, 1)
        if use_cuda:
            knn_sample_noise = knn_sample_noise.cuda()
        knn_sample_noise_v = autograd.Variable(knn_sample_noise)
        # Score on train data
        (dist, ind) = knn_tree.kneighbors(X[:], 2)
        d_self = np.mean(np.take(dist, 1, axis=1))
        d_train = np.mean(np.take(dist, 0, axis=1))
        LOGGER.info('KNN score on train data:')
        LOGGER.info('D_self: {}'.format(d_self))
        LOGGER.info('D_train: {}'.format(d_train))
        Y = np.argmax(test_data, 2)
        knn_self_tree = NearestNeighbors(n_neighbors=2, algorithm='auto'
                                         , metric=lambda a, b: editdistance.eval(a, b))
        knn_self_tree.fit(Y)
        (dist, ind) = knn_self_tree.kneighbors(Y[:], 2)
        d_self = np.mean(np.take(dist, 1, axis=1))
        (dist, ind) = knn_tree.kneighbors(Y[:], 1)
        d_train = np.mean(np.take(dist, 0, axis=1))
        LOGGER.info('KNN score on test data:')
        LOGGER.info('D_self: {}'.format(d_self))
        LOGGER.info('D_train: {}'.format(d_train))
        best_self_knn_score = 0
    # Set p_mask
    p_mask = None
    mask_v = 0.1
    if use_binary_mask:
        p_mask = start_p_mask
        LOGGER.info('Setup p_mask: {}'.format(p_mask))


    # Loop over the dataset multiple times
    for epoch in range(num_epochs):
        LOGGER.info('Epoch: {}/{}'.format(epoch + 1, num_epochs))
        epoch_history = []
        for batch_idx in range(batches_per_epoch):
            # Set model parameters to require gradients to be computed and stored
            for p in model_dis.parameters():
                p.requires_grad = True
            # Initialize the metrics for this batch
            batch_history = {'discriminator': [], 'generator': {}}
            # Discriminator Training Phase:

            # -> Train discriminator k times
            for iter_d in range(discriminator_updates):
                # Get real examples
                real_data_v = np_to_input_var(next(train_iter),
                                              use_cuda)
                # Get valid examples
                valid_data_v = np_to_input_var(next(valid_iter),
                                               use_cuda)
                # Get noise
                noise = torch.Tensor(batch_size, latent_dim).uniform_(-1, 1)
                if use_cuda:
                    noise = noise.cuda()
                noise_v = autograd.Variable(noise, volatile=True)  # totally freeze model_gen
    # Get new batch of real training data
                (D_cost_train, D_wass_train) = compute_discr_loss_terms(
                    model_dis,
                    model_gen,
                    real_data_v,
                    noise_v,
                    batch_size,
                    latent_dim,
                    lmbda,
                    use_cuda,
                    use_binary_mask,
                    mask_v,
                    p_mask,
                    compute_grads=True,
                    last_channel_is_prob=last_channel_is_prob,
                )
                # Update the discriminator
                optimizer_dis.step()
                (D_cost_valid, D_wass_valid) = compute_discr_loss_terms(
                    model_dis,
                    model_gen,
                    valid_data_v,
                    noise_v,
                    batch_size,
                    latent_dim,
                    lmbda,
                    use_cuda,
                    use_binary_mask,
                    mask_v,
                    p_mask,
                    compute_grads=False,
                    last_channel_is_prob=last_channel_is_prob,
                )
                if use_cuda:
                    D_cost_train = D_cost_train.cpu()
                    D_cost_valid = D_cost_valid.cpu()
                    D_wass_train = D_wass_train.cpu()
                    D_wass_valid = D_wass_valid.cpu()

                batch_history['discriminator'].append({
                    'cost': float(D_cost_train.data.numpy()),
                    'wasserstein_cost':
                        float(D_wass_train.data.numpy()),
                    'cost_validation':
                        float(D_cost_valid.data.numpy()),
                    'wasserstein_cost_validation':
                        float(D_wass_valid.data.numpy()),
                })
            # ###########################
            # (2) Update G network
            # ##########################
            # Prevent discriminator from computing gradients, since
            # we are only updating the generator
            for p in model_dis.parameters():
                p.requires_grad = False
            G_cost = compute_gener_loss_terms(
                model_dis,
                model_gen,
                batch_size,
                latent_dim,
                use_cuda,
                use_binary_mask,
                mask_v,
                p_mask,
                compute_grads=True,
                last_channel_is_prob=last_channel_is_prob,
            )
            # Update generator
            optimizer_gen.step()
            if use_cuda:
                G_cost = G_cost.cpu()
            # Record generator loss
            batch_history['generator']['cost'] = \
                float(G_cost.data.numpy())
            # Record batch metrics
            epoch_history.append(batch_history)
      # Update binary mask
        if use_binary_mask:
            next_epoch = epoch + 1
            p_mask = (1 - start_p_mask) / mask_end_epoch * next_epoch \
                    + start_p_mask

            if next_epoch == mask_end_epoch or p_mask > 1.:
                p_mask = 1.
            LOGGER.info('Update p_mask: {}'.format(p_mask))
        # Record epoch metrics
        history.append(epoch_history)
        LOGGER.info(pprint.pformat(epoch_history[-1]))
        # Plot
        plot_history.append(epoch_history[-1]['discriminator'][-1])
        plot_history[-1]['gen_cost'] = epoch_history[-1]['generator'
        ]['cost']
        pd.DataFrame(plot_history).plot()
        plt.show()
    # Calc KNN score
        if use_knn_score and knn_tree is not None:
            LOGGER.info('Calculate KNN score...')
            # Self KNN score (D_self)
            knn_samp_output = model_gen(knn_sample_noise_v)
            if use_cuda:
                knn_samp_output = knn_samp_output.cpu()
            # One hot
            prob_channel = knn_samp_output.shape[2]
            if last_channel_is_prob:
                prob_channel = -1
            (values, indices) = knn_samp_output[:, :, :
                                                      prob_channel].max(2)
            knn_samp_output[:, :, :prob_channel] = 0
            indices = indices.view(indices.shape[0], indices.shape[1],
                                  1)
            knn_samp_output = knn_samp_output.scatter_(2, indices, 1)
            knn_samp_output = knn_samp_output.data.numpy()
            Y = np.argmax(knn_samp_output, 2)
            knn_self_tree = NearestNeighbors(n_neighbors=2,
                                            algorithm='auto', metric=lambda a, b: \
                    editdistance.eval(a, b))
            knn_self_tree.fit(Y)
            (dist, ind) = knn_self_tree.kneighbors(Y[:], 2)
            d_self = np.mean(np.take(dist, 1, axis=1))
            # Train KNN score
            (dist, ind) = knn_tree.kneighbors(Y[:], 1)
            d_train = np.mean(np.take(dist, 0, axis=1))
            # History and plot
            LOGGER.info('D_self: {}'.format(d_self))
            LOGGER.info('D_train: {}'.format(d_train))
            plot_knn_score_history.append({'d_self': d_self,
                                          'd_train': d_train})
            pd.DataFrame(plot_knn_score_history).plot()

            plt.show()
            # Save best knn self score model
            if output_dir and d_self > best_self_knn_score:
                best_self_knn_score = d_self
                # save model
                model_gen_output_path = os.path.join(output_dir,

                                                    'model_gen_best_knn_self_{}.pkl'.format(int(best_self_knn_score)))
                model_dis_output_path = os.path.join(output_dir,

                                                    'model_dis_best_knn_self_{}.pkl'.format(int(best_self_knn_score)))
                torch.save(model_gen.state_dict(),
                          model_gen_output_path,
                          pickle_protocol=pk.HIGHEST_PROTOCOL)
                torch.save(model_dis.state_dict(),
                          model_dis_output_path,
                          pickle_protocol=pk.HIGHEST_PROTOCOL)
        if (epoch + 1) % epochs_per_sample == 0:
            # Generate outputs for fixed latent samples
            LOGGER.info('Generating samples...')
            samp_output = model_gen(sample_noise_v)
            if use_cuda:
                samp_output = samp_output.cpu()
            samples[epoch + 1] = samp_output.data.numpy()
            if output_dir:
                LOGGER.info('Saving samples...')
                save_samples(
                    samples[epoch + 1],
                    epoch + 1,
                    output_dir,
                    model_gen,
                    model_dis,
                    last_channel_is_prob=last_channel_is_prob,
                )
    # # Get final discriminator loss
    # Get noise
    noise = torch.Tensor(batch_size, latent_dim).uniform_(-1, 1)
    if use_cuda:
        noise = noise.cuda()
    noise_v = autograd.Variable(noise, volatile=True)  # totally freeze generator
    final_discr_metrics = {
        'cost_validation': 0,
        'wasserstein_cost_validation': 0,
        'cost_test': 0,
        'wasserstein_cost_test': 0,
    }
    return model_gen, model_dis, history, final_discr_metrics, samples


# Batch generator

def batch_generator(data, batch_size, shuffle_each_epoch=True):
    indices = np.arange(data.shape[0])
    batch = []
    while True:
        if shuffle_each_epoch:
            np.random.shuffle(indices)
        for i in indices:
            batch.append(i)
            if len(batch) == batch_size:
                yield data[batch]
                batch = []


def create_data_split(
        g4_np_dataset_path,
        valid_ratio,
        test_ratio,
        train_batch_size,
        shuffle_each_epoch=True,
        train_subset_size=100,
):
    data = np.load(g4_np_dataset_path)
    num_g4 = data.shape[0]
    num_valid = int(np.ceil(num_g4 * valid_ratio))
    num_test = int(np.ceil(num_g4 * test_ratio))
    num_train = num_g4 - num_valid - num_test
    assert num_valid > 0
    assert num_test > 0
    assert num_train > 0
    indices = np.arange(num_g4)
    np.random.shuffle(indices)
    train_data_indices = indices[:num_train]
    valid_data_indices = indices[num_train:num_train + num_valid]
    test_data_indices = indices[num_train + num_valid:]
    train_data = data[train_data_indices]
    valid_data = data[valid_data_indices]
    test_data = data[test_data_indices]
    train_gen = batch_generator(train_data, train_batch_size,
                                shuffle_each_epoch)
    valid_gen = batch_generator(valid_data, train_batch_size,
                                shuffle_each_epoch)
    train_subset = None
    if train_subset_size > 0:
        train_subset = train_data[:train_subset_size]
    return (train_gen, valid_gen, test_data, train_subset)


# Log
def init_console_logger(logger, verbose=False):
    # Log to stderr also
    stream_handler = logging.StreamHandler()
    if verbose:
        stream_handler.setLevel(logging.DEBUG)
    else:
        stream_handler.setLevel(logging.INFO)
    formatter = \
        logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'
                          )
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

# Data

In [None]:
# raw_dataset.to_csv("/content/drive/My Drive/Zdna.csv", index=False)

In [None]:
# # Train
# path = 'G4_Chip_seq_quadruplex_norm_quad_labeled.npy'
# channels = 6
# last_channel_is_prob = True
# path = '/content/G4_Chip_seq_quadruplex_norm_quad_labeled.npy'
# path = '/content/WuKou_zdna2016_filter-norm_to_512.npy'
path = "/content/zdna2016_norm_quad_labeled.npy"
channels = 6
last_channel_is_prob = True
args = {}
args['verbose'] = True
args['batches_per_epoch'] = 109
args['batch_size'] = 64
args['latent_dim'] = 128
args['ngpus'] = 1
args['model_size'] = 32
args['output_dir'] = 'models/'
args['g4_np_data_path'] = path
args['valid_ratio'] = 0.1
args['test_ratio'] = 0.1
args['shuffle_train_each_epoch'] = True
args['post_proc_filt_len'] = None
args['alpha'] = 0.2
args['shift_factor'] = 2
args['batch_shuffle'] = False
args['num_epochs'] = 200
args['learning_rate'] = 1e-4
args['beta1'] = 0.5
args['beta2'] = 0.9
args['lmbda'] = 10.0
args['discriminator_updates'] = 5
args['epochs_per_sample'] = 5
args['sample_size'] = 20
args['num_channels'] = channels
args['use_knn_score'] = True
args['knn_score_sample_size'] = 100
args['use_binary_mask'] = False
args['start_p_mask'] = 0.3
args['mask_end_epoch'] = 90
args['last_channel_is_prob'] = last_channel_is_prob
init_console_logger(LOGGER, args['verbose'])
LOGGER.info('Initialized logger.')
batch_size = args['batch_size']
latent_dim = args['latent_dim']
ngpus = args['ngpus']
model_size = args['model_size']
model_dir = os.path.join(args['output_dir'],
                         datetime.datetime.now().strftime('%Y%m%d%H%M%S'
                                                          ))
args['model_dir'] = model_dir
if not args['use_knn_score']:
    args['knn_score_sample_size'] = 0
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
LOGGER.info('Saving configurations...')
config_path = os.path.join(model_dir, 'config.json')
with open(config_path, 'w') as f:
    json.dump(args, f)
LOGGER.info('Loading G4 data...')
g4_np_data_path = args['g4_np_data_path']
(train_gen, valid_gen, test_data, train_subset) = create_data_split(
    g4_np_data_path,
    args['valid_ratio'],
    args['test_ratio'],
    batch_size,
    args['shuffle_train_each_epoch'],
    args['knn_score_sample_size'],
)
LOGGER.info('Creating models...')
model_gen = WGANGenerator(args['model_size'], 512, args['num_channels'
],
                           last_channel_is_prob=args['last_channel_is_prob'
                           ])
model_dis = WGANDiscriminator(args['model_size'], 512,
                               args['num_channels'])
LOGGER.info(model_gen)
LOGGER.info(model_dis)
LOGGER.info('Starting training...')
(model_gen, model_dis, history, final_discr_metrics, samples) = \
    train_wgan(
        model_gen=model_gen,
        model_dis=model_dis,
        train_gen=train_gen,
        valid_gen=valid_gen,
        test_data=test_data,
        num_epochs=args['num_epochs'],
        batches_per_epoch=args['batches_per_epoch'],
        batch_size=batch_size,
        output_dir=model_dir,
        lr=args['learning_rate'],
        beta_1=args['beta1'],
        beta_2=args['beta2'],
        lmbda=args['lmbda'],
        use_cuda=ngpus >= 1,
        discriminator_updates=args['discriminator_updates'],
        latent_dim=latent_dim,
        epochs_per_sample=args['epochs_per_sample'],
        sample_size=args['sample_size'],

        use_knn_score=args['use_knn_score'],
        knn_score_sample_size=args['knn_score_sample_size'],
        knn_score_real_data=train_subset,
        use_binary_mask=args['use_binary_mask'],
        start_p_mask=args['start_p_mask'],
        mask_end_epoch=args['mask_end_epoch'],
        last_channel_is_prob=args['last_channel_is_prob'],
    )
LOGGER.info('Finished training.')
LOGGER.info('Final discriminator loss on validation and test:')
LOGGER.info(pprint.pformat(final_discr_metrics))
LOGGER.info('Saving models...')
model_gen_output_path = os.path.join(model_dir, 'model_gen.pkl')
model_dis_output_path = os.path.join(model_dir, 'model_dis.pkl')
torch.save(model_gen.state_dict(), model_gen_output_path,
           pickle_protocol=pk.HIGHEST_PROTOCOL)
torch.save(model_dis.state_dict(), model_dis_output_path,
           pickle_protocol=pk.HIGHEST_PROTOCOL)
LOGGER.info('Saving metrics...')
history_output_path = os.path.join(model_dir, 'history.pkl')
final_discr_metrics_output_path = os.path.join(model_dir,
                                               'final_discr_metrics.pkl')
with open(history_output_path, 'wb') as f:
    pk.dump(history, f)
with open(final_discr_metrics_output_path, 'wb') as f:
    pk.dump(final_discr_metrics, f)
LOGGER.info('Done with LKAN!')

In [None]:
sum(p.numel() for p in model_gen.parameters())

21023238

In [None]:
# # Train
# path = 'G4_Chip_seq_quadruplex_norm_quad_labeled.npy'
# channels = 6
# last_channel_is_prob = True
# path = '/content/G4_Chip_seq_quadruplex_norm_quad_labeled.npy'
# path = '/content/WuKou_zdna2016_filter-norm_to_512.npy'
path = "/content/zdna2016_norm_quad_labeled.npy"
channels = 6
last_channel_is_prob = True
args = {}
args['verbose'] = True
args['batches_per_epoch'] = 109
args['batch_size'] = 64
args['latent_dim'] = 128
args['ngpus'] = 1
args['model_size'] = 32
args['output_dir'] = 'models/'
args['g4_np_data_path'] = path
args['valid_ratio'] = 0.1
args['test_ratio'] = 0.1
args['shuffle_train_each_epoch'] = True
args['post_proc_filt_len'] = None
args['alpha'] = 0.2
args['shift_factor'] = 2
args['batch_shuffle'] = False
args['num_epochs'] = 200
args['learning_rate'] = 1e-4
args['beta1'] = 0.5
args['beta2'] = 0.9
args['lmbda'] = 10.0
args['discriminator_updates'] = 5
args['epochs_per_sample'] = 5
args['sample_size'] = 20
args['num_channels'] = channels
args['use_knn_score'] = True
args['knn_score_sample_size'] = 100
args['use_binary_mask'] = False
args['start_p_mask'] = 0.3
args['mask_end_epoch'] = 90
args['last_channel_is_prob'] = last_channel_is_prob
init_console_logger(LOGGER, args['verbose'])
LOGGER.info('Initialized logger.')
batch_size = args['batch_size']
latent_dim = args['latent_dim']
ngpus = args['ngpus']
model_size = args['model_size']
model_dir = os.path.join(args['output_dir'],
                         datetime.datetime.now().strftime('%Y%m%d%H%M%S'
                                                          ))
args['model_dir'] = model_dir
if not args['use_knn_score']:
    args['knn_score_sample_size'] = 0
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
LOGGER.info('Saving configurations...')
config_path = os.path.join(model_dir, 'config.json')
with open(config_path, 'w') as f:
    json.dump(args, f)
LOGGER.info('Loading G4 data...')
g4_np_data_path = args['g4_np_data_path']
(train_gen, valid_gen, test_data, train_subset) = create_data_split(
    g4_np_data_path,
    args['valid_ratio'],
    args['test_ratio'],
    batch_size,
    args['shuffle_train_each_epoch'],
    args['knn_score_sample_size'],
)
LOGGER.info('Creating models...')
model_gen = WGANGenerator(args['model_size'], 512, args['num_channels'
],
                           last_channel_is_prob=args['last_channel_is_prob'
                           ])
model_dis = WGANDiscriminator(args['model_size'], 512,
                               args['num_channels'])
LOGGER.info(model_gen)
LOGGER.info(model_dis)
LOGGER.info('Starting training...')
(model_gen, model_dis, history, final_discr_metrics, samples) = \
    train_wgan(
        model_gen=model_gen,
        model_dis=model_dis,
        train_gen=train_gen,
        valid_gen=valid_gen,
        test_data=test_data,
        num_epochs=args['num_epochs'],
        batches_per_epoch=args['batches_per_epoch'],
        batch_size=batch_size,
        output_dir=model_dir,
        lr=args['learning_rate'],
        beta_1=args['beta1'],
        beta_2=args['beta2'],
        lmbda=args['lmbda'],
        use_cuda=ngpus >= 1,
        discriminator_updates=args['discriminator_updates'],
        latent_dim=latent_dim,
        epochs_per_sample=args['epochs_per_sample'],
        sample_size=args['sample_size'],

        use_knn_score=args['use_knn_score'],
        knn_score_sample_size=args['knn_score_sample_size'],
        knn_score_real_data=train_subset,
        use_binary_mask=args['use_binary_mask'],
        start_p_mask=args['start_p_mask'],
        mask_end_epoch=args['mask_end_epoch'],
        last_channel_is_prob=args['last_channel_is_prob'],
    )
LOGGER.info('Finished training.')
LOGGER.info('Final discriminator loss on validation and test:')
LOGGER.info(pprint.pformat(final_discr_metrics))
LOGGER.info('Saving models...')
model_gen_output_path = os.path.join(model_dir, 'model_gen.pkl')
model_dis_output_path = os.path.join(model_dir, 'model_dis.pkl')
torch.save(model_gen.state_dict(), model_gen_output_path,
           pickle_protocol=pk.HIGHEST_PROTOCOL)
torch.save(model_dis.state_dict(), model_dis_output_path,
           pickle_protocol=pk.HIGHEST_PROTOCOL)
LOGGER.info('Saving metrics...')
history_output_path = os.path.join(model_dir, 'history.pkl')
final_discr_metrics_output_path = os.path.join(model_dir,
                                               'final_discr_metrics.pkl')
with open(history_output_path, 'wb') as f:
    pk.dump(history, f)
with open(final_discr_metrics_output_path, 'wb') as f:
    pk.dump(final_discr_metrics, f)
LOGGER.info('Done!')

# Reconstruct

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
from torch.utils import data
import numpy as np
import torch.optim as optim
import sklearn.metrics as metrics
import matplotlib.pyplot as plt

In [None]:
codes = {
 'A': [1., 0., 0., 0., 0.],
 'T': [0., 1., 0., 0., 0.],
 'G': [0., 0., 1., 0., 0.],
 'C': [0., 0., 0., 1., 0.],
 'N': [0., 0., 0., 0., 1.],
 }

quad_len = 100

In [None]:
one_hot_quads = []
line_num = 0
with open("/content/drive/My Drive/data_for_gans/preprocessed_fastas/mmHDNA_filter-norm_to_500.fasta", 'r') as f:
  for line in f:
    if line[0] != '>':
      one_hot = []
      for s in line.upper():
        if s != '\n':
          one_hot.append(codes[s])
      one_hot_quads.append(one_hot)
      line_num += 1
one_hot_quads_np = np.array(one_hot_quads)
np.save('mmHDNA_filter-norm_to_500.npy', one_hot_quads_np)

In [None]:
one_hot_quads = []
line_num = 0
with open("/content/drive/My Drive/data_for_gans/preprocessed_fastas/mmQuad_filter-norm_to_500.fasta", 'r') as f:
  for line in f:
    if line[0] != '>':
      one_hot = []
      for s in line.upper():
        if s != '\n':
          one_hot.append(codes[s])
      one_hot_quads.append(one_hot)
      line_num += 1
one_hot_quads_np = np.array(one_hot_quads)
np.save('mmQuad_filter-norm_to_500.npy', one_hot_quads_np)

In [None]:
# one_hot_quads = []
# line_num = 0
# with open('/content/drive/My Drive/data_for_gans/preprocessed_fastas/WuKou2016_filter-norm_to_512.fasta', 'r') as f:
#   for line in f:
#     if line[0] != '>':
#       one_hot = []
#       for s in line.upper():
#         if s != '\n':
#           one_hot.append(codes[s])
#       one_hot_quads.append(one_hot)
#       line_num += 1
# one_hot_quads_np = np.array(one_hot_quads)
# np.save('WuKou2016_filter-norm_to_500.npy', one_hot_quads_np)

In [None]:
# one_hot_quads = []
# line_num = 0
# with open('/content/drive/My Drive/data_for_gans/preprocessed_fastas/zdna2016_filter-norm_to_512.fasta', 'r') as f:
#   for line in f:
#     if line[0] != '>':
#       one_hot = []
#       for s in line.upper():
#         if s != '\n':
#           one_hot.append(codes[s])
#       one_hot_quads.append(one_hot)
#       line_num += 1
# one_hot_quads_np = np.array(one_hot_quads)
# np.save('zdna2016_filter-norm_to_500.npy', one_hot_quads_np)

In [None]:
data1 = np.load("/content/mmHDNA_filter-norm_to_500.npy")
data2 = np.load("/content/mmQuad_filter-norm_to_500.npy")
data = np.concatenate((data1, data2), axis=0)
# dataq = np.load("/content/G4_filter-norm_to_500.npy")
# data = np.concatenate((data, dataq), axis=0)
# dataq = np.load("/content/zdna2016_filter-norm_to_500.npy")
# data = np.concatenate((data, dataq), axis=0)

In [None]:
# saved_state_dict = torch.load('./models/20231002120536/model_gen_last.pkl')

# # Initialize a new model with the current architecture
# model_gen = G4GANGenerator(model_size, seq_len, onehot_len)
# model_gen.eval()

# # Copy the weights from the saved model to the new model
# model_gen.conv1.weight.data.copy_(saved_state_dict['conv1.weight'])
# model_gen.conv1.bias.data.copy_(saved_state_dict['conv1.bias'])

In [None]:
weights_path = 'wgan_zdna.pkl'
model_size = 32
seq_len = 5
onehot_len = 65
hidden_state_len = 128
g4gan = load_wgan_generator(weights_path, model_size, seq_len,
                             onehot_len)
g4gan.eval()

In [None]:
input = data

In [None]:
input.shape[0]

1177297

In [None]:
def sample_fake_data(generator, batch_size, hidden_state_len):
    fake_data = []
    with torch.no_grad():
        for _ in range(0, input.shape[0], batch_size):
            noise = torch.Tensor(batch_size, hidden_state_len).uniform_(-1, 1)
            fake = generator(noise)
            (values, indices_hot) = fake.max(2)
            fake[:, :, :] = 0
            indices_hot = indices_hot.view(indices_hot.shape[0], indices_hot.shape[1], 1)
            fake = fake.scatter_(2, indices_hot, 1)
            fake_data.append(fake.detach().numpy())
    return np.concatenate(fake_data, axis=0)

In [None]:
batch_size = 1000

fake_samples = sample_fake_data(g4gan, batch_size, hidden_state_len)

In [None]:
with torch.no_grad():
    noise = torch.Tensor(input.shape[0], hidden_state_len).uniform_(-1, 1)
    fake = g4gan(noise)
    (values, indices_hot) = fake.max(2)
    fake[:, :, :] = 0
    indices_hot = indices_hot.view(indices_hot.shape[0], indices_hot.shape[1], 1)
    fake = fake.scatter_(2, indices_hot, 1)
    input_fake = fake.detach().numpy()

# num_g4 = input.shape[0]
# num_test = int(np.ceil(num_g4 * test_ratio))
# num_train = num_g4 - num_test
# indices = np.arange(num_g4)
# np.random.shuffle(indices)
# train_data_indices = indices[:num_train]
# test_data_indices = indices[num_train:]

In [None]:
fake_samples.shape

(1178000, 512, 5)

In [None]:
new_data = fake_samples[:, :100, :5]

# Extract the cut values
cut_values = fake_samples[:, :, 5:]

In [None]:
cut_values.reshape(-1, 1)

array([], shape=(0, 1), dtype=float32)

In [None]:
new_data[1].shape

(100, 5)

In [None]:
new_data.shape

(1178000, 100, 5)

In [None]:
reverse_codes = {
    'A': [1., 0., 0., 0., 0.],
    'T': [0., 1., 0., 0., 0.],
    'G': [0., 0., 1., 0., 0.],
    'C': [0., 0., 0., 1., 0.],
    'N': [0., 0., 0., 0., 1.],
}

def decode_data(encoded_sequences):
    decoded_sequences = []

    for sequence in encoded_sequences:
        decoded_sequence = ''
        for encoding in sequence:
            for char, code in reverse_codes.items():
                if torch.equal(encoding, torch.tensor(code, dtype=torch.float64)):
                    decoded_sequence += char
                    break
        decoded_sequences.append(decoded_sequence)

    return decoded_sequences

decoded_seqs = decode_data([torch.tensor(seq, dtype=torch.float64) for seq in new_data])

decoded_seqs_np = np.array(decoded_seqs)

with open('generated_sequences.txt', 'w') as f:
    for seq in decoded_seqs:
        f.write(seq + '\n')

In [None]:
bedWK_df = pd.read_csv('/content/drive/My Drive/data_for_gans/data/hg19_zdna/raw/WuKou16_filter_norm_to_512.bed', sep='\t', comment='t', header=None)
bed16_df = pd.read_csv('/content/drive/My Drive/data_for_gans/data/hg19_zdna/raw/zdna2016_filter_norm_to_512.bed', sep='\t', comment='t', header=None)
bedG_df = pd.read_csv('/content/drive/My Drive/data_for_gans/G4_Chip_seq_filter_norm_to_500.bed', sep='\t', comment='t', header=None)
bed = pd.read_csv('/content/HDNA_norm_to_500.bed', sep='\t', comment='t', header=None)

In [None]:
bed1 = pd.read_csv('/content/mmHDNA_norm_to_500.bed', sep='\t', comment='t', header=None)
bed2 = pd.read_csv('/content/Quadruplex_norm_to_500.bed', sep='\t', comment='t', header=None)

In [None]:
# un_bed_dataset = pd.concat([bedWK_df, bed16_df]).reset_index(drop=True)
# un_bed_dataset = pd.concat([un_bed_dataset, bedG_df]).reset_index(drop=True)
# un_bed_dataset = pd.concat([un_bed_dataset, bed]).reset_index(drop=True)
un_bed_dataset = pd.concat([bed1, bed2]).reset_index(drop=True)

In [None]:
bed1.columns = ['chrom', 'chromStart', 'chromEnd']
bed2.columns = ['chrom', 'chromStart', 'chromEnd']

In [None]:
# bed = pd.read_csv('/content/HDNA_norm_to_500.bed', sep='\t', comment='t', header=None)

In [None]:
un_bed_dataset.columns = ['chrom', 'chromStart', 'chromEnd']

In [None]:
un_bed_dataset.shape

(1211161, 3)

In [None]:
df_gen_sequences = pd.DataFrame({"chrom": un_bed_dataset[:len(decoded_seqs_np)].chrom, "raw_sequence": decoded_seqs_np})

In [None]:
df_gen_sequences_to_fasta = df_gen_sequences

In [None]:
fasta_lines = []

for index, row in df_gen_sequences_to_fasta.iterrows():
    chrom = row['chrom']
    sequence = row['raw_sequence']

    fasta_lines.append(f'>{chrom}\n{sequence}\n')

fasta_file_path = '/content/mm_Quad+H_wgan.fasta'
with open(fasta_file_path, 'w') as fasta_file:
    fasta_file.writelines(fasta_lines)