<a href="https://colab.research.google.com/github/sirmammingtonham/topics_nlp/blob/main/sru.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [301]:
class SRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, activation=F.tanh):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.linear = nn.Linear(input_size, hidden_size, bias=False) #W
        self.gate = nn.Linear(input_size, 2*hidden_size, bias=True) #W' and W''
        self.activation = activation

        self.gate_norm = nn.LayerNorm(2 * hidden_size)
        self.c_norm = nn.LayerNorm(hidden_size)
        
        # self.v = nn.Parameter(torch.randn(2*hidden_size))

        self.dropout = nn.Dropout(0.1)

    def forward(self, x, c=None):
        if c is None:
            c = torch.zeros((x.shape[0], self.hidden_size), dtype=x.dtype, device=x.device)

        xt = self.linear(x)
        # gate = torch.sigmoid(self.gate_norm(self.gate(x) + torch.einsum('bs,bs->b', self.v, c)))
        gate = torch.sigmoid(self.gate_norm(self.gate(x)))
        f = gate[:, :, :self.hidden_size]
        r = gate[:, :, self.hidden_size:]
        z = (1.0 - f) * xt + f

        hidden_states = []
        for t in range(x.shape[1]):
            c = f[:, t] * c + z[:, t]
            hidden_states.append(c)

        c_stack = torch.stack(hidden_states, dim=1)
        h = r * self.activation(self.c_norm(c_stack)) + (1.0 - r) * xt
        return h, c

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            val_range = (3.0 / self.input_size) ** 0.5
            module.weight.data.uniform_(-val_range, val_range)
            if module.bias is not None:
                module.bias.data.zero_()

In [380]:
class SRU(nn.Module):
    def __init__(self, num_layers=2, num_digits=16, hidden_dim=128):
        super().__init__()
        self.embed = nn.Embedding(num_embeddings=num_digits, embedding_dim=num_digits)
        self.dropout_embed = nn.Dropout(0.5)
        # self.srus = nn.LSTM(input_size=num_digits, hidden_size=hidden_dim, batch_first=True)#nn.Sequential()
        self.srus = []
        for i in range(num_layers):
            self.srus.append(
                SRUCell(
                    input_size=num_digits if i == 0 else hidden_dim, 
                    hidden_size=hidden_dim
                ).cuda()
            )

        self.linear = nn.Linear(in_features=hidden_dim, out_features=num_digits)
        self.softmax = nn.Softmax(-1)

    def forward(self, x):
        x = self.embed(x)
        x = self.dropout_embed(x)
        c = None
        for sru_layer in self.srus:
            x, c = sru_layer(x, c)

        x, _ = x.max(dim=1)
        x = self.linear(x)
        return self.softmax(x)

In [381]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm.autonotebook import tqdm

In [382]:
NUM_DIGITS = 16
class CountingDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.int_to_binary(self.data[idx]), self.int_to_binary(self.data[idx] + 1)

    def int_to_binary(self, n):
        binary_list = [int(i) for i in bin(n)[2:]]
        while len(binary_list) < NUM_DIGITS:
            binary_list.insert(0, 0)
        return torch.LongTensor(binary_list)

examples = np.arange(2**NUM_DIGITS-1)
np.random.shuffle(examples)

test_size = 10000
train_ds = CountingDataset(examples[:-test_size])
test_ds = CountingDataset(examples[-test_size:])

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=128, shuffle=True)

In [383]:
model = SRU(num_layers=2, num_digits=16, hidden_dim=128).cuda()
criterion = nn.BCELoss(size_average=True).cuda()
optimizer = torch.optim.Adam(model.parameters())

In [384]:
def evaluate(model, eval_loader):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for i, (inputs, labels) in enumerate(tqdm(eval_loader)):
            inputs = inputs.cuda()
            labels = labels.cuda()

            predicted = model(inputs)
            total += labels.shape[0]
            predicted = (predicted > 0.5).float()
            correct += (predicted == labels).all(1).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [385]:
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    for i, (inputs, labels) in enumerate(tqdm(train_loader)):
        inputs = inputs.cuda()
        labels = labels.cuda()

        outputs = model(inputs)
        loss = criterion(outputs, labels.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    eval_accuracy = evaluate(model, test_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Eval Accuracy: {eval_accuracy:.2f}')

  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [1/100], Loss: 1.3988, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [2/100], Loss: 1.3675, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [3/100], Loss: 1.3745, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [4/100], Loss: 1.3970, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [5/100], Loss: 1.3963, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [6/100], Loss: 1.4290, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [7/100], Loss: 1.3628, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [8/100], Loss: 1.3570, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [9/100], Loss: 1.3467, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [10/100], Loss: 1.3519, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [11/100], Loss: 1.3929, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [12/100], Loss: 1.3614, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [13/100], Loss: 1.3658, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [14/100], Loss: 1.4044, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [15/100], Loss: 1.3176, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [16/100], Loss: 1.4323, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [17/100], Loss: 1.3724, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [18/100], Loss: 1.3891, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [19/100], Loss: 1.3758, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [20/100], Loss: 1.3068, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [21/100], Loss: 1.3816, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [22/100], Loss: 1.3835, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [23/100], Loss: 1.3134, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [24/100], Loss: 1.3688, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [25/100], Loss: 1.2709, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [26/100], Loss: 1.3267, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [27/100], Loss: 1.3106, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [28/100], Loss: 1.3175, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [29/100], Loss: 1.2919, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [30/100], Loss: 1.3667, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [31/100], Loss: 1.3065, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [32/100], Loss: 1.3903, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [33/100], Loss: 1.3261, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [34/100], Loss: 1.3377, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [35/100], Loss: 1.4111, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [36/100], Loss: 1.3503, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [37/100], Loss: 1.3906, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [38/100], Loss: 1.3427, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [39/100], Loss: 1.3195, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [40/100], Loss: 1.3682, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [41/100], Loss: 1.3197, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [42/100], Loss: 1.3054, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [43/100], Loss: 1.3777, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [44/100], Loss: 1.3969, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [45/100], Loss: 1.3927, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [46/100], Loss: 1.3258, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [47/100], Loss: 1.3550, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [48/100], Loss: 1.3679, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [49/100], Loss: 1.3643, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [50/100], Loss: 1.2819, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [51/100], Loss: 1.3433, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [52/100], Loss: 1.3646, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [53/100], Loss: 1.3802, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [54/100], Loss: 1.3692, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [55/100], Loss: 1.3465, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [56/100], Loss: 1.3432, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [57/100], Loss: 1.2678, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [58/100], Loss: 1.3269, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [59/100], Loss: 1.3713, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [60/100], Loss: 1.3330, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [61/100], Loss: 1.3899, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [62/100], Loss: 1.3643, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [63/100], Loss: 1.3741, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [64/100], Loss: 1.3620, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [65/100], Loss: 1.3799, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [66/100], Loss: 1.3623, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [67/100], Loss: 1.4361, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [68/100], Loss: 1.3285, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [69/100], Loss: 1.3883, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [70/100], Loss: 1.3859, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [71/100], Loss: 1.3717, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [72/100], Loss: 1.3785, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [73/100], Loss: 1.3188, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [74/100], Loss: 1.4037, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [75/100], Loss: 1.3443, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [76/100], Loss: 1.3931, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [77/100], Loss: 1.3214, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [78/100], Loss: 1.3249, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [79/100], Loss: 1.3250, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [80/100], Loss: 1.3499, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [81/100], Loss: 1.3943, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [82/100], Loss: 1.3140, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [83/100], Loss: 1.3647, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [84/100], Loss: 1.3504, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [85/100], Loss: 1.3444, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [86/100], Loss: 1.3407, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [87/100], Loss: 1.3038, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [88/100], Loss: 1.3546, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [89/100], Loss: 1.3333, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [90/100], Loss: 1.3084, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [91/100], Loss: 1.4121, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [92/100], Loss: 1.3137, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [93/100], Loss: 1.3185, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [94/100], Loss: 1.3844, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [95/100], Loss: 1.3862, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [96/100], Loss: 1.3294, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [97/100], Loss: 1.3850, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [98/100], Loss: 1.3844, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [99/100], Loss: 1.3025, Eval Accuracy: 0.00


  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch [100/100], Loss: 1.3848, Eval Accuracy: 0.00


In [389]:
_, (example, label) = next(enumerate(test_loader))
example = example.cuda()
label = label.cuda()

In [400]:
preds = model(example)



In [401]:
preds = (preds > 0.1).int()

In [404]:
example[0]

tensor([0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1], device='cuda:0')

In [402]:
preds[0]

tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0',
       dtype=torch.int32)