# CONSTANTS

In [1]:
import math

GLOBAL = {
    'embedded_length': 512,
    'continuity_length': 64,
    'n_vocab': 29
}

# Transformer (Top level)
TRANS_CONST = {
    'n_attention_layers': 8,
    'n_attention_heads': 8,    

    'max_output_length': 6,
    
    'embedding_dic_size': GLOBAL['n_vocab'], 
    'embedded_vec_size': GLOBAL['embedded_length'],
    
    # 'pos_encoding_input': GLOBAL['embedded_length'],
    # 'pos_encoding_output': GLOBAL['embedded_length'],
    
    'linear_input': GLOBAL['embedded_length'],
    'linear_output': GLOBAL['n_vocab'] # output vocab size
}

# Encoder, EncoderLayer
ENCODER_CONST = {
    'norm1_size': GLOBAL['embedded_length'], # same as input matrix width
    'norm2_size': GLOBAL['embedded_length'],

    # maybe rename these two, it's just for knowing the input dim and the dim that the FF layer will work with
    'ff1': GLOBAL['embedded_length'], 
    'ff2': GLOBAL['embedded_length'] * 4
}

# Decoder, DecoderLayer
DECODER_CONST = {
    'norm1_size': GLOBAL['embedded_length'], # same as input matrix width
    'norm2_size': GLOBAL['embedded_length'],
    'norm3_size': GLOBAL['embedded_length'],

    'ff1': GLOBAL['embedded_length'],#TODO RENAME
    'ff2': GLOBAL['embedded_length'] * 4#TODO RENAME
}

# MultiHeadAttention, SingleHeadAttention
ATTENTION_CONST = {
    'mh_concat_width': GLOBAL['continuity_length']*TRANS_CONST['n_attention_heads'], # single head attention width * number of heads
    'mh_output_width': GLOBAL['embedded_length'], #TODO - I'm just guessing this. Didn't see in illustrated transformer. Since we have to use this for the add & norm layer though it has to be the same as the input width (I think)

    # W_q weight matrix 
    'sh_linear1_input': GLOBAL['embedded_length'], # same as embedded length to end up with n_words x 64
    'sh_linear1_output': GLOBAL['continuity_length'], # specified in the paper
    # W_k weight matrix 
    'sh_linear2_input': GLOBAL['embedded_length'], # same as embedded length to end up with n_words x 64
    'sh_linear2_output': GLOBAL['continuity_length'], # specified in the paper
    # W_v weight matrix 
    'sh_linear3_input': GLOBAL['embedded_length'], # same as embedded length to end up with n_words x 64
    'sh_linear3_output': GLOBAL['continuity_length'], # specified in the paper
    
    'sh_scale_factor': math.sqrt(GLOBAL['continuity_length']) # specified in the paper, square root of dimension of key vector/matrix (64)
}

# FeedForward
FEEDFORWARD_CONST = {
    'dropout': 0.1
}



# FEEDFORWARD

In [2]:
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, dim_model, dim_ff, dropout=FEEDFORWARD_CONST['dropout']):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(dim_model, dim_ff)
        # self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_ff, dim_model)

    def forward(self, inputs):
        x = inputs
        x = self.linear1(x)
        x = nn.functional.relu(x) 
        # x = self.dropout(x) 
        x = self.linear2(x) 
        return x


# MULTIHEADATTENTION

In [3]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads):
        super(MultiHeadAttention, self).__init__()

        self.linear = nn.Linear(64, 512)
        self.wQ = nn.Linear(ATTENTION_CONST['sh_linear1_input'], ATTENTION_CONST['sh_linear1_output'])
        self.wK = nn.Linear(ATTENTION_CONST['sh_linear2_input'], ATTENTION_CONST['sh_linear2_output'])
        self.wV = nn.Linear(ATTENTION_CONST['sh_linear3_input'], ATTENTION_CONST['sh_linear3_output'])

    def forward(self, q, k, v):
        q = self.wQ(q)
        k = self.wK(k)
        v = self.wV(v)

        # split heads - I think they do this instead of a loop
        x, attention_weights = self.applyHeads(q, k, v)
        # transpose ?
        # reshape ?
        x = self.linear(x)
        return x, attention_weights

    def applyHeads(self, q, k, v, mask=None):
        x = torch.matmul(q, k.permute(1, 0)) 
        # scale x
        # add mask
        attention_weights = nn.Softmax(dim=-1)(x)
        x = torch.matmul(attention_weights, v)
        return x, attention_weights



# DECODER

In [4]:
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, n_layers, n_attention_heads):
        super(Decoder, self).__init__()
        
        #self.embedding
        #self.pos_encoding
        #self.dropout
        self.decoderLayers = nn.ModuleList([DecoderLayer(n_attention_heads) for _ in range(n_layers)])

    def forward(self, inputs, encoderOut):
        x = inputs
        #x = self.embedding(x)
        #x = x * root of dim, whatever that is exactly
        #x = self.pos_encoding(x)
        #x = self.dropout(x)

        attention_weights = []
        for layer in self.decoderLayers:
            x, att1, att2 = layer(x, encoderOut) 
            attention_weights.append([att1, att2])

        return x, attention_weights

class DecoderLayer(nn.Module):
    def __init__(self, n_attention_heads):
        super(DecoderLayer, self).__init__()

        self.mhattention1 = MultiHeadAttention(n_attention_heads)
        self.mhattention2 = MultiHeadAttention(n_attention_heads)
        self.norm1 = nn.LayerNorm(DECODER_CONST['norm1_size'])
        self.norm2 = nn.LayerNorm(DECODER_CONST['norm2_size'])
        self.norm3 = nn.LayerNorm(DECODER_CONST['norm3_size'])
        self.dropout1 = nn.Dropout()
        self.dropout2 = nn.Dropout()
        self.dropout3 = nn.Dropout()
        self.feedforward = FeedForward(DECODER_CONST['ff1'], DECODER_CONST['ff2'])

    def forward(self, inputs, encoderOut):
        x = inputs
        z = x
        x, att1 = self.mhattention1(x, x, x)
        x = self.dropout1(x)
        x = z + x        
        x = self.norm1(x)
        z = x
        x, att2 = self.mhattention2(x, encoderOut, encoderOut) 
        x = self.dropout2(x)
        x = z + x
        x = self.norm2(x)
        z = x
        x = self.feedforward(x)
        x = self.dropout3(x)
        x = z + x
        x = self.norm3(x)
        return x, att1, att2



# ENCODER

In [5]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, n_layer, n_attention_heads):
        super(Encoder, self).__init__()
        
        #self.embedding
        #self.pos_encoding
        #self.dropout
        self.encoderLayers = nn.ModuleList([EncoderLayer(n_attention_heads) for _ in range(n_layer)])

    def forward(self, inputs):
        x = inputs
        #x = self.embedding(x)
        #x = x * root of dim, whatever that is exactly
        #x = self.pos_encoding(x)
        #x = self.dropout(x)

        for layer in self.encoderLayers:
            x = layer(x)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, n_attention_heads):
        super(EncoderLayer, self).__init__()

        self.mhattention = MultiHeadAttention(n_attention_heads)
        self.norm1 = nn.LayerNorm(ENCODER_CONST['norm1_size'])
        self.norm2 = nn.LayerNorm(ENCODER_CONST['norm2_size'])
        self.dropout1 = nn.Dropout()
        self.dropout2 = nn.Dropout()
        self.feedforward = FeedForward(ENCODER_CONST['ff1'], ENCODER_CONST['ff2'])

    def forward(self, inputs):
        x = inputs 
        z = x
        x, _ = self.mhattention(x, x, x)
        x = self.dropout1(x)
        x = z + x
        x = self.norm1(x)
        z = x
        x = self.feedforward(x)
        x = self.dropout2(x)
        x = z + x
        x = self.norm2(x) 
        return x



# TRANSFORMER

In [6]:
import torch
import torch.nn as nn
import numpy

class Transformer(nn.Module):
    def __init__(self, n_layers=TRANS_CONST['n_attention_layers'], n_attention_heads=TRANS_CONST['n_attention_heads']):
        super(Transformer, self).__init__()

        self.encoder = Encoder(n_layers, n_attention_heads)
        self.decoder = Decoder(n_layers, n_attention_heads)
        self.embedding = nn.Embedding(TRANS_CONST['embedding_dic_size'], TRANS_CONST['embedded_vec_size'])
        # self.posEncoding = #TODO
        self.linear = nn.Linear(TRANS_CONST['linear_input'], TRANS_CONST['linear_output'])
        self.softmax = nn.Softmax(dim=1)

    def __call__(self, inputs=None):
        if inputs != None: 
            raise NotImplementedError

        import random
        inputs = []
        for _ in range(13): inputs.append(numpy.zeros(26)) # 26 is vocab size, should be constant; 13 is just a random amount of words in the sequence
        inputs = torch.Tensor(inputs)
        for i in inputs: i[random.randint(0, len(i) - 1)] = 1

        return self.forward(inputs.long())

    def forward(self, inputs):
        x = self.doEmbedding(inputs)
        encoderOut = self.encoder(x)
        x, weights = self.decoder(x, encoderOut)
        x, weights = self.decoder(x, encoderOut)
        x = self.linear(x)
        x = self.softmax(x)
        return x, weights

    def doEmbedding(self, inputs):
        x = inputs.nonzero()[:, 1] # this gets all indices of nonzero values from the inputs matrix
        x = self.embedding(x)
        # x = self.posEncoding(x)
        return x








# Check Shapes

In [7]:
import random
inputs = []
for _ in range(13): inputs.append(numpy.zeros(26)) # 26 is vocab size, should be constant; 13 is just a random amount of words in the sequence
inputs = torch.Tensor(inputs)
for i in inputs: i[random.randint(0, len(i) - 1)] = 1

sample_enc = Encoder(1, 1)
sample_dec = Decoder(1, 1)
sample_tf = Transformer()

#sample_x = inputs
#print(sample_x.shape)
#sample_x = sample_enc(sample_x)
#print(sample_x.shape)
#sample_x = sample_dec(sample_x)
#print(sample_x.shape)
sample_x, weights = sample_tf()
print(sample_x.shape)


torch.Size([13, 29])


# TRAIN

In [9]:
import torch 
import torch.nn as nn
import random
import numpy as np

EPOCHS = 200

transformer = Transformer()
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0002)
loss = torch.nn.BCELoss()

real_sample = torch.Tensor([
        [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0]
])
for _ in range(1000):
    sample, __ = transformer()        
    # target = torch.ones(sample.shape[0], sample.shape[1])
    error = loss(sample, real_sample)
    if _ % 10 == 0: print(error, sample)
    error.backward()
    optimizer.step()
    if(error > 8): break
print(sample)


tensor(0.1623, grad_fn=<BinaryCrossEntropyBackward>) tensor([[0.0458, 0.0297, 0.0645, 0.0065, 0.0237, 0.0166, 0.0172, 0.0396, 0.0163,
         0.0108, 0.0347, 0.0145, 0.1121, 0.0391, 0.0214, 0.1197, 0.0355, 0.0254,
         0.0427, 0.0232, 0.0339, 0.0335, 0.0547, 0.0071, 0.0352, 0.0327, 0.0146,
         0.0285, 0.0208],
        [0.0438, 0.0275, 0.0617, 0.0219, 0.0498, 0.0380, 0.0243, 0.0398, 0.0239,
         0.0386, 0.0202, 0.0120, 0.0439, 0.0247, 0.0242, 0.0821, 0.0359, 0.0226,
         0.0525, 0.0494, 0.0451, 0.0157, 0.0756, 0.0106, 0.0174, 0.0363, 0.0232,
         0.0164, 0.0229],
        [0.0354, 0.0459, 0.0304, 0.0529, 0.0376, 0.0258, 0.0271, 0.0185, 0.0492,
         0.0327, 0.0528, 0.0063, 0.0360, 0.0091, 0.0304, 0.0433, 0.0152, 0.0333,
         0.0446, 0.0463, 0.0303, 0.0256, 0.0681, 0.0194, 0.0475, 0.0565, 0.0271,
         0.0330, 0.0197],
        [0.0295, 0.0687, 0.0269, 0.0250, 0.0322, 0.0169, 0.0249, 0.0382, 0.0306,
         0.0364, 0.0398, 0.0118, 0.0300, 0.0404, 0.0288, 0.

tensor(0.0923, grad_fn=<BinaryCrossEntropyBackward>) tensor([[0.2667, 0.0010, 0.0004, 0.0007, 0.1040, 0.0006, 0.0005, 0.2786, 0.0007,
         0.0005, 0.0011, 0.0053, 0.0005, 0.0005, 0.0032, 0.0021, 0.0008, 0.0004,
         0.0005, 0.3230, 0.0008, 0.0023, 0.0016, 0.0007, 0.0005, 0.0005, 0.0010,
         0.0006, 0.0007],
        [0.1221, 0.0017, 0.0009, 0.0010, 0.1843, 0.0021, 0.0011, 0.3974, 0.0011,
         0.0007, 0.0008, 0.0123, 0.0005, 0.0006, 0.0023, 0.0011, 0.0010, 0.0006,
         0.0006, 0.2547, 0.0017, 0.0027, 0.0011, 0.0009, 0.0006, 0.0019, 0.0019,
         0.0018, 0.0006],
        [0.1306, 0.0017, 0.0006, 0.0006, 0.2587, 0.0019, 0.0009, 0.2510, 0.0009,
         0.0013, 0.0008, 0.0072, 0.0008, 0.0011, 0.0015, 0.0013, 0.0020, 0.0008,
         0.0021, 0.3222, 0.0020, 0.0018, 0.0019, 0.0011, 0.0004, 0.0010, 0.0016,
         0.0010, 0.0011],
        [0.2368, 0.0011, 0.0012, 0.0007, 0.2292, 0.0006, 0.0021, 0.1540, 0.0007,
         0.0012, 0.0006, 0.0055, 0.0004, 0.0009, 0.0024, 0.

tensor(0.0774, grad_fn=<BinaryCrossEntropyBackward>) tensor([[0.0279, 0.0004, 0.0002, 0.0001, 0.0247, 0.0002, 0.0002, 0.0095, 0.0003,
         0.0002, 0.0002, 0.1123, 0.0001, 0.0002, 0.0002, 0.0004, 0.0002, 0.0002,
         0.0002, 0.8206, 0.0002, 0.0002, 0.0003, 0.0001, 0.0001, 0.0001, 0.0001,
         0.0001, 0.0002],
        [0.0083, 0.0003, 0.0001, 0.0001, 0.0210, 0.0001, 0.0001, 0.0087, 0.0001,
         0.0002, 0.0001, 0.1832, 0.0001, 0.0001, 0.0001, 0.0002, 0.0004, 0.0001,
         0.0002, 0.7749, 0.0002, 0.0002, 0.0002, 0.0001, 0.0001, 0.0003, 0.0002,
         0.0001, 0.0002],
        [0.0506, 0.0008, 0.0004, 0.0003, 0.0523, 0.0003, 0.0006, 0.0137, 0.0007,
         0.0003, 0.0006, 0.1117, 0.0003, 0.0003, 0.0005, 0.0003, 0.0006, 0.0002,
         0.0007, 0.7612, 0.0003, 0.0004, 0.0004, 0.0003, 0.0004, 0.0003, 0.0006,
         0.0007, 0.0003],
        [0.0164, 0.0004, 0.0001, 0.0001, 0.0171, 0.0003, 0.0001, 0.0059, 0.0001,
         0.0001, 0.0002, 0.1187, 0.0001, 0.0002, 0.0003, 0.

tensor(0.1137, grad_fn=<BinaryCrossEntropyBackward>) tensor([[2.0813e-04, 2.6811e-05, 9.8765e-06, 6.8617e-06, 5.3598e-04, 7.8464e-06,
         1.8330e-05, 1.8987e-03, 7.5485e-06, 1.0931e-05, 2.6748e-05, 5.2302e-01,
         2.0994e-05, 1.6651e-05, 1.1009e-05, 1.5005e-05, 3.4297e-05, 9.0911e-06,
         1.3892e-05, 4.7397e-01, 1.1904e-05, 8.4866e-06, 2.0223e-05, 9.9141e-06,
         2.0379e-05, 9.6738e-06, 1.8600e-05, 1.5369e-05, 2.1521e-05],
        [9.0121e-04, 2.4877e-05, 9.9112e-06, 6.3887e-06, 3.0030e-04, 1.1330e-05,
         2.3810e-05, 7.1602e-04, 2.4188e-05, 1.6396e-05, 1.1730e-05, 5.2531e-01,
         1.8639e-05, 2.3811e-05, 1.6399e-05, 2.3828e-05, 1.9584e-05, 5.9798e-06,
         2.4996e-05, 4.7235e-01, 4.8399e-05, 1.1416e-05, 1.6903e-05, 7.6824e-06,
         5.4634e-06, 1.1094e-05, 2.2555e-05, 1.4871e-05, 2.1470e-05],
        [4.3988e-04, 2.0615e-05, 6.4027e-06, 5.2274e-06, 4.4107e-04, 7.2721e-06,
         8.5958e-06, 6.0835e-04, 5.7839e-06, 5.7265e-06, 1.2369e-05, 5.9899e-0

tensor(0.0998, grad_fn=<BinaryCrossEntropyBackward>) tensor([[8.0591e-04, 3.5764e-05, 6.1882e-06, 2.5883e-06, 9.2574e-04, 9.2861e-06,
         5.7445e-06, 2.0154e-03, 1.5215e-05, 5.9099e-06, 7.7091e-06, 3.9559e-01,
         9.1744e-06, 6.4281e-06, 8.8118e-06, 1.9737e-05, 2.3332e-05, 3.1312e-06,
         1.3778e-05, 6.0039e-01, 1.1651e-05, 4.9529e-06, 1.0500e-05, 9.3253e-06,
         9.2690e-06, 5.6933e-06, 1.9834e-05, 1.4606e-05, 1.4186e-05],
        [7.2390e-04, 7.9865e-06, 6.9582e-06, 3.3067e-06, 6.4597e-04, 2.5622e-06,
         3.1968e-06, 7.1446e-04, 3.9689e-06, 1.8749e-06, 4.7680e-06, 2.8428e-01,
         3.2936e-06, 5.0294e-06, 7.7588e-06, 1.1753e-05, 2.3716e-06, 3.1151e-06,
         4.6029e-06, 7.1352e-01, 8.3217e-06, 2.2601e-06, 6.7185e-06, 3.3696e-06,
         2.1046e-06, 4.5398e-06, 6.6094e-06, 2.8473e-06, 4.8174e-06],
        [1.0851e-03, 1.9808e-05, 3.8341e-06, 9.6210e-06, 9.1870e-04, 1.0711e-05,
         7.5953e-06, 1.6063e-03, 1.3621e-05, 6.1888e-06, 1.1568e-05, 3.4835e-0

tensor(0.1124, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.4793e-03, 8.8074e-06, 3.2635e-06, 4.4070e-06, 6.2119e-03, 3.5709e-06,
         4.8447e-06, 8.8379e-03, 4.3151e-06, 1.5390e-06, 3.5605e-06, 1.7099e-02,
         1.8396e-06, 4.1583e-06, 2.8030e-06, 2.0527e-05, 5.2156e-06, 2.7898e-06,
         2.2816e-06, 9.6427e-01, 3.1275e-06, 1.4966e-06, 8.9423e-06, 3.1241e-06,
         2.4008e-06, 2.9848e-06, 2.6810e-06, 2.3900e-06, 3.2272e-06],
        [1.5401e-03, 1.0917e-06, 1.0052e-06, 3.4206e-07, 8.0398e-04, 4.3446e-07,
         6.8876e-07, 5.9807e-03, 1.0159e-06, 1.0036e-06, 6.7532e-07, 2.0288e-02,
         1.1142e-06, 5.0917e-07, 5.3504e-07, 2.4468e-06, 6.2623e-07, 3.3859e-07,
         5.2512e-07, 9.7137e-01, 7.4548e-07, 5.3736e-07, 1.5544e-06, 4.2133e-07,
         3.3857e-07, 1.0479e-06, 1.1540e-06, 4.0680e-07, 5.4919e-07],
        [9.8991e-04, 2.8992e-06, 6.2634e-07, 1.0230e-06, 4.1605e-03, 1.1608e-06,
         1.1647e-06, 9.2129e-03, 1.0003e-06, 1.9079e-06, 1.9666e-06, 1.7494e-0

tensor(0.0872, grad_fn=<BinaryCrossEntropyBackward>) tensor([[7.4787e-02, 1.3984e-05, 7.9487e-06, 8.4216e-06, 9.4877e-02, 3.7350e-06,
         6.8442e-06, 2.4518e-01, 3.2716e-06, 3.6720e-06, 6.3095e-06, 1.2270e-02,
         7.1460e-06, 5.5022e-06, 7.7720e-06, 1.6055e-05, 1.2345e-05, 4.2089e-06,
         6.5403e-06, 5.7271e-01, 4.6690e-06, 7.6716e-06, 1.7113e-05, 3.5241e-06,
         9.2052e-06, 5.0019e-06, 1.0482e-05, 7.1039e-06, 2.2836e-06],
        [1.0252e-01, 9.0908e-06, 5.1550e-06, 2.9823e-06, 1.9224e-01, 5.4064e-06,
         3.6763e-06, 2.9088e-01, 6.8835e-06, 2.4774e-06, 7.8972e-06, 1.2089e-02,
         3.6388e-06, 5.8263e-06, 5.5602e-06, 1.3049e-05, 4.8014e-06, 2.5383e-06,
         3.8997e-06, 4.0210e-01, 5.1092e-06, 6.2311e-06, 4.0911e-05, 1.2587e-05,
         6.1512e-06, 4.0759e-06, 3.4552e-06, 4.7877e-06, 6.2041e-06],
        [5.9487e-02, 1.1537e-05, 5.2934e-06, 3.9476e-06, 1.2531e-01, 5.9724e-06,
         9.7538e-06, 1.6948e-01, 5.8843e-06, 3.7301e-06, 8.1789e-06, 1.0521e-0

tensor(0.1519, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.5972e-01, 2.0799e-06, 8.4479e-07, 1.4466e-06, 2.8950e-01, 8.6893e-07,
         1.3347e-06, 3.3023e-01, 1.1809e-06, 7.2631e-07, 3.2023e-06, 3.5904e-03,
         1.4412e-06, 1.1507e-06, 6.5175e-07, 6.2067e-06, 1.1562e-06, 1.3462e-06,
         2.4325e-06, 1.6915e-02, 6.2870e-07, 2.1415e-06, 5.2092e-06, 1.1791e-06,
         1.6971e-06, 1.8030e-06, 2.2110e-06, 1.2310e-06, 1.3084e-06],
        [9.3906e-02, 4.4746e-06, 1.1763e-06, 2.1343e-06, 4.3041e-01, 1.4410e-06,
         2.8689e-06, 4.5067e-01, 3.6402e-06, 8.9902e-07, 2.7141e-06, 2.8895e-03,
         6.2431e-06, 1.8921e-06, 5.1525e-07, 4.4779e-06, 2.8207e-06, 9.2437e-07,
         1.3370e-06, 2.2064e-02, 1.2478e-06, 8.1483e-07, 1.0872e-05, 1.7275e-06,
         5.9835e-07, 2.2846e-06, 2.1351e-06, 7.5014e-07, 2.7686e-06],
        [2.2577e-01, 3.7024e-06, 1.7238e-06, 1.2685e-06, 5.1752e-01, 1.2590e-06,
         1.6725e-06, 2.2496e-01, 1.1262e-06, 1.6664e-06, 2.1375e-06, 2.6659e-0

tensor(0.1703, grad_fn=<BinaryCrossEntropyBackward>) tensor([[7.1519e-01, 3.3125e-06, 5.2305e-07, 4.6702e-07, 2.4786e-01, 8.7509e-07,
         1.0995e-06, 1.8347e-02, 5.4001e-07, 3.8462e-07, 1.8004e-06, 4.4857e-03,
         1.1752e-06, 1.0652e-06, 4.8672e-07, 1.9023e-06, 1.6343e-06, 6.4758e-07,
         8.8246e-07, 1.4090e-02, 4.0506e-07, 7.0916e-07, 4.3108e-06, 8.0451e-07,
         1.2587e-06, 1.1923e-06, 6.5216e-07, 4.2416e-07, 6.7314e-07],
        [5.9454e-01, 1.8019e-06, 9.2782e-07, 7.8780e-07, 3.1867e-01, 6.7846e-07,
         1.4134e-06, 6.7666e-02, 1.2123e-06, 2.1292e-07, 1.4087e-06, 4.5333e-03,
         3.3984e-06, 3.2858e-07, 2.7590e-07, 2.0396e-06, 5.9315e-07, 4.3012e-07,
         4.2292e-07, 1.4564e-02, 2.2619e-07, 3.4164e-07, 1.7441e-06, 4.4516e-07,
         5.4562e-07, 4.0837e-07, 6.8370e-07, 3.5754e-07, 3.7994e-07],
        [7.9462e-01, 3.4521e-07, 1.7285e-07, 1.0171e-07, 1.7789e-01, 1.6414e-07,
         1.4006e-07, 1.6937e-02, 2.4166e-07, 2.6532e-07, 5.5709e-07, 2.7420e-0

tensor(0.0856, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.9088e-01, 4.7863e-07, 2.2615e-07, 1.1037e-07, 8.4013e-02, 1.9300e-07,
         4.9276e-07, 1.4933e-02, 1.5844e-07, 2.0127e-07, 4.7992e-07, 1.1883e-01,
         3.1126e-07, 1.5713e-07, 1.9698e-07, 7.1465e-07, 2.3023e-07, 1.1345e-07,
         2.4464e-07, 5.9133e-01, 2.4827e-07, 1.2459e-07, 6.2148e-07, 2.8198e-07,
         2.5965e-07, 3.1175e-07, 1.9802e-07, 1.5935e-07, 1.4871e-07],
        [2.5476e-01, 3.6893e-07, 1.5191e-07, 8.5102e-08, 5.7765e-02, 7.8022e-08,
         2.4766e-07, 5.6804e-03, 2.1890e-07, 8.7501e-08, 2.9098e-07, 1.0322e-01,
         2.3569e-07, 1.6105e-07, 1.2021e-07, 4.1722e-07, 5.6142e-08, 5.0303e-08,
         9.2417e-08, 5.7857e-01, 1.2914e-07, 9.7895e-08, 6.0044e-07, 1.3617e-07,
         1.4581e-07, 1.0474e-07, 1.0973e-07, 1.0313e-07, 1.0464e-07],
        [2.9522e-01, 4.7657e-07, 4.2625e-07, 1.5960e-07, 5.3836e-02, 3.2819e-07,
         4.8124e-07, 7.2915e-03, 4.5763e-07, 2.6968e-07, 4.3412e-07, 1.1571e-0

tensor(0.1137, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.0548e-03, 3.9359e-08, 1.9165e-08, 1.9484e-08, 5.7378e-04, 1.4462e-08,
         4.1640e-08, 1.3557e-04, 2.7393e-08, 8.0600e-09, 3.4582e-08, 4.1808e-02,
         4.3550e-08, 1.4253e-08, 2.0748e-08, 9.4048e-08, 1.8614e-08, 2.6765e-08,
         2.3499e-08, 9.5643e-01, 1.1764e-08, 2.6991e-08, 1.6926e-07, 1.7422e-08,
         3.6051e-08, 1.4180e-08, 2.0001e-08, 4.4639e-08, 1.4886e-08],
        [6.5720e-04, 1.2684e-07, 2.4816e-08, 3.3260e-08, 9.6271e-04, 1.8201e-08,
         7.4422e-08, 2.5487e-04, 4.4148e-08, 1.4797e-08, 7.5077e-08, 7.3546e-02,
         6.7623e-08, 3.0417e-08, 9.5186e-09, 3.1099e-07, 2.2950e-08, 2.1763e-08,
         4.0132e-08, 9.2458e-01, 2.8490e-08, 2.4908e-08, 2.4016e-07, 3.1483e-08,
         5.2170e-08, 3.0476e-08, 2.0797e-08, 4.6212e-08, 2.6504e-08],
        [2.6405e-04, 1.9002e-08, 1.0719e-08, 6.1515e-09, 8.8421e-04, 8.3564e-09,
         5.3039e-09, 1.4468e-04, 1.7857e-08, 2.6175e-09, 1.6814e-08, 9.9283e-0

tensor(0.1361, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.5685e-05, 8.1997e-09, 3.6686e-09, 1.4237e-09, 4.9547e-05, 1.5564e-09,
         3.4838e-09, 1.4103e-05, 2.3468e-09, 2.3445e-09, 3.3725e-09, 6.8930e-02,
         4.0511e-09, 1.3959e-09, 8.2721e-10, 4.3478e-09, 8.9840e-10, 8.1507e-10,
         1.7366e-09, 9.3099e-01, 9.8922e-10, 6.6297e-10, 1.2763e-08, 1.0764e-09,
         2.1779e-09, 2.3285e-09, 1.7455e-09, 2.1980e-09, 2.6661e-09],
        [5.6409e-06, 5.0851e-09, 2.6218e-09, 1.1101e-09, 2.0144e-05, 2.2218e-09,
         3.6180e-09, 8.0959e-06, 1.9777e-09, 1.1386e-09, 4.3110e-09, 1.0516e-01,
         5.7584e-09, 1.2648e-09, 9.8794e-10, 1.7968e-08, 3.8611e-09, 1.2874e-09,
         2.6297e-09, 8.9481e-01, 7.4702e-10, 8.9423e-10, 1.7506e-08, 1.1279e-09,
         1.3106e-09, 1.4331e-09, 9.5540e-10, 2.0907e-09, 1.0199e-09],
        [1.1185e-05, 4.4755e-08, 1.6345e-08, 5.7313e-09, 4.7667e-05, 1.4598e-08,
         2.2037e-08, 1.8820e-05, 5.9130e-09, 1.0074e-08, 1.8104e-08, 9.0106e-0

tensor(0.1483, grad_fn=<BinaryCrossEntropyBackward>) tensor([[4.7534e-07, 1.3768e-09, 3.6761e-10, 4.1747e-10, 7.7233e-06, 3.4252e-10,
         8.6549e-10, 8.0526e-06, 5.4952e-10, 3.7820e-10, 9.1647e-10, 5.2165e-01,
         1.0488e-09, 3.2492e-10, 2.6089e-10, 2.4029e-09, 5.8798e-10, 5.0609e-10,
         7.7921e-10, 4.7833e-01, 4.9038e-10, 3.5244e-10, 4.6241e-09, 3.8386e-10,
         5.2269e-10, 5.7454e-10, 2.9854e-10, 6.3593e-10, 5.3173e-10],
        [2.2998e-06, 1.0511e-08, 1.7851e-09, 8.5186e-10, 7.8491e-06, 1.5370e-09,
         6.6323e-09, 1.0717e-05, 1.5398e-09, 2.5290e-09, 1.1494e-08, 4.9308e-01,
         1.2294e-08, 1.6989e-09, 1.7283e-09, 1.4369e-08, 3.2633e-09, 3.8456e-09,
         1.5080e-09, 5.0690e-01, 1.6521e-09, 1.5536e-09, 2.0604e-08, 3.4374e-09,
         3.6711e-09, 3.3729e-09, 1.8436e-09, 6.2278e-09, 2.8617e-09],
        [4.5042e-07, 4.4810e-10, 3.4553e-10, 2.1631e-10, 3.3757e-06, 1.3609e-10,
         7.7511e-10, 4.2212e-06, 2.3984e-10, 9.4339e-11, 3.3881e-10, 3.4169e-0

tensor(0.2073, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.6823e-08, 6.9117e-10, 8.5913e-11, 5.8505e-11, 1.3573e-06, 1.0114e-10,
         2.5985e-10, 1.7189e-06, 8.5239e-11, 2.2750e-10, 1.6009e-10, 7.4192e-01,
         2.6944e-10, 2.4449e-10, 2.6596e-11, 7.6014e-10, 5.7878e-11, 6.7782e-11,
         8.0756e-11, 2.5808e-01, 8.4280e-11, 5.8251e-11, 1.0472e-09, 1.4308e-10,
         2.0731e-10, 8.5051e-11, 9.7839e-11, 6.5375e-11, 1.9225e-10],
        [1.9258e-08, 1.2580e-10, 1.1494e-10, 3.8624e-11, 4.1982e-07, 5.8050e-11,
         1.8415e-10, 1.1624e-06, 5.3686e-11, 4.3241e-11, 7.3703e-11, 8.2614e-01,
         1.6063e-10, 2.8662e-11, 2.5205e-11, 1.4617e-10, 4.1352e-11, 1.9535e-11,
         2.7014e-11, 1.7386e-01, 1.1690e-11, 2.4082e-11, 4.7271e-10, 3.2657e-11,
         1.6965e-10, 1.1142e-10, 4.1855e-11, 6.1031e-11, 4.3085e-11],
        [1.8392e-08, 3.8063e-11, 1.2969e-11, 1.1673e-11, 7.2973e-07, 2.7837e-11,
         2.6230e-11, 7.9830e-07, 3.8167e-11, 1.1929e-11, 2.0043e-11, 8.3978e-0

tensor(0.1595, grad_fn=<BinaryCrossEntropyBackward>) tensor([[5.0739e-09, 7.9953e-11, 9.4192e-12, 1.6752e-11, 1.0615e-06, 1.7479e-11,
         2.0752e-11, 1.1945e-06, 2.5308e-11, 9.5944e-12, 2.4634e-11, 2.7004e-01,
         4.7835e-11, 1.8291e-11, 1.4317e-11, 3.6736e-11, 1.4781e-11, 2.1261e-11,
         1.3245e-11, 7.2996e-01, 6.0321e-12, 8.8402e-12, 1.4920e-10, 9.5170e-12,
         4.3065e-11, 1.5011e-11, 7.0428e-12, 2.8989e-11, 7.4189e-12],
        [2.1037e-08, 1.7723e-09, 4.7735e-10, 6.4330e-10, 2.0567e-06, 4.8616e-10,
         2.6155e-09, 7.0084e-06, 3.6228e-10, 7.9257e-10, 1.7253e-09, 2.8033e-01,
         2.3205e-09, 7.5836e-10, 2.0064e-10, 5.8474e-09, 5.7082e-10, 4.4262e-10,
         6.0915e-10, 7.1966e-01, 1.7586e-10, 1.6330e-10, 7.2884e-09, 5.0968e-10,
         2.1516e-09, 3.6456e-10, 2.2267e-10, 4.4632e-10, 5.8338e-10],
        [7.1860e-09, 7.0524e-11, 4.4462e-11, 5.8342e-11, 8.5065e-07, 7.4000e-11,
         1.1667e-10, 1.2531e-06, 1.0997e-10, 2.9772e-11, 6.2563e-11, 3.0407e-0

tensor(0.2198, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.1642e-09, 5.3420e-12, 1.2699e-12, 6.0585e-13, 3.6516e-07, 5.5534e-13,
         1.2861e-11, 9.6415e-07, 8.4012e-13, 2.6607e-12, 2.8839e-12, 6.6578e-03,
         1.5144e-11, 6.4614e-13, 1.1746e-12, 9.5474e-12, 1.3599e-12, 8.6202e-13,
         8.1896e-13, 9.9334e-01, 2.7599e-13, 5.8025e-13, 4.5778e-11, 9.0327e-13,
         2.5252e-12, 1.0584e-12, 4.2788e-13, 1.3034e-12, 9.1512e-13],
        [1.5217e-09, 2.2965e-11, 4.4225e-12, 4.9379e-12, 3.3943e-07, 5.8317e-12,
         1.5779e-11, 2.0570e-06, 1.0671e-11, 4.3724e-12, 1.4888e-11, 7.1443e-03,
         4.3677e-11, 5.9849e-12, 2.3852e-12, 5.3240e-11, 6.7903e-12, 9.1523e-12,
         3.4885e-12, 9.9285e-01, 4.9520e-12, 2.8537e-12, 1.2948e-10, 3.5351e-12,
         9.6312e-12, 3.2826e-12, 2.4727e-12, 3.7681e-12, 5.3790e-12],
        [1.9617e-09, 6.5007e-11, 3.3151e-11, 2.5024e-11, 1.0169e-06, 2.1286e-11,
         8.2461e-11, 2.3230e-06, 3.5405e-11, 3.9789e-11, 4.4540e-11, 9.7989e-0

tensor(0.2563, grad_fn=<BinaryCrossEntropyBackward>) tensor([[2.9781e-10, 7.9466e-12, 2.5942e-12, 1.9142e-12, 3.6250e-07, 8.3042e-13,
         6.6581e-12, 1.9210e-06, 1.9140e-12, 3.1877e-12, 3.1783e-12, 7.3765e-04,
         1.1627e-11, 1.0611e-12, 6.7234e-13, 1.2222e-11, 1.7457e-12, 9.9477e-13,
         1.0090e-12, 9.9926e-01, 5.6665e-13, 7.8953e-13, 4.0049e-11, 5.8558e-13,
         6.3158e-12, 4.9184e-13, 3.5625e-13, 9.6885e-13, 6.1223e-13],
        [1.7348e-09, 2.2709e-10, 3.7546e-11, 1.0163e-11, 2.9872e-06, 3.6704e-11,
         4.3881e-10, 1.2766e-05, 4.4495e-11, 5.6233e-11, 5.4291e-11, 1.5556e-03,
         6.7572e-10, 1.1515e-11, 3.9996e-12, 6.0980e-10, 2.7500e-11, 2.0702e-11,
         1.2132e-11, 9.9843e-01, 6.5759e-12, 1.1847e-11, 1.5057e-09, 1.2146e-11,
         4.7734e-11, 1.8015e-11, 9.2580e-12, 1.4016e-11, 1.7468e-11],
        [5.7032e-10, 2.4044e-11, 8.8254e-12, 4.3591e-12, 6.3639e-07, 6.0455e-12,
         5.0142e-11, 8.5008e-06, 2.7274e-12, 3.7425e-12, 5.9413e-12, 1.7772e-0

tensor(0.2405, grad_fn=<BinaryCrossEntropyBackward>) tensor([[6.2739e-10, 1.5751e-11, 6.2323e-12, 3.4706e-12, 4.8102e-06, 5.0791e-12,
         2.0347e-11, 3.7450e-05, 5.5411e-12, 5.2080e-12, 1.1421e-11, 1.3140e-03,
         4.3221e-11, 1.8290e-12, 1.2558e-12, 4.0888e-11, 5.3793e-12, 3.6598e-12,
         3.2104e-12, 9.9864e-01, 1.3648e-12, 2.5333e-12, 1.2001e-10, 1.4313e-12,
         1.0878e-11, 4.3104e-12, 1.3755e-12, 1.3285e-12, 2.7254e-12],
        [6.4705e-10, 7.6392e-11, 1.5015e-11, 1.0169e-11, 7.6572e-06, 9.0797e-12,
         6.8789e-11, 4.8782e-05, 1.4998e-11, 4.7041e-11, 2.1983e-11, 8.0957e-04,
         1.4012e-10, 1.1443e-11, 5.3404e-12, 1.2901e-10, 1.0381e-11, 1.1958e-11,
         1.8405e-11, 9.9913e-01, 3.5009e-12, 3.4287e-12, 2.5937e-10, 7.2819e-12,
         9.1749e-12, 1.2527e-11, 5.1471e-12, 9.5471e-12, 7.9492e-12],
        [2.4100e-09, 9.3645e-11, 4.1443e-11, 5.3451e-11, 7.4581e-06, 4.1836e-11,
         5.4038e-11, 5.7781e-05, 4.6946e-11, 5.4158e-11, 5.0208e-11, 2.4122e-0

tensor(0.1883, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.5662e-09, 3.8324e-11, 3.5558e-12, 5.9770e-12, 9.2795e-05, 7.5212e-12,
         3.6181e-11, 8.8528e-04, 1.1314e-11, 6.7513e-12, 1.1144e-11, 9.6468e-03,
         2.3164e-11, 3.7739e-12, 3.1796e-12, 2.2130e-11, 6.0692e-12, 3.2308e-12,
         7.1561e-12, 9.8938e-01, 3.0487e-12, 1.9490e-12, 8.7427e-11, 2.0469e-12,
         1.1129e-11, 7.1190e-12, 2.8727e-12, 2.6612e-12, 5.5386e-12],
        [2.8979e-09, 3.0275e-11, 3.2536e-12, 8.0525e-12, 8.9661e-05, 1.0938e-11,
         1.1628e-11, 4.1894e-04, 1.2756e-11, 6.1674e-12, 5.9526e-12, 1.2142e-02,
         1.7726e-11, 3.7092e-12, 2.3855e-12, 4.9221e-11, 1.1906e-11, 3.5230e-12,
         1.0505e-11, 9.8735e-01, 2.6505e-12, 2.1822e-12, 1.2298e-10, 2.5918e-12,
         9.4965e-12, 7.4334e-12, 1.5024e-12, 6.9187e-12, 6.4400e-12],
        [2.4045e-09, 3.7225e-11, 6.4751e-11, 6.6273e-11, 4.6255e-05, 3.8331e-11,
         1.7064e-10, 6.5468e-04, 4.0230e-11, 3.0155e-11, 4.8072e-11, 5.1867e-0

tensor(0.1030, grad_fn=<BinaryCrossEntropyBackward>) tensor([[8.3873e-08, 2.0704e-09, 6.0726e-10, 1.0058e-09, 6.9785e-03, 6.5045e-10,
         5.2625e-09, 3.9863e-02, 1.1821e-09, 4.9796e-10, 1.1849e-09, 2.7314e-01,
         1.0931e-08, 3.5144e-10, 2.6230e-10, 5.0685e-09, 2.9669e-10, 2.3447e-10,
         7.9019e-10, 6.8002e-01, 1.9796e-10, 2.7562e-10, 1.0822e-08, 1.7275e-10,
         1.0316e-09, 6.7868e-10, 5.5711e-11, 3.1495e-10, 3.2996e-10],
        [7.6089e-08, 4.5415e-10, 1.2851e-10, 1.3088e-10, 2.1316e-02, 1.3057e-10,
         8.8668e-10, 8.9096e-02, 1.2761e-10, 9.2414e-11, 1.2763e-10, 3.3807e-01,
         4.1529e-10, 1.3211e-10, 2.6803e-11, 4.3054e-10, 5.7203e-11, 1.0908e-10,
         7.4869e-11, 5.5152e-01, 4.7521e-11, 7.4475e-11, 7.7680e-10, 6.2590e-11,
         2.1377e-10, 1.1823e-10, 4.0283e-11, 8.6286e-11, 6.8754e-11],
        [9.5562e-08, 2.5243e-09, 1.4072e-09, 6.7808e-10, 8.4082e-03, 1.0612e-09,
         5.4253e-09, 6.7425e-02, 1.3672e-09, 8.1762e-10, 1.3874e-09, 2.2863e-0

tensor(0.1769, grad_fn=<BinaryCrossEntropyBackward>) tensor([[4.2476e-08, 1.8298e-10, 2.8173e-11, 2.5732e-11, 2.2764e-02, 3.5008e-11,
         2.6873e-10, 1.9197e-01, 4.4233e-11, 5.0781e-11, 1.6305e-10, 7.1123e-01,
         2.9224e-10, 3.1559e-11, 1.1736e-11, 1.7558e-10, 2.7410e-11, 1.0667e-11,
         2.7342e-11, 7.4028e-02, 1.7767e-11, 1.0279e-11, 1.4456e-09, 1.7251e-11,
         7.4267e-11, 2.8100e-11, 1.1980e-11, 3.8872e-11, 9.3115e-11],
        [1.0026e-08, 3.7657e-11, 2.3210e-11, 4.6219e-12, 1.1085e-02, 1.5439e-11,
         8.4301e-11, 2.5252e-01, 2.3496e-11, 8.8323e-12, 2.2990e-11, 7.0909e-01,
         9.8205e-11, 4.5973e-12, 5.9697e-12, 8.9861e-11, 1.2785e-11, 5.2596e-12,
         1.8566e-11, 2.7311e-02, 2.9784e-12, 5.7679e-12, 3.2968e-10, 1.0975e-11,
         7.5151e-12, 1.0685e-11, 9.3724e-12, 6.3791e-12, 6.4935e-12],
        [6.0204e-08, 2.2559e-09, 6.5988e-10, 4.7738e-10, 2.9294e-02, 5.7500e-10,
         2.0883e-09, 4.5999e-01, 5.0162e-10, 4.4619e-10, 5.9877e-10, 4.5753e-0

tensor(0.2368, grad_fn=<BinaryCrossEntropyBackward>) tensor([[4.3754e-08, 2.5647e-10, 7.7263e-11, 1.1080e-10, 2.7817e-02, 1.4410e-10,
         5.3397e-10, 3.1070e-01, 3.0573e-10, 1.1886e-10, 2.0137e-10, 6.5850e-01,
         1.5388e-09, 9.4388e-11, 3.8403e-11, 8.1609e-10, 1.2496e-10, 1.0603e-10,
         1.2302e-10, 2.9906e-03, 2.3322e-11, 9.2172e-11, 5.8478e-09, 8.4943e-11,
         8.0628e-10, 1.0555e-10, 1.2163e-11, 3.3750e-11, 3.9510e-11],
        [7.2724e-08, 6.6441e-10, 4.3194e-10, 4.0364e-10, 7.3388e-02, 4.0191e-10,
         3.0047e-10, 4.2817e-01, 3.5073e-10, 2.0213e-10, 4.9002e-10, 4.9596e-01,
         1.1479e-09, 3.8242e-10, 1.7261e-10, 1.1522e-09, 2.3187e-10, 1.6551e-10,
         2.3673e-10, 2.4730e-03, 4.4456e-11, 5.7183e-11, 6.9628e-09, 1.8525e-10,
         7.0993e-10, 3.3564e-10, 7.8940e-11, 3.0461e-10, 2.6307e-10],
        [1.4843e-08, 7.1471e-11, 2.8111e-11, 3.1568e-11, 5.3599e-02, 1.5059e-11,
         1.2011e-10, 5.3291e-01, 1.6505e-11, 4.1040e-11, 4.0861e-11, 4.1033e-0

tensor(0.2624, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.6768e-08, 7.9178e-11, 3.8848e-11, 3.7420e-11, 1.8093e-01, 2.1135e-11,
         1.4835e-10, 5.3437e-01, 3.2870e-11, 3.8092e-11, 4.2082e-11, 2.8294e-01,
         3.3082e-10, 1.6492e-11, 3.5492e-12, 4.1659e-10, 2.2798e-11, 1.7433e-11,
         2.3585e-11, 1.7673e-03, 5.1124e-12, 8.5699e-12, 2.3549e-09, 5.6115e-12,
         4.1224e-11, 1.7312e-11, 6.2150e-12, 3.9471e-11, 1.1110e-11],
        [1.6979e-08, 1.6388e-10, 7.5494e-11, 7.5691e-11, 1.0262e-01, 5.6133e-11,
         6.5254e-10, 8.1381e-01, 7.2148e-11, 6.7489e-11, 1.6923e-10, 8.3080e-02,
         1.3492e-09, 3.0132e-11, 9.2107e-12, 5.1942e-10, 4.0869e-11, 2.6332e-11,
         4.9019e-11, 4.9018e-04, 2.4556e-11, 4.2233e-11, 2.6162e-09, 2.5746e-11,
         1.4832e-10, 6.0316e-11, 8.0546e-12, 3.7697e-11, 5.9896e-11],
        [7.7399e-08, 1.9290e-10, 4.2314e-11, 4.2080e-11, 1.9712e-01, 2.2662e-11,
         9.1001e-11, 5.4276e-01, 4.6255e-11, 3.6424e-11, 8.6918e-11, 2.5924e-0

tensor(0.2864, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.3220e-07, 4.3034e-10, 5.5975e-11, 6.5377e-11, 6.6776e-01, 1.4102e-10,
         3.3941e-10, 1.5828e-01, 9.6843e-11, 1.0113e-10, 1.0469e-10, 1.7333e-01,
         1.4483e-09, 1.0131e-10, 2.3951e-11, 1.6018e-10, 5.2905e-11, 5.8924e-11,
         6.0079e-11, 6.3254e-04, 2.0617e-11, 1.5557e-11, 2.1735e-09, 1.9291e-11,
         2.5398e-10, 1.0097e-10, 1.9784e-11, 4.9277e-11, 1.2678e-10],
        [9.3294e-08, 3.6135e-10, 1.1510e-10, 8.9277e-11, 7.5266e-01, 1.1264e-10,
         8.1362e-10, 2.1047e-01, 8.9014e-11, 1.0462e-10, 6.1752e-10, 3.6758e-02,
         1.9053e-09, 5.7131e-11, 4.3456e-11, 9.4649e-10, 1.5366e-10, 9.1313e-11,
         1.2760e-10, 1.1405e-04, 3.9973e-11, 4.0156e-11, 1.7678e-09, 6.9067e-11,
         2.4813e-10, 1.0295e-10, 2.3149e-11, 5.5769e-11, 1.3001e-10],
        [1.0313e-07, 1.6427e-10, 5.4775e-11, 3.3487e-11, 7.3743e-01, 2.6797e-11,
         4.3279e-10, 1.5246e-01, 3.1460e-11, 1.0420e-10, 9.5182e-11, 1.0974e-0

tensor(0.2806, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.6299e-08, 2.1423e-11, 1.4780e-11, 9.5137e-12, 6.6153e-01, 1.5426e-11,
         3.4497e-11, 4.1558e-02, 2.0460e-11, 1.2264e-11, 2.7708e-11, 2.9587e-01,
         4.8542e-11, 1.4439e-11, 5.4443e-12, 9.7299e-11, 9.7189e-12, 5.1419e-12,
         8.1887e-12, 1.0400e-03, 9.4006e-13, 6.1943e-12, 5.4447e-10, 3.4145e-12,
         3.4720e-11, 5.4790e-12, 1.8431e-12, 8.1890e-12, 4.3258e-12],
        [3.0064e-08, 1.0990e-10, 6.5243e-11, 1.9836e-11, 5.3211e-01, 2.3944e-11,
         2.3362e-10, 2.5635e-02, 3.2547e-11, 4.3935e-11, 8.4977e-11, 4.4156e-01,
         6.4454e-10, 7.7677e-12, 4.1512e-12, 5.3319e-10, 1.0150e-11, 1.0787e-11,
         1.2614e-11, 6.9303e-04, 3.9774e-12, 7.9030e-12, 2.7027e-09, 4.9494e-12,
         8.7954e-11, 1.2689e-11, 1.0743e-12, 2.4306e-11, 1.2604e-11],
        [6.8844e-08, 2.3964e-10, 1.3001e-10, 1.4185e-10, 5.1563e-01, 1.5674e-10,
         1.7722e-10, 3.1840e-02, 9.8603e-11, 4.7967e-11, 5.1416e-11, 4.5122e-0

tensor(0.2332, grad_fn=<BinaryCrossEntropyBackward>) tensor([[8.0577e-07, 6.9644e-10, 2.5648e-10, 2.8859e-10, 1.6269e-01, 1.1924e-10,
         1.0320e-09, 2.2744e-02, 2.0319e-10, 1.1898e-10, 5.2405e-10, 8.1229e-01,
         2.4815e-09, 2.6640e-10, 3.2081e-11, 2.9151e-09, 9.3710e-11, 1.7718e-10,
         1.5190e-10, 2.2769e-03, 3.9233e-11, 4.1413e-11, 3.5179e-09, 9.5556e-11,
         3.2651e-10, 8.9598e-11, 2.9319e-11, 2.6709e-10, 1.4998e-10],
        [7.4538e-08, 9.7849e-11, 3.6758e-11, 2.1089e-11, 6.1259e-02, 2.3222e-11,
         1.3991e-10, 5.1177e-03, 3.2707e-11, 4.5632e-11, 6.6707e-11, 9.3227e-01,
         3.1612e-10, 5.8298e-11, 6.6790e-12, 1.3801e-10, 1.6151e-11, 1.4584e-11,
         3.1317e-11, 1.3502e-03, 2.9401e-12, 1.1016e-11, 6.4553e-10, 1.5699e-11,
         4.3816e-11, 1.3660e-11, 4.6271e-12, 4.7911e-11, 2.1279e-11],
        [6.4439e-07, 4.2697e-10, 4.5416e-10, 8.1326e-11, 8.9495e-02, 1.3627e-10,
         1.5718e-09, 4.3457e-02, 1.7280e-10, 1.7235e-10, 3.0597e-10, 8.6362e-0

tensor(0.2199, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.1035e-06, 1.3758e-09, 7.3714e-10, 4.5583e-10, 3.2048e-02, 5.8462e-10,
         1.9385e-09, 2.4225e-02, 5.1498e-10, 3.1181e-10, 8.0545e-10, 9.2455e-01,
         5.9627e-09, 3.3085e-10, 1.5672e-10, 1.7215e-09, 1.9549e-10, 1.9573e-10,
         7.1145e-10, 1.9176e-02, 6.3246e-11, 1.8984e-10, 7.4754e-09, 1.6066e-10,
         1.4193e-09, 3.2307e-10, 4.6956e-11, 4.0864e-10, 1.6241e-10],
        [3.1183e-06, 5.8336e-10, 2.8420e-11, 3.2547e-11, 4.6868e-02, 2.7503e-11,
         6.5226e-10, 2.6467e-02, 6.4340e-11, 6.0347e-11, 1.2139e-10, 9.1019e-01,
         1.3009e-09, 2.4633e-11, 1.0297e-11, 6.2243e-10, 1.8398e-11, 1.2733e-11,
         1.5884e-11, 1.6475e-02, 7.4389e-12, 8.4415e-12, 1.6287e-09, 1.1384e-11,
         6.0998e-11, 2.3351e-11, 4.0884e-12, 3.3916e-11, 3.6786e-11],
        [1.0385e-06, 1.6563e-10, 2.2781e-11, 1.0165e-11, 2.7516e-02, 1.7464e-11,
         3.4017e-10, 2.1343e-02, 1.6946e-11, 5.6188e-11, 1.0577e-10, 9.3540e-0

tensor(0.1172, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.1795e-06, 2.4305e-10, 8.5680e-11, 5.3034e-11, 1.2561e-02, 4.0221e-11,
         5.5465e-10, 1.6131e-02, 2.7977e-11, 1.0180e-10, 7.1582e-11, 4.0971e-01,
         5.0286e-10, 4.0941e-11, 1.3553e-11, 4.4505e-10, 3.4823e-11, 2.3146e-11,
         1.6752e-11, 5.6160e-01, 1.1193e-11, 1.5158e-11, 4.3136e-09, 1.9074e-11,
         9.5776e-11, 4.4020e-11, 1.1375e-11, 3.5561e-11, 3.7021e-11],
        [8.2755e-06, 2.6804e-09, 5.3260e-10, 4.8659e-10, 2.0589e-02, 6.1055e-10,
         3.5647e-09, 2.2312e-02, 5.6187e-10, 7.9035e-10, 9.3918e-10, 7.4700e-01,
         7.7206e-09, 7.8541e-10, 1.1515e-10, 3.4530e-09, 1.6274e-10, 2.2611e-10,
         2.0454e-10, 2.1009e-01, 8.9488e-11, 1.3029e-10, 3.6681e-08, 1.7079e-10,
         2.9821e-09, 2.1712e-10, 8.3619e-11, 4.3138e-10, 1.8794e-10],
        [5.4248e-06, 2.3356e-10, 1.0285e-10, 5.8537e-11, 2.7408e-02, 6.3375e-11,
         8.7738e-10, 3.3747e-02, 4.4938e-11, 1.1961e-10, 8.8775e-11, 4.5100e-0

tensor(0.1210, grad_fn=<BinaryCrossEntropyBackward>) tensor([[8.6096e-06, 5.5264e-10, 9.4206e-11, 5.4660e-11, 4.1081e-03, 5.6055e-11,
         1.2975e-09, 2.0701e-02, 1.0219e-10, 9.3061e-11, 9.8593e-11, 5.0962e-02,
         1.2015e-09, 9.5604e-11, 2.6649e-11, 1.6465e-09, 4.5080e-11, 4.4186e-11,
         4.0061e-11, 9.2422e-01, 1.4336e-11, 2.2400e-11, 1.1927e-09, 3.8707e-11,
         1.7951e-10, 5.1393e-11, 1.7878e-11, 5.8566e-11, 8.2052e-11],
        [8.8290e-05, 4.2147e-10, 6.0213e-10, 3.4686e-10, 5.4191e-03, 3.2813e-10,
         9.1929e-10, 1.4433e-02, 5.1506e-10, 2.6937e-10, 5.9447e-10, 5.2638e-02,
         1.4828e-09, 3.9649e-10, 4.3340e-10, 6.6745e-10, 5.4588e-10, 4.4717e-10,
         4.0449e-10, 9.2742e-01, 9.2115e-11, 1.9246e-10, 2.8912e-09, 2.8362e-10,
         9.3889e-10, 4.5738e-10, 5.9763e-11, 4.5324e-10, 1.4219e-10],
        [8.4255e-06, 2.7006e-10, 1.1346e-10, 1.5332e-10, 1.9901e-03, 8.8611e-11,
         6.9218e-10, 3.7507e-03, 1.7708e-10, 5.1956e-11, 1.8909e-10, 3.2316e-0

tensor(0.1500, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.1009e-04, 4.1723e-10, 1.8568e-10, 7.0227e-10, 7.7737e-04, 6.8575e-10,
         5.0136e-10, 8.6642e-03, 7.9415e-10, 2.2738e-10, 3.3533e-10, 1.3754e-03,
         1.5491e-09, 3.9990e-10, 6.7076e-11, 1.1886e-09, 2.7319e-10, 1.2830e-10,
         1.6147e-10, 9.8907e-01, 4.2557e-11, 1.0945e-10, 2.2624e-09, 5.3251e-10,
         8.7287e-10, 2.8799e-10, 2.3735e-11, 2.1884e-10, 3.0197e-10],
        [8.0605e-05, 4.3377e-10, 4.7884e-10, 2.3068e-10, 1.7974e-03, 1.6965e-10,
         3.2995e-09, 1.0937e-02, 4.7044e-10, 5.2569e-10, 5.1962e-10, 2.1528e-03,
         5.1661e-09, 2.4213e-10, 3.5908e-11, 3.8919e-09, 1.8241e-10, 1.0128e-10,
         2.8972e-10, 9.8503e-01, 3.6639e-11, 3.8711e-11, 8.3045e-09, 9.7059e-11,
         6.9548e-10, 1.5254e-10, 3.1445e-11, 3.5752e-10, 1.6377e-10],
        [2.8367e-05, 1.5017e-10, 4.3532e-11, 2.7541e-11, 7.3159e-04, 2.4461e-11,
         2.8100e-10, 1.0871e-02, 1.1580e-11, 5.8337e-11, 2.1677e-11, 1.8385e-0

tensor(0.1989, grad_fn=<BinaryCrossEntropyBackward>) tensor([[4.4917e-05, 5.0469e-11, 6.9893e-11, 2.6267e-11, 8.8012e-05, 1.7546e-11,
         3.7911e-10, 2.0645e-03, 3.4910e-11, 6.3457e-11, 5.6446e-11, 3.6256e-05,
         1.2882e-10, 5.0674e-11, 4.8090e-12, 1.7802e-10, 3.5967e-11, 2.9460e-11,
         3.3027e-11, 9.9777e-01, 3.8646e-12, 1.8003e-11, 5.1633e-10, 1.5955e-11,
         1.6785e-10, 5.6007e-11, 1.6218e-12, 4.4516e-11, 1.3711e-11],
        [9.7604e-06, 1.9132e-12, 1.2542e-12, 6.5568e-13, 6.9775e-05, 8.3112e-13,
         5.4715e-12, 1.3657e-03, 6.3424e-13, 3.2163e-13, 2.8871e-12, 3.7931e-05,
         1.2000e-11, 6.6879e-13, 3.0330e-13, 2.9259e-12, 4.4049e-13, 9.0308e-13,
         7.8940e-13, 9.9852e-01, 3.1851e-13, 2.5435e-13, 1.5378e-11, 6.9031e-13,
         2.5415e-12, 5.1554e-13, 1.8065e-13, 4.7279e-13, 4.3932e-13],
        [2.5225e-05, 5.3131e-11, 2.8977e-11, 1.9212e-11, 6.4263e-05, 1.1284e-11,
         2.2355e-10, 2.3574e-03, 2.2018e-11, 1.0833e-11, 4.1380e-11, 3.9655e-0

tensor(0.1919, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.6910e-04, 1.6638e-10, 4.2371e-11, 2.3253e-11, 9.3665e-05, 1.3728e-11,
         1.8792e-10, 3.5009e-02, 9.6409e-12, 2.1040e-11, 1.9894e-11, 2.8343e-05,
         8.1556e-10, 8.6951e-12, 1.2264e-12, 2.2916e-10, 4.2020e-12, 4.9103e-12,
         5.5875e-12, 9.6470e-01, 2.1791e-12, 3.2985e-12, 3.4230e-09, 8.2089e-12,
         3.4672e-11, 6.7277e-12, 2.5064e-12, 1.2429e-11, 1.7168e-11],
        [1.9146e-04, 1.1766e-11, 5.1253e-12, 7.9097e-12, 9.0520e-05, 7.3378e-12,
         1.4640e-11, 8.7686e-03, 9.2039e-12, 8.1629e-12, 7.3335e-12, 1.2108e-05,
         4.2871e-11, 5.7473e-12, 2.1029e-12, 1.7793e-11, 5.2608e-12, 4.5512e-12,
         3.5379e-12, 9.9094e-01, 2.2689e-12, 1.9181e-12, 1.2442e-10, 2.2941e-12,
         9.0521e-12, 3.9602e-12, 1.3288e-12, 4.9104e-12, 4.3346e-12],
        [1.3584e-04, 1.9727e-11, 6.9542e-12, 8.2305e-12, 4.3810e-05, 1.0873e-11,
         3.4584e-11, 1.7347e-03, 1.1017e-11, 8.1734e-12, 8.3226e-12, 4.9013e-0

tensor(0.1656, grad_fn=<BinaryCrossEntropyBackward>) tensor([[5.1807e-04, 1.0158e-11, 1.8335e-11, 1.3225e-11, 2.3864e-05, 7.9139e-12,
         1.7618e-11, 2.0117e-02, 1.1586e-11, 5.1311e-12, 3.6413e-12, 8.1978e-07,
         2.1303e-11, 9.3225e-12, 1.0740e-12, 1.7743e-11, 6.7974e-12, 7.2378e-12,
         5.7935e-12, 9.7934e-01, 1.7510e-12, 2.2064e-12, 3.4145e-10, 3.0260e-12,
         2.6769e-11, 5.7195e-12, 2.5301e-13, 1.0581e-11, 4.0882e-12],
        [2.0661e-03, 1.8346e-11, 9.0292e-12, 6.0175e-12, 7.3700e-05, 4.2259e-12,
         8.0481e-11, 6.6405e-02, 4.8575e-12, 4.5720e-12, 1.2666e-11, 3.0397e-06,
         1.4685e-10, 4.9212e-12, 2.4994e-12, 3.6961e-11, 9.9951e-12, 6.0554e-12,
         2.6958e-12, 9.3145e-01, 2.4519e-12, 3.2883e-12, 3.6243e-10, 2.5221e-12,
         2.1561e-11, 3.4281e-12, 8.4888e-13, 9.4849e-12, 4.1769e-12],
        [5.9903e-04, 3.6251e-11, 3.8825e-11, 5.3034e-12, 2.8096e-05, 9.6163e-12,
         9.1718e-11, 5.2657e-02, 1.6647e-11, 2.2076e-11, 1.3112e-11, 3.0209e-0

tensor(0.1637, grad_fn=<BinaryCrossEntropyBackward>) tensor([[3.8703e-03, 1.9580e-11, 2.5849e-11, 5.2473e-12, 6.3980e-05, 3.4567e-12,
         8.9241e-11, 5.4283e-01, 6.3663e-12, 1.3120e-11, 1.1003e-11, 1.3943e-06,
         5.1210e-11, 1.4549e-11, 5.2234e-12, 3.5241e-11, 5.7069e-12, 8.7301e-12,
         1.4707e-11, 4.5324e-01, 2.6881e-12, 2.9138e-12, 1.5616e-10, 3.0595e-12,
         1.6139e-11, 5.4413e-12, 2.3098e-12, 7.2277e-12, 2.2281e-12],
        [2.8018e-03, 1.9156e-10, 3.5382e-11, 3.4162e-11, 3.9220e-05, 3.3789e-11,
         3.2004e-10, 6.2959e-01, 5.0578e-11, 1.0713e-10, 7.9392e-11, 7.1401e-07,
         8.8186e-10, 3.0972e-11, 4.0725e-12, 2.2282e-10, 2.8244e-11, 2.5735e-11,
         2.8649e-11, 3.6757e-01, 4.9998e-12, 9.8198e-12, 6.4150e-09, 1.1508e-11,
         2.6414e-10, 3.1481e-11, 2.0684e-12, 5.5980e-11, 9.6992e-12],
        [3.5770e-03, 2.7092e-10, 5.0564e-11, 4.4374e-11, 8.7507e-05, 9.2368e-11,
         6.4026e-10, 4.8271e-01, 3.2067e-11, 2.0746e-11, 5.7683e-11, 6.9294e-0

tensor(0.2256, grad_fn=<BinaryCrossEntropyBackward>) tensor([[7.3313e-03, 3.9123e-10, 7.8751e-11, 1.1770e-10, 1.6130e-05, 4.6724e-11,
         2.0362e-10, 9.0053e-01, 5.7541e-11, 1.0606e-10, 3.0001e-11, 1.0422e-07,
         2.7982e-10, 3.2856e-11, 1.1362e-11, 5.0047e-10, 1.8994e-11, 2.8744e-11,
         3.3168e-11, 9.2122e-02, 5.1903e-12, 1.3144e-11, 1.7017e-09, 9.6963e-12,
         2.9045e-10, 3.3785e-11, 2.3255e-12, 6.2346e-11, 2.4050e-11],
        [3.4680e-02, 3.4528e-11, 1.0807e-11, 6.1758e-12, 1.7225e-05, 5.4245e-12,
         4.4159e-11, 8.8224e-01, 6.1646e-12, 1.2548e-11, 1.2617e-11, 9.1997e-08,
         5.7957e-11, 8.4781e-12, 4.5312e-13, 3.8341e-11, 2.6216e-12, 4.5780e-12,
         1.4930e-12, 8.3059e-02, 7.7268e-13, 3.2548e-12, 1.8958e-09, 2.3412e-12,
         4.6587e-11, 4.6187e-12, 2.2742e-13, 1.0921e-11, 1.7554e-12],
        [1.2330e-02, 1.2158e-11, 7.8861e-12, 2.8379e-12, 2.7971e-05, 3.9688e-12,
         1.6345e-11, 9.4113e-01, 5.2040e-12, 4.9432e-12, 3.0728e-12, 3.1551e-0

tensor(0.2398, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.5183e-02, 2.3830e-11, 1.2509e-11, 1.2612e-11, 8.8423e-06, 9.8685e-12,
         2.9940e-10, 9.2864e-01, 1.4231e-11, 1.1798e-11, 2.4445e-11, 3.3121e-08,
         1.1846e-10, 4.5060e-12, 6.9324e-13, 1.9132e-10, 3.4309e-12, 6.4157e-12,
         4.2199e-12, 5.6164e-02, 1.0911e-12, 1.8006e-12, 1.9253e-09, 2.0987e-12,
         2.7033e-11, 3.5988e-12, 2.3729e-13, 1.0888e-11, 2.7741e-12],
        [3.7168e-02, 7.5972e-11, 3.9257e-11, 7.1274e-11, 3.5245e-05, 3.8311e-11,
         2.4782e-10, 9.2342e-01, 6.2464e-11, 2.1343e-11, 5.8302e-11, 2.4956e-08,
         1.0146e-09, 4.2297e-11, 1.6492e-12, 2.7706e-10, 1.0925e-11, 1.3338e-11,
         1.6254e-11, 3.9380e-02, 6.1586e-13, 3.3629e-12, 6.0235e-09, 4.6469e-12,
         2.2305e-10, 2.9367e-11, 1.5310e-13, 2.2805e-11, 6.2391e-12],
        [3.7914e-02, 1.6721e-10, 1.1864e-10, 8.0478e-11, 3.7646e-05, 1.0372e-10,
         4.2448e-10, 7.6901e-01, 5.0446e-11, 2.9354e-11, 5.5531e-11, 6.4862e-0

tensor(0.1806, grad_fn=<BinaryCrossEntropyBackward>) tensor([[2.6820e-01, 3.5686e-11, 3.8356e-11, 1.4021e-11, 7.3947e-05, 1.9673e-11,
         1.7914e-10, 4.7885e-01, 2.9287e-11, 2.2778e-11, 3.2279e-11, 6.3485e-08,
         2.4990e-10, 1.5936e-11, 1.7579e-12, 8.1175e-11, 5.6665e-12, 8.6121e-12,
         1.5414e-11, 2.5287e-01, 1.4871e-12, 3.0681e-12, 2.1749e-09, 3.6662e-12,
         5.0398e-11, 2.4464e-11, 4.9771e-13, 1.6556e-11, 5.3584e-12],
        [5.9879e-01, 5.3953e-10, 2.7881e-10, 3.8616e-10, 1.4137e-04, 2.0796e-10,
         1.5197e-09, 2.8681e-01, 2.7775e-10, 2.1972e-10, 7.5075e-10, 5.1302e-08,
         1.6493e-09, 1.0673e-10, 1.8902e-11, 1.2936e-09, 2.2996e-10, 1.6129e-10,
         1.6666e-10, 1.1426e-01, 5.9643e-11, 1.2534e-10, 4.2913e-09, 1.0843e-10,
         3.7578e-10, 4.5744e-10, 1.9262e-11, 1.9034e-10, 1.7176e-10],
        [2.9996e-01, 1.0756e-09, 5.5764e-09, 2.4479e-09, 2.2664e-04, 2.2647e-09,
         8.2396e-09, 5.6095e-01, 3.7420e-09, 1.2422e-09, 1.1326e-09, 2.0470e-0

tensor(0.1742, grad_fn=<BinaryCrossEntropyBackward>) tensor([[6.0151e-01, 1.0283e-10, 2.5568e-11, 2.2656e-11, 4.2784e-05, 1.9674e-11,
         4.4812e-11, 2.5017e-02, 2.5602e-11, 2.7784e-11, 1.2589e-11, 1.2147e-08,
         1.1377e-10, 1.8863e-11, 5.0090e-12, 4.5076e-11, 8.8812e-12, 6.7047e-12,
         1.7967e-11, 3.7343e-01, 2.8570e-12, 6.9004e-12, 1.0423e-10, 1.0257e-11,
         2.4674e-11, 2.0629e-11, 1.3607e-12, 1.9045e-11, 1.3461e-11],
        [6.2839e-01, 1.3491e-10, 8.6403e-11, 6.7261e-11, 5.1352e-05, 1.2231e-10,
         6.6583e-10, 8.3142e-02, 1.1919e-10, 7.4123e-11, 1.7352e-10, 1.0735e-07,
         9.7500e-10, 2.9297e-11, 1.4518e-11, 6.3645e-10, 5.4134e-11, 5.1947e-11,
         1.0069e-10, 2.8842e-01, 1.3852e-11, 4.1507e-11, 6.6598e-09, 3.2439e-11,
         2.4314e-10, 1.2073e-10, 3.5894e-12, 4.2476e-11, 2.3949e-11],
        [4.2991e-01, 2.9448e-09, 5.5191e-10, 1.6575e-10, 1.1068e-04, 3.1855e-10,
         1.7389e-09, 3.2506e-02, 1.2690e-10, 5.1610e-10, 3.2766e-10, 5.1789e-0

tensor(0.1863, grad_fn=<BinaryCrossEntropyBackward>) tensor([[2.9546e-01, 1.8868e-11, 9.6322e-12, 1.3432e-11, 2.5940e-05, 1.0874e-11,
         7.8675e-11, 5.7752e-03, 9.1793e-12, 1.9657e-11, 1.1640e-11, 3.1239e-08,
         2.9048e-11, 1.6175e-11, 4.9030e-12, 4.7488e-11, 1.8617e-11, 4.6387e-12,
         5.6100e-12, 6.9874e-01, 3.0141e-12, 5.5651e-12, 1.2974e-10, 5.6649e-12,
         2.0144e-11, 1.9768e-11, 5.1478e-13, 1.8005e-11, 1.6946e-11],
        [4.9017e-01, 3.0102e-11, 1.0828e-11, 2.9146e-12, 2.2251e-05, 3.0046e-12,
         1.6692e-10, 5.1276e-03, 2.3084e-12, 1.5489e-11, 1.8521e-11, 5.4318e-09,
         2.1975e-10, 1.5111e-12, 3.9948e-13, 6.4857e-11, 1.4319e-12, 1.5904e-12,
         1.0249e-12, 5.0468e-01, 6.0785e-13, 6.2036e-13, 8.8353e-10, 6.5715e-13,
         2.8936e-11, 1.5378e-12, 1.1243e-13, 3.2581e-12, 1.3426e-12],
        [4.3888e-01, 2.6027e-10, 8.5590e-11, 2.5172e-11, 5.8865e-05, 2.9083e-11,
         3.3883e-10, 1.9376e-02, 1.5028e-11, 8.0523e-11, 1.4007e-10, 7.7440e-0

tensor(0.1923, grad_fn=<BinaryCrossEntropyBackward>) tensor([[5.9893e-01, 4.1752e-11, 6.5128e-12, 8.5666e-12, 4.2250e-05, 9.5775e-12,
         1.2452e-10, 5.2619e-04, 8.7667e-12, 1.0447e-11, 6.2502e-12, 3.5208e-09,
         8.2779e-11, 5.6789e-12, 9.1179e-13, 6.8947e-11, 3.3510e-12, 6.6669e-12,
         2.3543e-12, 4.0050e-01, 8.2153e-13, 2.8123e-12, 7.9265e-10, 1.6441e-12,
         3.3490e-11, 4.6236e-12, 1.9383e-13, 6.4076e-12, 2.3322e-12],
        [6.2753e-02, 2.2155e-12, 3.9332e-12, 4.3026e-12, 4.7236e-06, 2.4449e-12,
         6.2547e-12, 1.2404e-04, 1.6198e-12, 1.9264e-12, 1.2472e-12, 6.2381e-09,
         8.2870e-12, 2.2692e-12, 1.0310e-13, 4.6243e-12, 1.7893e-12, 1.3189e-12,
         1.2367e-12, 9.3712e-01, 2.0078e-13, 6.6703e-13, 2.1198e-10, 7.0307e-13,
         9.3247e-12, 1.3321e-12, 3.9468e-14, 2.1127e-12, 2.7366e-13],
        [8.1275e-02, 8.7153e-12, 9.4350e-12, 8.5470e-12, 1.0652e-05, 3.9194e-12,
         3.2494e-11, 5.5356e-04, 5.6461e-12, 4.0930e-12, 5.9963e-12, 1.4970e-0

tensor(0.2087, grad_fn=<BinaryCrossEntropyBackward>) tensor([[8.7294e-02, 8.1297e-11, 1.9016e-11, 3.5237e-11, 6.2816e-05, 1.7132e-11,
         4.6749e-11, 1.0764e-04, 1.5790e-11, 2.2671e-11, 1.4880e-11, 5.4776e-09,
         2.0275e-10, 1.1495e-11, 8.8074e-13, 7.5480e-11, 3.1923e-12, 1.1866e-11,
         7.3746e-12, 9.1254e-01, 3.5037e-12, 3.5087e-12, 7.3071e-10, 2.4623e-12,
         1.4437e-10, 1.0042e-11, 4.1840e-13, 1.4389e-11, 2.9692e-12],
        [8.8205e-03, 1.8029e-12, 3.4783e-13, 3.4790e-13, 4.6749e-06, 4.7350e-13,
         6.0734e-13, 6.3853e-06, 2.1289e-13, 2.6595e-13, 7.8303e-13, 7.4375e-09,
         1.1215e-12, 1.6424e-13, 5.1405e-14, 3.3027e-13, 2.4409e-13, 1.6123e-13,
         1.3652e-13, 9.9117e-01, 3.5232e-14, 1.1422e-13, 1.9660e-11, 9.3661e-14,
         1.2257e-12, 1.4901e-13, 2.3138e-14, 2.4063e-13, 1.1932e-13],
        [2.6826e-02, 6.9642e-12, 9.6232e-12, 8.6858e-12, 1.1603e-05, 1.5861e-11,
         6.0103e-11, 6.0143e-05, 8.3699e-12, 6.7572e-12, 1.2789e-11, 1.6042e-0

tensor(0.2255, grad_fn=<BinaryCrossEntropyBackward>) tensor([[2.0360e-02, 1.6355e-12, 3.9805e-12, 1.9225e-12, 1.4453e-05, 6.5325e-12,
         4.6515e-12, 6.9815e-06, 2.2907e-12, 3.6180e-12, 3.5889e-12, 3.4036e-08,
         3.1425e-12, 4.9155e-12, 1.6723e-13, 2.5620e-12, 1.3549e-12, 1.4783e-12,
         3.3248e-12, 9.7962e-01, 9.3186e-13, 1.4949e-12, 1.1674e-11, 1.7205e-12,
         6.1456e-12, 2.0477e-12, 1.0731e-13, 1.7831e-12, 8.5991e-13],
        [1.7449e-02, 5.2919e-10, 2.5657e-11, 2.6419e-11, 2.0640e-05, 2.3387e-11,
         4.5538e-10, 1.7437e-05, 4.8899e-11, 1.0879e-10, 1.4877e-10, 5.1213e-09,
         1.5072e-09, 2.0721e-11, 9.7960e-13, 1.0673e-09, 1.1452e-11, 7.8737e-12,
         3.5868e-11, 9.8251e-01, 1.2749e-11, 8.9315e-12, 7.2290e-09, 1.1915e-11,
         2.3032e-10, 2.0885e-11, 1.4460e-12, 2.4688e-11, 3.9121e-11],
        [3.2246e-03, 2.8717e-12, 8.1892e-12, 2.6267e-12, 1.0291e-05, 1.9432e-12,
         2.0090e-11, 1.1653e-05, 2.6150e-12, 4.3040e-12, 1.5430e-12, 2.6263e-0

tensor(0.2143, grad_fn=<BinaryCrossEntropyBackward>) tensor([[2.4647e-02, 2.9808e-11, 1.5420e-11, 8.8063e-12, 5.4775e-05, 1.1237e-11,
         9.3563e-11, 4.7171e-06, 6.9667e-12, 1.1973e-11, 1.6621e-11, 2.1762e-08,
         1.6716e-10, 6.5562e-12, 1.5002e-12, 1.1547e-10, 3.0082e-12, 6.4890e-12,
         7.5959e-12, 9.7529e-01, 3.0062e-12, 1.7382e-12, 1.9249e-10, 2.4706e-12,
         2.2895e-11, 5.5807e-12, 5.2784e-13, 6.2968e-12, 6.1270e-12],
        [3.4240e-02, 3.2152e-12, 8.0156e-13, 9.2365e-13, 2.6377e-05, 7.9280e-13,
         3.0657e-11, 4.5516e-06, 8.4327e-13, 2.2171e-12, 5.7732e-12, 7.0783e-09,
         4.0515e-11, 5.4298e-13, 1.3243e-13, 6.5779e-11, 1.9969e-13, 3.3750e-13,
         2.2594e-13, 9.6573e-01, 7.4060e-14, 2.3618e-13, 1.5716e-10, 1.3297e-13,
         4.9121e-12, 5.9303e-13, 2.6384e-14, 4.8009e-13, 4.4686e-13],
        [1.1905e-02, 5.4275e-11, 2.3693e-11, 1.6960e-11, 2.3286e-05, 1.9846e-11,
         1.5808e-10, 3.5584e-06, 5.5405e-12, 2.7588e-11, 2.1977e-11, 2.4077e-0

tensor(0.2082, grad_fn=<BinaryCrossEntropyBackward>) tensor([[2.8628e-02, 1.6487e-11, 1.7421e-12, 9.9990e-13, 2.9385e-05, 1.2528e-12,
         1.6594e-11, 2.3263e-06, 1.3140e-12, 6.7072e-12, 4.4044e-12, 1.5209e-08,
         1.7214e-11, 7.8094e-13, 1.8185e-13, 1.5222e-11, 3.7056e-13, 5.6579e-13,
         2.5668e-13, 9.7134e-01, 3.6693e-13, 8.3166e-13, 1.5039e-10, 1.5107e-13,
         1.0034e-11, 9.2910e-13, 8.9507e-14, 1.3821e-12, 9.5015e-13],
        [1.1454e-02, 6.9128e-11, 4.2478e-11, 7.3984e-12, 1.8624e-05, 7.5440e-12,
         6.2138e-11, 2.0146e-06, 9.2811e-12, 3.4230e-11, 2.6175e-11, 1.6148e-08,
         1.2141e-10, 2.0388e-11, 1.5403e-12, 6.4950e-11, 7.2499e-12, 1.0354e-11,
         6.3006e-12, 9.8852e-01, 4.0289e-12, 5.2816e-12, 3.9065e-10, 9.3608e-12,
         3.6958e-11, 8.9915e-12, 1.3043e-12, 8.9469e-12, 4.3875e-12],
        [1.9701e-02, 2.8095e-10, 9.4486e-11, 1.1785e-10, 2.9698e-05, 1.0725e-10,
         2.2328e-10, 2.3066e-06, 9.2329e-11, 6.0919e-11, 1.3957e-10, 2.9154e-0

tensor(0.1937, grad_fn=<BinaryCrossEntropyBackward>) tensor([[4.9187e-02, 1.4883e-12, 2.3995e-12, 8.8921e-13, 1.7853e-04, 8.2670e-13,
         8.2816e-12, 2.3014e-06, 4.1932e-13, 5.9953e-13, 5.8351e-13, 9.8618e-08,
         4.9694e-12, 5.2168e-13, 3.1191e-13, 3.1744e-12, 6.5732e-13, 1.0134e-12,
         2.9232e-13, 9.5063e-01, 1.9762e-13, 5.5095e-13, 1.7918e-11, 4.7531e-13,
         2.1988e-12, 4.1496e-13, 2.3531e-13, 7.4917e-13, 2.8094e-13],
        [3.5949e-01, 1.2234e-10, 7.2603e-11, 6.2775e-11, 1.5288e-04, 7.0369e-11,
         2.9022e-10, 2.5687e-06, 7.6370e-11, 3.2867e-11, 5.7065e-11, 5.0356e-08,
         2.3201e-10, 7.7199e-11, 6.1672e-12, 7.4166e-11, 5.0711e-11, 5.5048e-11,
         3.7513e-11, 6.4036e-01, 2.6897e-11, 2.4033e-11, 5.4661e-10, 3.8252e-11,
         1.3070e-10, 6.1566e-11, 2.9047e-12, 5.7441e-11, 1.7803e-11],
        [8.7340e-02, 4.1245e-10, 1.5723e-10, 1.3415e-10, 1.3246e-04, 1.7001e-10,
         1.0617e-09, 8.5107e-07, 8.7692e-11, 1.0780e-10, 1.6377e-10, 3.2994e-0

tensor(0.1926, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.5083e-01, 1.3873e-10, 1.5595e-11, 4.6588e-11, 3.6143e-04, 2.0216e-11,
         9.5711e-11, 6.2985e-07, 2.7818e-11, 2.4609e-11, 3.1297e-11, 1.1148e-07,
         1.0330e-10, 2.1165e-11, 2.3321e-12, 7.9100e-11, 1.7590e-11, 2.0129e-11,
         8.4939e-12, 8.4880e-01, 6.7847e-12, 1.1886e-11, 3.4249e-10, 6.1391e-12,
         3.8267e-11, 2.1496e-11, 1.5750e-12, 1.5028e-11, 1.5669e-11],
        [1.1262e-01, 2.5188e-13, 2.0225e-13, 4.5462e-13, 2.0957e-04, 2.2584e-13,
         3.5806e-13, 3.9663e-07, 1.9933e-13, 1.3211e-13, 8.9574e-14, 1.2041e-07,
         3.4050e-13, 2.6776e-13, 1.9072e-13, 2.6064e-13, 3.1204e-13, 1.7428e-13,
         1.9578e-13, 8.8717e-01, 1.5413e-13, 7.8470e-14, 7.2120e-12, 1.0734e-13,
         3.1913e-13, 1.4980e-13, 4.9515e-14, 3.0552e-13, 9.2608e-14],
        [7.9015e-02, 9.2949e-11, 1.7135e-11, 1.5463e-11, 4.2920e-04, 1.9481e-11,
         9.0379e-11, 6.6156e-07, 2.4274e-11, 3.3145e-11, 2.7324e-11, 1.6649e-0

tensor(0.1829, grad_fn=<BinaryCrossEntropyBackward>) tensor([[5.1138e-01, 9.7576e-11, 4.5048e-11, 3.3034e-11, 2.5549e-03, 3.0007e-11,
         4.3015e-10, 1.5891e-06, 3.2838e-11, 2.9064e-11, 3.7289e-11, 2.9584e-07,
         2.3323e-10, 1.8523e-11, 5.8828e-12, 3.3530e-10, 7.2901e-12, 1.1832e-11,
         9.7596e-12, 4.8606e-01, 2.2723e-12, 7.8529e-12, 2.5122e-09, 8.2594e-12,
         9.1559e-11, 9.6799e-12, 1.2127e-12, 2.2909e-11, 3.6638e-12],
        [3.2031e-01, 5.3548e-11, 9.6265e-12, 2.3352e-11, 7.1205e-04, 9.7280e-12,
         1.5918e-11, 8.9362e-07, 1.8330e-11, 1.8680e-11, 1.5284e-11, 3.9230e-07,
         8.3283e-11, 7.2411e-12, 1.1652e-12, 4.9565e-11, 2.2280e-12, 5.0356e-12,
         5.7722e-12, 6.7897e-01, 1.6313e-12, 2.2186e-12, 1.3255e-10, 6.4446e-12,
         5.4217e-11, 5.2735e-12, 5.2872e-13, 9.1996e-12, 8.0073e-12],
        [4.1480e-01, 1.5201e-10, 9.1500e-11, 4.1185e-11, 5.6023e-03, 3.2941e-11,
         3.0021e-10, 9.3146e-07, 2.7729e-11, 4.3965e-11, 3.8694e-11, 1.2771e-0

tensor(0.2118, grad_fn=<BinaryCrossEntropyBackward>) tensor([[7.2167e-01, 1.9618e-07, 7.6869e-08, 3.9481e-08, 2.5815e-02, 5.7579e-08,
         9.2127e-08, 2.2195e-06, 2.3858e-08, 4.3036e-08, 3.4392e-08, 1.2504e-05,
         2.0309e-07, 3.1492e-08, 5.6709e-09, 3.8661e-08, 1.3472e-08, 1.4783e-08,
         1.5057e-08, 2.5250e-01, 2.4383e-09, 5.9821e-09, 8.7189e-07, 7.8199e-09,
         1.1961e-07, 1.0056e-08, 2.0112e-09, 1.5294e-08, 8.5694e-09],
        [8.1798e-01, 4.7120e-10, 2.8204e-10, 2.0582e-10, 4.8965e-03, 1.2462e-10,
         9.2844e-10, 6.2531e-07, 1.1277e-10, 1.9691e-10, 1.0970e-10, 5.4993e-07,
         1.5913e-09, 6.9723e-11, 5.0407e-12, 1.6688e-09, 2.7177e-11, 6.0487e-11,
         2.3019e-11, 1.7713e-01, 2.1851e-11, 3.5623e-11, 2.4782e-09, 7.6475e-12,
         2.8748e-10, 1.2682e-10, 4.0409e-12, 1.0485e-10, 8.1624e-11],
        [6.4155e-01, 2.0112e-10, 6.5305e-10, 2.3168e-10, 2.4592e-02, 2.7073e-10,
         1.4746e-09, 2.8459e-06, 1.6264e-10, 3.1519e-10, 2.2958e-10, 3.1164e-0

tensor(0.2060, grad_fn=<BinaryCrossEntropyBackward>) tensor([[4.9793e-01, 9.0754e-10, 2.7323e-10, 7.9852e-11, 2.6481e-02, 9.9807e-11,
         4.2532e-09, 2.3809e-06, 6.7226e-11, 8.5219e-10, 4.5518e-10, 7.0932e-06,
         1.4589e-09, 4.7215e-11, 1.2471e-11, 3.1714e-09, 3.3628e-11, 4.7730e-11,
         4.5728e-11, 4.7558e-01, 2.7698e-11, 2.8785e-11, 2.7677e-08, 4.7095e-11,
         3.4164e-10, 7.8403e-11, 5.5121e-12, 1.1989e-10, 6.3105e-11],
        [8.6886e-01, 3.9761e-11, 6.5891e-12, 8.8430e-12, 2.4109e-02, 6.0795e-12,
         9.9909e-11, 4.2389e-07, 6.1049e-12, 1.5399e-11, 2.6844e-11, 1.3512e-06,
         4.3591e-11, 8.3458e-12, 1.3897e-11, 6.8198e-11, 6.5850e-12, 5.1682e-12,
         4.8530e-12, 1.0703e-01, 5.3549e-12, 2.2012e-12, 2.3516e-10, 2.6636e-12,
         1.1543e-11, 4.4282e-12, 1.5040e-12, 7.7830e-12, 4.6708e-12],
        [5.9195e-01, 3.5956e-09, 3.4843e-09, 5.9307e-10, 4.2734e-02, 1.0051e-09,
         4.2317e-09, 1.2026e-06, 6.7632e-10, 2.4527e-09, 1.7908e-09, 1.3982e-0

In [None]:
print(torch.argmax(real_sample, dim=1))
sample, _ = transformer()
print(torch.argmax(sample, dim=1))

In [None]:
print(__)