In [4]:
import itertools
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torchinfo import summary
import numpy as np
import pandas as pd


def config():
    bases = ["A","C","G","T"]
    bits = len(bases)
    pairs,diPairs = {},{}
    for i,v in enumerate(bases):
        pairs[v] = np.zeros(len(bases))
        pairs[v][i] = 1
    pairs["N"] = np.ones(bits, dtype = float) / bits

    for i, di in enumerate(itertools.product(bases, repeat = 2)):
        diPairs[di] = np.zeros((bits * bits))
        diPairs[di][i] = 1

    for bp1 in bases:
        diPairs[(bp1, "N")] = np.zeros((bits * bits))
        diPairs[("N", bp1)] = np.zeros((bits * bits))
        for bp2 in bases:
            diPairs[(bp1, "N")] += diPairs[(bp1, bp2)]
            diPairs[("N", bp1)] += diPairs[(bp2, bp1)]
        diPairs[(bp1, "N")] = diPairs[(bp1, "N")] / bits
        diPairs[("N", bp1)] = diPairs[("N", bp1)] / bits
    return pairs, diPairs

def one_hot(seq,pairs):
    return np.array(list(map(lambda x: pairs[x], seq)))

def preprocess_selfloop(x):
    num_linkages = x.shape[0]
    linkages = torch.arange(num_linkages)
    linkages_prev = torch.cat([linkages,F.pad(torch.arange(num_linkages-1),(1,0))])
    linkages_prev= linkages_prev.reshape((-1,num_linkages)).T
    linkages_next = torch.cat([linkages,F.pad(torch.arange(1,num_linkages),(0,1),value=num_linkages-1)])
    linkages_next = linkages_next.reshape((-1,num_linkages)).T
    return x, linkages_prev,linkages_next


class MessagePassingConv(nn.Module):
    
    def __init__(self,filters=64, multiply=True, if_gru=False):
        super(MessagePassingConv,self).__init__()
        self.multiply = multiply
        self.if_gru = if_gru
        self.weight_next = nn.Parameter(torch.randn(filters,filters))
        self.weight_prev = nn.Parameter(torch.randn(filters,filters))
        self.bias = nn.Parameter(torch.zeros(1,filters))
        if self.multiply:
            if self.multiply == "add":
                self.weight_next_all = nn.Parameter(torch.randn(filters,filters))
                self.weight_prev_all = nn.Parameter(torch.randn(filters,filters))
            self.bias_all = nn.Parameter(torch.zeros(1,filters))
        self.bn = nn.BatchNorm1d(filters)
        if if_gru:
            self.gru_layer = torch.nn.GRUCell(filters,filters)
    
    def forward(self,x,pprev,pnext,kmers):
        #input_shape is (node features, node pairs to next, node pairs to prev)
        #               [(B * N, k), (None, 2), (None, 2)]
        #output_shape is (B * N, k)
        dim0,dim1 = x.shape[0],x.shape[1]
        prev_x = torch.gather(x,0,pprev[:,-1].unsqueeze(1).expand(1,dim0,dim1).squeeze())
        prev_sumx = torch.zeros(dim0,dim1,dtype=torch.float)
        prev_sumx=prev_sumx.scatter_add(0, pprev[:,0].unsqueeze(1).expand(1,dim0,dim1).squeeze(), prev_x)
        next_x = torch.gather(x,0,pnext[:,-1].unsqueeze(1).expand(1,dim0,dim1).squeeze())
        next_sumx = torch.zeros(dim0,dim1,dtype=torch.float)
        next_sumx=next_sumx.scatter_add(0, pnext[:,0].unsqueeze(1).expand(1,dim0,dim1).squeeze(), next_x)
        aggre=torch.matmul(prev_sumx,self.weight_prev)+torch.matmul(next_sumx,self.weight_next)+ self.bias
        if self.multiply:
            if self.multiply == "add":
                aggre = aggre + torch.matmul(next_sumx, self.weight_next_all) + tf.matmul(prev_sumx, self.weight_prev_all)
            aggre = aggre * x  + self.bias_all
        else:
            aggre = aggre + x
        aggre = self.bn(F.relu(aggre))
        if self.if_gru:
            x= self.gru_layer(aggre, x)
        else:
            x = F.sigmoid(aggre)
        return x 

    
    
class AVGFeature(nn.Module):
    
    def __init__(self,targetFeature=1,filters=64):
        super(AVGFeature,self).__init__()
        self.targetFeature = targetFeature if targetFeature != 0 else 1
        self.channel = filters
        self.aggrefeaturenums = self.channel // self.targetFeature
        self.padding = self.channel %self.targetFeature
    
    def forward(self,x):
        x = F.pad(x, (0,0, 0, self.padding))
        x = torch.reshape(x, (-1, self.targetFeature, self.aggrefeaturenums))
        return torch.reshape(torch.mean(x, dim=-1), [-1, self.targetFeature])

class DNAModel(nn.Module):
    
    def __init__(self,input_features=4, filters=64, basefeatures=1, mp_layer=1, mp_steps=10,
                 dropout_rate=0.0,constraints=True):
        super(DNAModel,self).__init__()
        self.mp_layers_count = mp_layer
        self.input_features = input_features
        self.steps = mp_steps
        self.mp_layers = []
        self.num_basefeatures = basefeatures
        self.filter_size = filters
        self.dropout_rate = dropout_rate
        self.constraints=constraints
        self.selfconv = torch.nn.Conv1d(input_features,filters, 1)
        for _ in range(mp_layer):
            self.mp_layers.append(MessagePassingConv(filters=filters))
        self.dropout_layer = torch.nn.Dropout(dropout_rate)
        self.avg_layer = AVGFeature(self.num_basefeatures, filters=filters)
        
    def forward(self,x,pprev,pnext,kmers):
        x = self.selfconv(torch.permute(x,(0,2,1)))
        x = torch.permute(x,(0,2,1)).squeeze(0)
        results=[]
        for i,layer in enumerate(self.mp_layers):
            for _ in range(self.steps):
                x = layer(x, pprev,pnext, kmers)
            if self.constraints:
                x = self.dropout_layer(x)
                results.append(self.avg_layer(x))
        if self.constraints:
            return torch.concat(results, dim=1)
        x = self.dropout_layer(x)
        return self.avg_layer(x)
    
    
seq = "NN" + "ACGT" + "NN"
pairs,diPairs = config()
x = one_hot(seq,pairs)
kmers = [len(seq)]
x,pprev,pnext = preprocess_selfloop(x)
x = torch.tensor(x,dtype=torch.float).unsqueeze(0)

# messagepassing
# x = torch.permute(x,(0,2,1))
# selfconv = torch.nn.Conv1d(4,64, 1)
# x=selfconv(x)
# x=torch.permute(x,(0,2,1)).reshape((-1,64))
# model = MessagePassingConv()
# model(x,pprev,pnext,kmers).shape
# model

model = DNAModel()
print(summary(model, input_data=[x,pprev,pnext,kmers]))

Layer (type:depth-idx)                   Output Shape              Param #
DNAModel                                 [8, 1]                    --
├─Conv1d: 1-1                            [1, 64, 8]                320
├─Dropout: 1-2                           [8, 64]                   --
├─AVGFeature: 1-3                        [8, 1]                    --
Total params: 320
Trainable params: 320
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.00
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.01


In [5]:
import tensorflow as tf

def oneHot(seq, padding = False):
    #if padding:
    #    seq = "N" + seq + "N"
    return np.array(list(map(lambda x: pairs[x], seq)))
def oneHotDi(seq, padding = False):
    #if padding:
    #    seq = "N" + seq + "N"
    return np.array(list(map(lambda i: diPairs[(seq[i], seq[i+1])], range(len(seq)-1))))


def loadData(filenames):
    fhandle = []
    for filename in filenames:
        fhandle.append(open(filename, "r"))
    while 1:
        curSeq = ""
        features = []
        eof = False
        for fin in fhandle:
            line = fin.readline()
            if not line:
                continue
            items = line.strip().split()
            seq = items[0]
            features += list(map(float, items[1:]))
            if padded:
                seq = "NN" + seq + "NN"
                features = [float("nan"), float("nan")] + features + [float("nan"), float("nan")]
            yield tf.ragged.constant(oneHot(seq, padded), dtype = tf.float32), tf.ragged.constant(oneHotDi(seq, padded), dtype = tf.float32), tf.ragged.constant(np.array(features)[tf.newaxis], dtype = tf.float32)
    for fin in fhandle:
        fin.close()
import os
from functools import partial
dname = "Expt" # MC, MD
padded = False
fdir = f"/Users/john/data/DNAshape/Deep DNAshape Training data/{dname}"
filenames = []
for root,dirs,files in os.walk(fdir):
    for f in files:
        filenames.append(os.path.join(root,f))

for v in partial(loadData, filenames)():
    print(v)

(<tf.RaggedTensor [[0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.

(<tf.RaggedTensor [[0.0, 0.0, 1.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 1.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0

(<tf.RaggedTensor [[1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0]]>, <tf.RaggedTensor [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 

(<tf.RaggedTensor [[0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
  0.0, 0.0],
 [1.0, 0.0, 0.0,

(<tf.RaggedTensor [[0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 1.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

(<tf.RaggedTensor [[0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0,

(<tf.RaggedTensor [[0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0

(<tf.RaggedTensor [[1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

(<tf.RaggedTensor [[0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  1.0, 0.0]]>, <tf.RaggedTensor [[-0.21, -0.29333332, -0.16666667, 0.13666667, 0.44333333, -0.5125,
  -0.265, -0.15, -0.053333335, -0.026666667, 0.30333334, 38.445,
  30.093334, 38.72333, 31.19, 43.616665, -0.2225, -0.59, 0.16666667,
  -0.6533333, 1.33, -3.65, -10.2775, -7.9766665, -16.576666, -10.823334,
  5.3566666, -0.1825, -0.0525, 0.10333333, 0.043333333, -0.06,
  -0.06333333, -0.49, 0.05, 0.1, -0.03, 0.2

(<tf.RaggedTensor [[0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0]]>, <tf.RaggedTensor [[0.39208, -0.39144, -0.29896, 0.7972, 0.0854, 0.0236, -0.0798, 0.19844,
  -0.2817255, 35.22048, 36.9248, 35.25844, 27.90348, 0.13048, 0.78004,
  -0.44092, -0.81356, -10.1096, -7.10068, -13.62712, -7.57184, 1.9028628,
  -0.00016, -0.11684, 0.06832, 0.05144, 0.11164706, -0.03908, 0.22116,
  -0.12704, 0.08704, 0.1412549, -1.21964, 1.6918, 1.25632, -0.26152,
  2.3759608, 4.32768, 3.10124, -0.47228, 2.09568, -8.39472, -1.6586,
  -0.78476, 6.15828, 14.103333, 3.258, 3.40492, 3.09236, 3.191

(<tf.RaggedTensor [[0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 

(<tf.RaggedTensor [[0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 1.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 1.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 1.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0

(<tf.RaggedTensor [[0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 1.0, 0.0],
 [0.0, 0.0, 1.0, 0.0],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 1.0],
 [0.0, 1.0, 0.0, 0.0],
 [0.0, 1.0, 0.0, 0.0]]>, <tf.RaggedTensor [[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  1.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
  0.0, 0.0],
 [0.0, 0.0, 0.0, 0.0, 0

AssertionError: 