**Conditioning the difussion process by hypernetworks**

- This is a prototype of conditioning of the diffusion process by using hypernetworks that generate the parameters for the last blocks of the UNET.
- 90% of the code is adapted from the original notebook on conditional diffusion (Lucas Ferreira da Silva, Luca Pinello, Zach Nussbaum), only with a reorganization of the code and some minor changes.
- As of now, the hypernetworks predict parameters based on the *component id*, not the *cell type*.
- Hypernetworks are trained along with the diffusion process (end-to-end).
- Initial experiments show similar performance (loss; KL divergence) as not using hypernetworks - although over multiple runs, a little faster drop of loss has been observed *(although quite unstable due to very small batch sizes of 14...hitting gpu mem limits)*. Comparisons made only on 4 components, more conditional classes might show a bigger difference using this approach.
- Why hypernetworks:
  - More expressive than classical embedding methods
  - Potential for compression of the parameters (hypernetworks can be seen as a weight-sharing mechanism)
  - Less training data needed for new cell types/components which could be beneficial in the future (reason is the transfer learning of the hypernetwork).
  - Idea: Mixture of denoisers (one hypernetwork generating parameters for multiple denoisers).
  
Contributor:
- Jan Sobotka (github: Johnny1188)

# Importing

In [None]:
# !pip install hypnettorch # for hypernetworks (https://github.com/chrhenning/hypnettorch)

In [None]:
from hypnettorch.hnets import ChunkedHMLP

In [None]:
import os; os.getpid()
import torch
import copy
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch
from IPython.display import display
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
from tqdm import tqdm_notebook
import seaborn as sns
import matplotlib.pyplot as plt
from torch.nn.modules.activation import ReLU
from torch.optim import Adam
from tqdm import tqdm_notebook
import matplotlib
import math
from functools import partial
from scipy.special import rel_entr
from torch import nn, einsum
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange
from torch import nn, einsum
import torch.nn.functional as F
import random
from livelossplot import PlotLosses
from functools import partial

%matplotlib inline


# Utils

In [None]:
nucleotides = ['A', 'C', 'T', 'G']

#### Data

In [None]:
def one_hot_encode(seq, alphabet, max_seq_len):
    """One-hot encode a sequence."""
    seq_len = len(seq)
    seq_array = np.zeros((max_seq_len, len(alphabet)))
    for i in range(seq_len):
        seq_array[i, alphabet.index(seq[i])] = 1
    return seq_array

def encode(seq, alphabet):
    """Encode a sequence."""
    seq_len = len(seq)
    seq_array = np.zeros(len(alphabet))
    for i in range(seq_len):
        seq_array[alphabet.index(seq[i])] = 1
    
    return seq_array

def show_seq(dataloader_seq):
    for i_image in dataloader_seq:
      for image in i_image:
        image = image.numpy().reshape(4,200)
        print (image)
        plt.rcParams["figure.figsize"] = (20,1)
        pd_seq = pd.DataFrame(image)
        pd_seq.index = nucleotides
        sns.heatmap(pd_seq, linewidth=1, cmap='bwr', center=0) 
        plt.show()

class SequenceDataset(Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, seqs, c, transform=None):
        'Initialization'
        self.seqs = seqs
        self.c = c
        self.transform = transform

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.seqs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        image = self.seqs[index]

        x = self.transform(image)

        y = self.c[index]

        return x, y


#### Metrics

In [None]:
def sampling_to_metric(number_of_samples=20, specific_group=False, group_number=None, cond_weight_to_metric=0):
  # Sampling regions using the trained  model
    final_sequences=[]
    for n_a in tqdm_notebook(range(number_of_samples)): # generating 20*10 sequences
      #sampled_images = bit_diffusion.sample(batch_size = 4)
        sample_bs = 10
        if specific_group:
            sampled = torch.from_numpy(np.array([group_number] * sample_bs) )
            print ('specific')
        else:
            sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs))
            
        
#         random_classes = torch.zeros((sample_bs, TOTAL_class_number))
#         random_classes = random_classes.scatter_(1, sampled.unsqueeze(dim=1), 1).float().cuda()
        random_classes = sampled.float().cuda()
        sampled_images = sample(model, classes=random_classes, image_size=image_size, batch_size=sample_bs, channels=1, cond_weight=cond_weight_to_metric, \
            betas=betas, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas=sqrt_recip_alphas, posterior_variance=posterior_variance)
        #sampled_images = sampled_images
        for n_b, x in enumerate(sampled_images[-1]):
            #x = x[-1]
            #print(x.shape)
            seq_final = f'>seq_test_{n_a}_{n_b}\n' +''.join([nucleotides[s] for s in np.argmax(x.reshape(4,200), axis=0)]) 
            final_sequences.append(seq_final)

    save_motifs_syn = open('synthetic_motifs.fasta', 'w')

    save_motifs_syn.write('\n'.join(final_sequences))
    save_motifs_syn.close()
    #Scan for motifs
    !gimme scan synthetic_motifs.fasta -p   JASPAR2020_vertebrates -g hg38 > syn_results_motifs.bed
    df_results_syn = pd.read_csv('syn_results_motifs.bed', sep='\t', skiprows=5, header=None)
    df_results_syn['motifs'] = df_results_syn[8].apply(lambda x: x.split( 'motif_name "')[1].split('"')[0]   )
    df_results_syn[0] = df_results_syn[0].apply(lambda x : '_'.join(  x.split('_')[:-1])    )
    df_motifs_count_syn = df_results_syn[[0,'motifs']].drop_duplicates().groupby('motifs').count()
    plt.rcParams["figure.figsize"] = (30,2)
    df_motifs_count_syn.sort_values(0, ascending=False).head(50)[0].plot.bar()
    plt.show()

    return df_motifs_count_syn


In [None]:
# Not using the total number of motifs but the count of the occurence aka: percentage of the sequences with a given motif.
def compare_motif_list(df_motifs_a, df_motifs_b):
  # Using KL divergence to compare motifs lists distribution
    set_all_mot = set(df_motifs_a.index.values.tolist() + df_motifs_b.index.values.tolist())
    create_new_matrix = []
    for x in set_all_mot:
        list_in = []
        list_in.append(x) # adding the name
        if x in df_motifs_a.index:
            list_in.append(df_motifs_a.loc[x][0])
        else:
             list_in.append(1)
                
        if x in df_motifs_b.index:
            list_in.append(df_motifs_b.loc[x][0])
        else:
             list_in.append(1)
        
        create_new_matrix.append(list_in)    
  

    df_motifs = pd.DataFrame(create_new_matrix, columns=['motif', 'motif_a', 'motif_b'])
    
    df_motifs['Diffusion_seqs'] = df_motifs['motif_a'] / df_motifs['motif_a'].sum()  
    df_motifs['Training_seqs'] = df_motifs['motif_b'] / df_motifs['motif_b'].sum()
    plt.rcParams["figure.figsize"] = (3,3)
    sns.regplot(x='Diffusion_seqs',  y='Training_seqs',data=df_motifs)
    plt.xlabel('Diffusion Seqs')
    plt.ylabel('Training Seqs')
    plt.title('Motifs Probs')
    plt.show()

    display(df_motifs) 
    kl_pq = rel_entr(df_motifs['Diffusion_seqs'].values, df_motifs['Training_seqs'].values )
    return np.sum(kl_pq)


In [None]:
def kl_comparison_between_dataset(first_dic, second_dict):
    final_comp_kl = []
    for k,v in first_dic.items():
        comp_array = []
        for k_second in second_dict.keys():
            kl_out = compare_motif_list(v, second_dict[k_second])
            comp_array.append(kl_out)
        final_comp_kl.append(comp_array)
    return final_comp_kl


#### Other

In [None]:
class EMA:  #https://github.com/dome272/Diffusion-Models-pytorch/blob/main/modules.py
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())


In [None]:
@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    #print (x.shape, 'x_shape')
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, time=t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def p_sample_guided(model, x, classes, t, t_index, context_mask, betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance, cond_weight=0.0):
    # adapted from: https://openreview.net/pdf?id=qw8AKxfYbI
    #print (classes[0])
    batch_size = x.shape[0]
    # double to do guidance with
    t_double = t.repeat(2)
    x_double = x.repeat(2, 1, 1, 1)
    betas_t = extract(betas, t_double, x_double.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t_double, x_double.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t_double, x_double.shape)
    
    # classifier free sampling interpolates between guided and non guided using `cond_weight`
    classes_masked = classes * context_mask
    classes_masked = classes_masked.type(torch.long)
    #print ('class masked', classes_masked)
    preds = model(x_double, time=t_double, classes=classes_masked)
    eps1 = (1 + cond_weight) * preds[:batch_size]
    eps2 = cond_weight * preds[batch_size:]
    x_t = eps1 - eps2 
    

    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t[:batch_size] * (
        x - betas_t[:batch_size] * x_t / sqrt_one_minus_alphas_cumprod_t[:batch_size]
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def p_sample_loop(model, classes, shape, cond_weight, timesteps, betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []
    
    if classes is not None:
        n_sample = classes.shape[0]
        context_mask = torch.ones_like(classes).to(device)
        # make 0 index unconditional
        # double the batch
        classes = classes.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 0. # makes second half of batch context free
        sampling_fn = partial(p_sample_guided, classes=classes, cond_weight=cond_weight, context_mask=context_mask, \
            betas=betas, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas=sqrt_recip_alphas, posterior_variance=posterior_variance)
    else:
        sampling_fn = partial(p_sample)
        
    
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = sampling_fn(model, x=img, t=torch.full((b,), i, device=device, dtype=torch.long), t_index=i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, classes=None, batch_size=16, channels=3, cond_weight=0, timesteps=50, betas=None, sqrt_one_minus_alphas_cumprod=None, sqrt_recip_alphas=None, posterior_variance=None):
    return p_sample_loop(model, classes=classes, shape=(batch_size, channels, 4, image_size), cond_weight=cond_weight, timesteps=timesteps, \
        betas=betas, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas=sqrt_recip_alphas, posterior_variance=posterior_variance)

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps, beta_end=0.005):
    beta_start = 0.0001

    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# forward diffusion
def q_sample(x_start, t, alphas_cumprod, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(torch.sqrt(alphas_cumprod), t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        torch.sqrt(1. - alphas_cumprod) , t, x_start.shape
    )   

    #print  (sqrt_alphas_cumprod_t , sqrt_one_minus_alphas_cumprod_t , t)

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def p_losses(denoise_model, x_start, t, classes, alphas_cumprod, noise=None, loss_type="l1", p_uncond=0.1):
    device = x_start.device
    if noise is None:
        noise = torch.randn_like(x_start) #  guass noise 
    x_noisy = q_sample(x_start=x_start, alphas_cumprod=alphas_cumprod, t=t, noise=noise) #this is the auto generated noise given t and Noise
    # print('max_q_sample', x_noisy.max(), 'mean_q_sample',x_noisy.mean() )
    
    context_mask = torch.bernoulli(torch.zeros(classes.shape[0]) + (1-p_uncond)).to(device)
    #print ('context mask', context_mask)
    #print ('classes', classes)
    
    # mask for unconditinal guidance
    classes = classes * context_mask
    # nn.Embedding needs type to be long, multiplying with mask changes type
    classes = classes.type(torch.long)
    #print ('final class',classes )
    predicted_noise = denoise_model(x_noisy, t, classes)   # this is the predicted noise given the model and step t
    # print('max_predicted', x_noisy.max(), 'mean_predicted',x_noisy.mean() )

    # #predicted is ok (clipped)
    # print ('predited inside loss')
    # print (predicted_noise)
    # print ('this is the noise generated by the p_losses')
    # print (noise)
    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        # print (noise.shape, 'noise' )
        # print (predicted_noise.shape, 'pred') 
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss


In [None]:
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def cycle(dl):
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

def l2norm(t):
    return F.normalize(t, dim = -1)

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

def beta_linear_log_snr(t):
    return -torch.log(expm1(1e-4 + 10 * (t ** 2)))

def alpha_cosine_log_snr(t, s: float = 0.008):
    return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version

def log_snr_to_alpha_sigma(log_snr):
    return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

#Ploting images for a forward pass example
def plot_images(dataloader, nucleotides, timesteps):
    for img_x_show, y in dataloader:
        for i in range(0,timesteps):
            if (i % 10) == 0:
                print (i)
                image_use = q_sample(img_x_show[0], t=torch.tensor([i]))
                image_use_numpy = image_use.numpy()
                
                plt.rcParams["figure.figsize"] = (20,1)
                pd_seq = pd.DataFrame(image_use_numpy.reshape(4,200))
                
                pd_seq.index = nucleotides
                sns.heatmap(pd_seq, linewidth=1, cmap='bwr', center=0) 
                plt.show()
                plt.rcParams["figure.figsize"] = (2,2)

                plt.bar(nucleotides, pd_seq.mean(1).T)
                plt.show()


# Models

In [None]:
# ADDED FOR HYPERNETWORK CONDITIONING
class ParamFreeLinear(nn.Module):
    def __init__(self, act_func=nn.Identity()):
        super().__init__()
        self.act_func = act_func

    def forward(self, x, W, b):
        return F.linear(self.act_func(x), weight=W, bias=b)

class ParamFreeConv2d(nn.Module):
    def __init__(self, kernel_size, stride=1, padding=0, dilation=1, groups=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

    def forward(self, x, W, b):
        return F.conv2d(x, weight=W, bias=b, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)

class ParamFreeGroupNorm(nn.Module):
    def __init__(self, num_groups):
        super().__init__()
        self.num_groups = num_groups

    def forward(self, x, W, b):
        return F.group_norm(x, num_groups=self.num_groups, weight=W, bias=b)


In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None):
    return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * (var + eps).rsqrt() * self.g

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

# positional embeds
class LearnedSinusoidalPosEmb(nn.Module):
    """ following @crowsonkb 's lead with learned sinusoidal pos emb """
    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

# MODIFIED FOR HYPERNETWORK CONDITIONING
class Block(nn.Module):
    def __init__(self, dim, dim_out, no_params=False, groups = 8):
        super().__init__()
        self.no_params = no_params

        if no_params is False:
            self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
            self.norm = nn.GroupNorm(groups, dim_out)
        else:
            self.proj = ParamFreeConv2d(kernel_size=3, padding=1)
            self.norm = ParamFreeGroupNorm(num_groups=groups)

        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None, params=None):
        assert self.no_params is False or params is not None, \
            "no_params is True but parameters not provided"
        
        x = self.proj(x) if not self.no_params else self.proj(x, W=params[0], b=params[1])
        x = self.norm(x) if not self.no_params else self.norm(x, W=params[2], b=params[3])

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

# MODIFIED FOR HYPERNETWORK CONDITIONING
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, no_params=False, groups = 8):
        super().__init__()
        self.no_params = no_params

        if exists(time_emb_dim):
            if no_params is False:
                self.mlp = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(time_emb_dim, dim_out * 2)
                )
            else:
                self.mlp = ParamFreeLinear(act_func=nn.SiLU()) # parameters provided externally

        self.block1 = Block(dim, dim_out, no_params=no_params, groups=groups)
        self.block2 = Block(dim_out, dim_out, no_params=no_params, groups=groups)

        self.res_conv = nn.Identity()
        if dim != dim_out:
            if no_params is False:
                self.res_conv = nn.Conv2d(dim, dim_out, 1)
            else:
                self.res_conv = ParamFreeConv2d(kernel_size=1) # parameters provided externally

    def forward(self, x, time_emb=None, params=None):
        assert self.no_params is False or params is not None, \
            "no_params is True but parameters not provided"

        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb) if not self.no_params else self.mlp(time_emb, W=params[0], b=params[1])
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift=scale_shift) if not self.no_params else self.block1(x, scale_shift=scale_shift, params=params[2:6])
        h = self.block2(h) if not self.no_params else self.block2(h, params=params[6:10])
        h = h + self.res_conv(x) if not self.no_params else h + self.res_conv(x, W=params[10], b=params[11])

        return h

class LinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            LayerNorm(dim)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale
        v = v / (h * w)

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)


In [None]:
# bit diffusion class
class Unet_lucas(nn.Module):
    def __init__(
        self,
        dim,
        init_dim = None,
        dim_mults=(1, 2, 4),
        channels = 1,
        resnet_block_groups = 8,
        learned_sinusoidal_dim = 18,
        num_classes=10,
        class_embed_dim=3,
        use_hypernet=False,
    ):
        super().__init__()

        ### determine dimensions
        channels = 1
        self.channels = channels
        # if you want to do self conditioning uncomment this
        #input_channels = channels * 2
        input_channels = channels
        init_dim = default(init_dim, dim)

        #self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) # original
        self.init_conv = nn.Conv2d(input_channels, init_dim, (7,7), padding = 3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, no_params=False, groups=resnet_block_groups)
        block_klass_param_free = partial(ResnetBlock, no_params=True, groups=resnet_block_groups)

        ### time embeddings
        time_dim = dim * 4
        fourier_dim = learned_sinusoidal_dim + 1
        sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        ### layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv2d(dim_out, dim_in, 3, padding = 1)
            ]))

        self.use_hypernet = use_hypernet
        if use_hypernet:
            self.final_res_block = block_klass_param_free(dim * 2, dim, time_emb_dim = time_dim)
        else:
            self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)

        if use_hypernet:
            self.final_conv = ParamFreeConv2d(kernel_size=1)
        else:
            self.final_conv = nn.Conv2d(dim, channels, 1)
        
        ### hypernetwork for class-conditioned weight generation
        if num_classes is not None:
            if use_hypernet:
                target_shapes = self._calc_hypernet_target_shapes(resnet_block_groups=resnet_block_groups, dim=dim, time_dim=time_dim, channels=channels)
                self.hypernet_head = ChunkedHMLP(
                    target_shapes=target_shapes,
                    chunk_size=10000,
                    chunk_emb_size=48,
                    cond_chunk_embs=True,
                    cond_in_size=12,
                    layers=[80, 120],
                    num_cond_embs=num_classes
                )
            else:
                self.label_emb = nn.Embedding(num_classes, time_dim)

        print('final',dim, channels, self.final_conv)
        
    def _calc_hypernet_target_shapes(self, resnet_block_groups, dim, time_dim, channels):
        target_shapes = []
        
        # initialize modules, get shapes and then delete them (ugly but works)
        block_klass = partial(ResnetBlock, no_params=False, groups=resnet_block_groups)
        tmp_final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        target_shapes.extend([list(p.shape) for p in tmp_final_res_block.parameters()])
        del tmp_final_res_block

        tmp_final_conv = nn.Conv2d(dim, channels, 1)
        target_shapes.extend([list(p.shape) for p in tmp_final_conv.parameters()])
        del tmp_final_conv

        return target_shapes

    def forward(self, x, time, classes, x_self_cond = None):
        assert classes is not None, 'classes must be provided'
        #x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
        #x = torch.cat((x_self_cond, x), dim = 1)
        #print ('classes inside unet',classes, 'time inside unet', time)
        x = self.init_conv(x)
        r = x.clone()

        t_start = self.time_mlp(time)
        t_mid = t_start.clone()
        t_end = t_start.clone()

        h = []
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t_start)#, classes)
            h.append(x)

            x = block2(x, t_start)#, classes)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t_mid)#, classes)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t_mid)#, classes)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t_mid)#, classes)

            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t_mid)#, classes)
            x = attn(x)

            x = upsample(x)
        
        x = torch.cat((x, r), dim = 1)

        if self.use_hypernet:
            # for each unique class, generate params for final res block and conv
            uniq_classes, x_idxs = torch.unique(classes, sorted=False, return_inverse=True)
            params = self.hypernet_head.forward(cond_id=uniq_classes.tolist())
            out = torch.zeros(x.shape[0], self.channels, x.shape[2], x.shape[3], device=x.device)
            for c_idx, (c, params_for_c) in enumerate(zip(uniq_classes, params)):
                tmp_x = self.final_res_block(x[x_idxs == c_idx], t_end[x_idxs == c_idx], params=params_for_c[:-2])#, classes)
                out[x_idxs == c_idx] = self.final_conv(tmp_x, W=params_for_c[-2], b=params_for_c[-1])              
            x = out

            ### LATER TODO: make this more efficient ###
            # 'params' is a list of lists of parameters, ex: [[W1,b1],[W2,b2],...] where W1,b1 are torch.Tensor
            #   -> group by parameter/layer destination to run only once
            # generate params for final res block and conv
            # params = self.hypernet_head.forward(cond_id=classes.tolist())
            # grouped_params_per_type = list(zip(*params)) # [[W1,W2],[b1,b2],...]
            # grouped_params_per_type = [torch.cat(p, dim=0) for p in grouped_params_per_type] # [W,b,...]
            # x = self.final_res_block(x, t_end, params=grouped_params_per_type[:-2])#, classes)
            # x = self.final_conv(x, W=params[-2], b=grouped_params_per_type[-1])
            ############################################
        else:
            x = self.final_res_block(x, t_end)
            x = self.final_conv(x)
        return x


# Loading data and generating fasta files and motifs

In [None]:
class LoadingData():
    def __init__(self,input_csv,sample_number=1000,subset_components=None, plot=True, change_component_index = True, ):
        '''subset_components: (NONE) list
                subset components should be a list of components to extract ex: [3,12,4,3]
           sample_number : 1000
               Total number o sequences
         '''
        self.csv = input_csv
        self.plot = plot
        self.sample_number= sample_number
        self.subset_components = subset_components    # case none I need add all
        self.change_comp_index = change_component_index
        self.data =  self.read_csv()
        self.df_generate = self.experiment() 
        self.df_train_in, self.df_test_in , self.df_train_shuffled_in = self.create_train_groups()
        
        self.train = None
        self.test = None
        self.train_shuffle = None
        self.get_motif()
    
    def read_csv(self):
        df =  pd.read_csv( self.csv , sep="\t")
        if self.change_comp_index:
            df['component']= df['component'] + 1
        return df

    def experiment(self):
        df_generate = self.data.copy()
        if self.subset_components != None and type(self.subset_components) == list:
            df_generate = df_generate.query( ' or '.join([f'component == {c}' for c in self.subset_components])).copy()
            print ('Subseting...')
        
        if self.plot:
            (df_generate.groupby('component').count()['raw_sequence']  / df_generate.groupby('component').count()['raw_sequence'].sum() ).plot.bar()
            plt.title('Component % on Training Sample')
            plt.show()

        return df_generate
    
    def create_train_groups(self):
        df_sampled = self.df_generate.sample(self.sample_number*2) #getting train and test
        df_train = df_sampled.iloc[:self.sample_number].copy()
        df_test = df_sampled.iloc[self.sample_number:].copy()
        df_train_shuffled = df_train.copy()
        df_train_shuffled['raw_sequence'] = df_train_shuffled['raw_sequence'].apply(lambda x : ''.join(random.sample(list(x), len(x)))  )
        return df_train , df_test , df_train_shuffled
        
    def get_motif(self):
        self.train = self.generate_motifs_and_fastas(self.df_train_in,  'train')
        self.test = self.generate_motifs_and_fastas(self.df_test_in, 'test')
        self.train_shuffle = self.generate_motifs_and_fastas(self.df_train_shuffled_in,'train_shuffle')
        
    def generate_motifs_and_fastas(self, df,name):
        '''return fasta anem , and dict with components motifs'''
        print ('Generating Fasta and Motis:', name)
        print ('---' * 10)
        fasta_saved = self.save_fasta(df, f"{name}_{self.sample_number}_{'_'.join([str(c) for c in self.subset_components])}")
        print('Generating Motifs (all seqs)')
        motif_all_components = LoadingData.motifs_from_fasta(fasta_saved, False)
        print('Generating Motifs per component')
        train_comp_motifs_dict = self.generate_motifs_components(df)
        return {'fasta_name':fasta_saved ,
                'motifs': motif_all_components ,    
                'motifs_per_components_dict':train_comp_motifs_dict ,
                'dataset': df}
        
    def save_fasta(self, df , name_fasta):
        fasta_final_name = name_fasta + '.fasta'
        save_fasta_file= open(fasta_final_name, 'w')
        write_fasta_component = '\n'.join(df[['Unnamed: 0', 'raw_sequence', 'component']].apply(lambda x : f'>{x[0]}_component_{x[2]}\n{x[1]}', axis=1).values.tolist())
        save_fasta_file.write(write_fasta_component)
        save_fasta_file.close()
        return fasta_final_name
    
    def generate_motifs_components(self, df):
        final_comp_values = {}
        for comp,v_comp in df.groupby('component'):
            print (comp)
            name_c_fasta = self.save_fasta(v_comp, 'temp_component')
            final_comp_values[comp] = LoadingData.motifs_from_fasta(name_c_fasta, False)
        return final_comp_values
    
    @staticmethod
    def motifs_from_fasta(fasta, generate_heatmap=True):
        print ('Computing Motifs....')
        !gimme scan $fasta -p  JASPAR2020_vertebrates -g hg38 > train_results_motifs.bed
        df_results_seq_guime = pd.read_csv('train_results_motifs.bed', sep='\t', skiprows=5, header=None)
        df_results_seq_guime['motifs'] = df_results_seq_guime[8].apply(lambda x: x.split( 'motif_name "'    )[1].split('"')[0]   )
        # if generate_heatmap:
        #     generate_heatmap_motifs(df_results_seq_guime)

        df_results_seq_guime[0] = df_results_seq_guime[0].apply(lambda x : '_'.join(  x.split('_')[:-1])    )
        df_results_seq_guime_count_out = df_results_seq_guime[[0,'motifs']].drop_duplicates().groupby('motifs').count()
        plt.rcParams["figure.figsize"] = (30,2)
        df_results_seq_guime_count_out.sort_values(0, ascending=False).head(50)[0].plot.bar()
        plt.title('Top 50 MOTIFS on component 0 ')
        plt.show()
        return df_results_seq_guime_count_out

encode_data = LoadingData("../../../models/vanilla_diffusion/train_all_classifier_WM20220916.csv", 
    sample_number=1000, subset_components=[3,8,12,15], plot=False)

In [None]:
df_results_seq_guime_count_train = encode_data.train['motifs']
df_results_seq_guime_count_test = encode_data.test['motifs']
df_results_seq_guime_count_shuffle = encode_data.train_shuffle['motifs']

final_comp_values_trian = encode_data.train['motifs_per_components_dict']
final_comp_values_test = encode_data.test['motifs_per_components_dict']
final_comp_values_shuffle = encode_data.train_shuffle['motifs_per_components_dict']

df = encode_data.train['dataset']
cell_components = df.sort_values('component')['component'].unique().tolist()  #I need to add this function inside the dataloader

In [None]:
names_comp = '''7 Trophoblasts
5 CD8_cells
15 CD34_cells
9 Fetal_heart
12 Fetal_muscle
14 HMVEC(vascular)
3 hESC(Embryionic)
8 Fetal(Neural)
13 Intestine
2 Skin(stromalA)
4 Fibroblast(stromalB)
6 Renal(Cancer)
16 Esophageal(Cancer)
11 Fetal_Lung
10 Fetal_kidney
1 Tissue_Invariant'''.split('\n')

labels_test = {int(x.split(' ')[0]): x.split(' ')[1]   for x in  names_comp}

def generate_heatmap_components(df_heat, x_label, y_label):
    plt.rcParams["figure.figsize"] = (10,10)
    df_plot = pd.DataFrame(df_heat)
    df_plot.columns = [labels_test[x] for x in cell_components]
    df_plot.index = df_plot.columns
    sns.heatmap(df_plot, cmap='Blues_r', annot=True, lw=0.1, vmax=1, vmin=0 )
    plt.title(f'Kl divergence \n {x_label} sequences x  {y_label} sequences \n MOTIFS probabilities')
    plt.xlabel(f'{x_label} Sequences  \n(motifs dist)')
    plt.ylabel(f'{y_label} \n (motifs dist)')

heat_train_test = kl_comparison_between_dataset(final_comp_values_trian, final_comp_values_test)
generate_heatmap_components(heat_train_test, 'Train', 'Test')

In [None]:
heat_train_shuffle = kl_comparison_between_dataset(final_comp_values_trian, final_comp_values_shuffle)
generate_heatmap_components(heat_train_shuffle, 'Train', 'shuffle')

In [None]:
dna_alphabet = ['A', 'C', 'T', 'G']
x_train_seq = np.array([one_hot_encode(x, dna_alphabet, 200) for x in tqdm_notebook(df['raw_sequence']) if 'N' not in x ])
X_train = x_train_seq
X_train = np.array([x.T.tolist()  for x in X_train])
X_train[X_train == 0] = -1
X_train.shape

# Train initialize and Training loop

In [None]:
# conditional training init
cell_types = sorted(list(df.component.unique()))
print(cell_types)
TOTAL_class_number = 17
x_train_cell_type = torch.from_numpy(df["component"].to_numpy())

# prep data loader
# batch_size = 256
batch_size = 14 # not optimal 
seq_dataset = SequenceDataset(seqs=X_train, c=x_train_cell_type, transform=T.Compose([T.ToTensor()]))
train_dl = DataLoader(seq_dataset, batch_size, shuffle=True, num_workers=3, pin_memory=True)

In [None]:
timesteps = 50

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps, beta_end=0.2)
#betas = cosine_beta_schedule(timesteps=timesteps,  s=0.0001)
# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
#sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) 
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

#Ploting images for a forward pass example
# plot_images(train_dl, nucleotides, timesteps)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
image_size = 200
channels = 1
use_hypernet = True

model = Unet_lucas(
    dim = image_size,
    channels = channels,
    dim_mults = (1,2,4),
    resnet_block_groups = 4,
    num_classes=TOTAL_class_number,
    use_hypernet=use_hypernet
).cuda()

model.to(device)
print(f"# of parameters of the full model: {count_parameters(model)}")
print(f"# of parameters of the hypernet alone: {0 if not use_hypernet else count_parameters(model.hypernet_head)}")

optimizer = Adam(model.parameters(), lr=1e-4)

train_kl, test_kl, shuffle_kl = 1, 1, 1
live_kl = PlotLosses(groups={ 'KL':['train', 'test', 'shuffle'], 'DiffusionLoss':['loss'] })

In [None]:
for step, batch in enumerate(train_dl):
    x, y = batch
    x = x.type(torch.float32).to(device)
    y = y.type(torch.long).to(device)
    break
t = torch.randint(0, timesteps, (x.shape[0],), device=device).long() # sampling a t to generate t and t+1
loss = p_losses(model, x, t, y, alphas_cumprod=alphas_cumprod, loss_type="huber")

## Metrics: Train and Test should go down (train lower than test) and shuffle should go up

In [None]:
epochs = 10000
save_and_sample_every = 50
epochs_loss_show = 10 

ema = EMA(0.995)
ema_model = copy.deepcopy(model).eval().requires_grad_(False)


for epoch in tqdm(range(epochs)):
    model.train()
    for step, batch in enumerate(train_dl):
        x, y = batch
        x = x.type(torch.float32).to(device)
        y = y.type(torch.long).to(device)

        batch_size = x.shape[0]
        # Algorithm 1 line 3: sample t uniformally for every example in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long() # sampling a t to generate t and t+1
        #loss = p_losses(model, batch, t, loss_type="l2")
        loss = p_losses(model, x, t, y, alphas_cumprod=alphas_cumprod, loss_type="huber")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema.step_ema(ema_model, model)
        # live_kl.update({'train':train_kl, 'test':test_kl , 'shuffle':shuffle_kl , 'loss': loss.item()})
        # live_kl.send()

    if (epoch % epochs_loss_show) == 0:
        print(f" Epoch {epoch} Loss:", loss.item())

    # save generated images
    if epoch != 0 and epoch % save_and_sample_every == 0:
        model.eval()
        print('saving')
        # torch.save(model,"UNET_HYPERNET_components_3_8_12_15.model")
        milestone = step // save_and_sample_every
        sample_bs = 2
        #This needs to be fixed to the random
        sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs))
        random_classes = sampled.cuda()
        
        samples = sample(model, classes=random_classes, image_size=image_size, batch_size=2, channels=1, cond_weight=1, timesteps=timesteps, \
            betas=betas, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas=sqrt_recip_alphas, posterior_variance=posterior_variance)
        n_print = 0
        for image, class_show in zip(samples[-1], random_classes):
            if n_print < 4:
                plt.rcParams["figure.figsize"] = (20,1)
                pd_seq = pd.DataFrame(image.reshape(4,200))
                pd_seq.index = nucleotides
                sns.heatmap(pd_seq, linewidth=1, cmap='bwr', center=0) 
                plt.show()
                plt.rcParams["figure.figsize"] = (2,2)

                plt.bar(['a', 'c', 't', 'g'],pd_seq.mean(1).T)
                plt.title(f'Class: {class_show}')
                plt.show()
                n_print = n_print + 1
        synt_df = sampling_to_metric(20)
        train_kl = compare_motif_list(synt_df, df_results_seq_guime_count_train)
        test_kl  = compare_motif_list(synt_df, df_results_seq_guime_count_test)
        shuffle_kl = compare_motif_list(synt_df, df_results_seq_guime_count_shuffle)
        live_kl.update({'train':train_kl, 'test':test_kl , 'shuffle':shuffle_kl , 'loss': loss.item()})
        live_kl.send()
        print('KL_TRAIN', train_kl  , 'KL' )
        print('KL_TEST',  test_kl  , 'KL' )
        print('KL_SHUFFLE',  shuffle_kl , 'KL' )


In [None]:
# model = torch.load('UNET_HYPERNET_components_3_8_12_15.model').to(device)

## Module metric by component

In [None]:
def kl_comparison_generated_sequences(components_list, dict_targer_components):
    '''
    ex: components_list = [3, 8, 12, 15]
    '''
    final_comp_kl = []
    use_comp_list = components_list
    for r in use_comp_list:
        print (r), 'component'
        comp_array = []
        group_compare = r
        synt_df_cond = sampling_to_metric(20
                                          , True, group_compare, cond_weight_to_metric=1 ) 
        for k in use_comp_list:
            v = dict_targer_components[k]

            kl_out = compare_motif_list(synt_df_cond, v)
            print (r,k,kl_out)
            comp_array.append(kl_out)
        final_comp_kl.append(comp_array)
    return final_comp_kl

In [None]:
model.eval()
use_comp =  [3, 8, 12, 15]
heat_new_sequences_test  =  kl_comparison_generated_sequences(use_comp, final_comp_values_trian)

In [None]:
generate_heatmap_components(heat_new_sequences_test, 'DNA(DIFFUSION)', 'Train')

# The Lowest values should be present from the up to bottown diagonal ~ 1000 epochs 