In [4]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


In [3]:
import os
import cv2
import numpy as np
from tqdm import tqdm

REBUILD_DATA = False

class DogsVSCats():
    IMG_SIZE = 50
    CATS = 'PetImages/Cat'
    DOGS = 'PetImages/Dog'
    LABELS = {CATS: 0, DOGS: 1}
    training_data = []
    catcount = 0
    dogcount = 0
    
    def make_training_data(self):
        for label in self.LABELS:
            for f in tqdm(os.listdir(label)):
                try:
                    path = os.path.join(label, f)
                    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
                    img = cv2.resize(img, (self.IMG_SIZE, self.IMG_SIZE))
                    self.training_data.append([np.array(img), np.eye(2)[self.LABELS[label]]])
                    if label == self.CATS:
                        self.catcount += 1
                    elif label == self.DOGS:
                        self.dogcount += 1
                except Exception as e:
                    pass
                
        np.random.shuffle(self.training_data)
        np.save('training_data.npy', self.training_data)
        print('Cats:', self.catcount)
        print('Dogs:', self.dogcount)
        
if REBUILD_DATA:
    dogsvcats = DogsVSCats()
    dogsvcats.make_training_data()
    
training_data = np.load('training_data.npy', allow_pickle = True)
print(len(training_data))

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)
        
        x = torch.randn(50, 50).view(-1, 1, 50, 50)
        self._to_linear = None
        self.convs(x)
        self.fc1 = nn.Linear(self._to_linear, 512)
        self.fc2 = nn.Linear(512, 2)
        
    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
        
        print(x[0].shape)
        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
#         print(x)
        return x
    
    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim = 1)
    
net = Net()
net = net.double()

import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr = 0.001)
loss_function = nn.MSELoss()

X = torch.tensor([i[0] for i in training_data]).view(-1, 50, 50)
X = X/255.0
y = torch.tensor([i[1] for i in training_data])

VAL_PCT = 0.1
val_size = int(len(X)*VAL_PCT)
print(val_size)

train_X = X[:-val_size].double()
train_y = y[:-val_size].double()

test_X = X[-val_size:].double()
test_y = y[-val_size:].double()

print(len(train_X))
print(len(test_X))

BATCH_SIZE = 100
EPOCHS = 1

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH_SIZE)):
        batch_X = train_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50)
        batch_y = train_y[i:i+BATCH_SIZE]
        net.zero_grad()
        outputs = net(batch_X)
        loss = loss_function(outputs, batch_y)
        loss.backward()
        optimizer.step()
print(loss)

correct = 0
total = 0
with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        net_out = net(test_X[i].view(-1, 1, 50, 50))[0]
        print(real_class, net_out)
        predicted_class = torch.argmax(net_out)
        if predicted_class == real_class:
            correct += 1
        total += 1
        
print("Accuracy:", round(correct/total, 3))

24946
torch.Size([128, 2, 2])


  0%|          | 0/225 [00:00<?, ?it/s]

2494
22452
2494
torch.Size([128, 2, 2])


  0%|          | 1/225 [00:00<02:45,  1.35it/s]

torch.Size([128, 2, 2])


  1%|          | 2/225 [00:01<02:38,  1.40it/s]

torch.Size([128, 2, 2])


  1%|▏         | 3/225 [00:02<02:34,  1.44it/s]

torch.Size([128, 2, 2])


  2%|▏         | 4/225 [00:02<02:33,  1.44it/s]

torch.Size([128, 2, 2])


  2%|▏         | 5/225 [00:03<02:29,  1.48it/s]

torch.Size([128, 2, 2])


  3%|▎         | 6/225 [00:04<02:27,  1.48it/s]

torch.Size([128, 2, 2])


  3%|▎         | 7/225 [00:04<02:27,  1.48it/s]

torch.Size([128, 2, 2])


  4%|▎         | 8/225 [00:05<02:24,  1.50it/s]

torch.Size([128, 2, 2])


  4%|▍         | 9/225 [00:06<02:22,  1.52it/s]

torch.Size([128, 2, 2])


  4%|▍         | 10/225 [00:06<02:22,  1.51it/s]

torch.Size([128, 2, 2])


  5%|▍         | 11/225 [00:07<02:20,  1.53it/s]

torch.Size([128, 2, 2])


  5%|▌         | 12/225 [00:07<02:20,  1.52it/s]

torch.Size([128, 2, 2])


  6%|▌         | 13/225 [00:08<02:20,  1.51it/s]

torch.Size([128, 2, 2])


  6%|▌         | 14/225 [00:09<02:21,  1.49it/s]

torch.Size([128, 2, 2])


  7%|▋         | 15/225 [00:10<02:23,  1.46it/s]

torch.Size([128, 2, 2])


  7%|▋         | 16/225 [00:10<02:22,  1.46it/s]

torch.Size([128, 2, 2])


  8%|▊         | 17/225 [00:11<02:22,  1.46it/s]

torch.Size([128, 2, 2])


  8%|▊         | 18/225 [00:12<02:22,  1.45it/s]

torch.Size([128, 2, 2])


  8%|▊         | 19/225 [00:12<02:21,  1.46it/s]

torch.Size([128, 2, 2])


  9%|▉         | 20/225 [00:13<02:20,  1.46it/s]

torch.Size([128, 2, 2])


  9%|▉         | 21/225 [00:14<02:21,  1.44it/s]

torch.Size([128, 2, 2])


 10%|▉         | 22/225 [00:14<02:22,  1.42it/s]

torch.Size([128, 2, 2])


 10%|█         | 23/225 [00:15<02:22,  1.42it/s]

torch.Size([128, 2, 2])


 11%|█         | 24/225 [00:16<02:21,  1.42it/s]

torch.Size([128, 2, 2])


 11%|█         | 25/225 [00:17<02:22,  1.40it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 26/225 [00:17<02:22,  1.40it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 27/225 [00:18<02:21,  1.40it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 28/225 [00:19<02:21,  1.39it/s]

torch.Size([128, 2, 2])


 13%|█▎        | 29/225 [00:19<02:23,  1.37it/s]

torch.Size([128, 2, 2])


 13%|█▎        | 30/225 [00:20<02:22,  1.37it/s]

torch.Size([128, 2, 2])


 14%|█▍        | 31/225 [00:21<02:20,  1.38it/s]

torch.Size([128, 2, 2])


 14%|█▍        | 32/225 [00:22<02:19,  1.38it/s]

torch.Size([128, 2, 2])


 15%|█▍        | 33/225 [00:22<02:19,  1.38it/s]

torch.Size([128, 2, 2])


 15%|█▌        | 34/225 [00:23<02:16,  1.40it/s]

torch.Size([128, 2, 2])


 16%|█▌        | 35/225 [00:24<02:14,  1.41it/s]

torch.Size([128, 2, 2])


 16%|█▌        | 36/225 [00:24<02:13,  1.41it/s]

torch.Size([128, 2, 2])


 16%|█▋        | 37/225 [00:25<02:12,  1.42it/s]

torch.Size([128, 2, 2])


 17%|█▋        | 38/225 [00:26<02:10,  1.43it/s]

torch.Size([128, 2, 2])


 17%|█▋        | 39/225 [00:27<02:10,  1.42it/s]

torch.Size([128, 2, 2])


 18%|█▊        | 40/225 [00:27<02:10,  1.42it/s]

torch.Size([128, 2, 2])


 18%|█▊        | 41/225 [00:28<02:12,  1.39it/s]

torch.Size([128, 2, 2])


 19%|█▊        | 42/225 [00:29<02:12,  1.38it/s]

torch.Size([128, 2, 2])


 19%|█▉        | 43/225 [00:30<02:14,  1.35it/s]

torch.Size([128, 2, 2])


 20%|█▉        | 44/225 [00:30<02:15,  1.34it/s]

torch.Size([128, 2, 2])


 20%|██        | 45/225 [00:31<02:12,  1.36it/s]

torch.Size([128, 2, 2])


 20%|██        | 46/225 [00:32<02:11,  1.36it/s]

torch.Size([128, 2, 2])


 21%|██        | 47/225 [00:33<02:12,  1.35it/s]

torch.Size([128, 2, 2])


 21%|██▏       | 48/225 [00:33<02:13,  1.33it/s]

torch.Size([128, 2, 2])


 22%|██▏       | 49/225 [00:34<02:10,  1.35it/s]

torch.Size([128, 2, 2])


 22%|██▏       | 50/225 [00:35<02:09,  1.35it/s]

torch.Size([128, 2, 2])


 23%|██▎       | 51/225 [00:35<02:08,  1.36it/s]

torch.Size([128, 2, 2])


 23%|██▎       | 52/225 [00:36<02:05,  1.38it/s]

torch.Size([128, 2, 2])


 24%|██▎       | 53/225 [00:37<02:04,  1.38it/s]

torch.Size([128, 2, 2])


 24%|██▍       | 54/225 [00:38<02:03,  1.39it/s]

torch.Size([128, 2, 2])


 24%|██▍       | 55/225 [00:38<02:02,  1.39it/s]

torch.Size([128, 2, 2])


 25%|██▍       | 56/225 [00:39<01:59,  1.41it/s]

torch.Size([128, 2, 2])


 25%|██▌       | 57/225 [00:40<01:59,  1.40it/s]

torch.Size([128, 2, 2])


 26%|██▌       | 58/225 [00:40<01:58,  1.41it/s]

torch.Size([128, 2, 2])


 26%|██▌       | 59/225 [00:41<01:56,  1.42it/s]

torch.Size([128, 2, 2])


 27%|██▋       | 60/225 [00:42<01:55,  1.42it/s]

torch.Size([128, 2, 2])


 27%|██▋       | 61/225 [00:42<01:53,  1.44it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 62/225 [00:43<01:54,  1.43it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 63/225 [00:44<01:55,  1.40it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 64/225 [00:45<01:54,  1.40it/s]

torch.Size([128, 2, 2])


 29%|██▉       | 65/225 [00:45<01:52,  1.42it/s]

torch.Size([128, 2, 2])


 29%|██▉       | 66/225 [00:46<01:50,  1.44it/s]

torch.Size([128, 2, 2])


 30%|██▉       | 67/225 [00:47<01:48,  1.45it/s]

torch.Size([128, 2, 2])


 30%|███       | 68/225 [00:47<01:47,  1.46it/s]

torch.Size([128, 2, 2])


 31%|███       | 69/225 [00:48<01:46,  1.46it/s]

torch.Size([128, 2, 2])


 31%|███       | 70/225 [00:49<01:45,  1.47it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 71/225 [00:49<01:44,  1.47it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 72/225 [00:50<01:44,  1.46it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 73/225 [00:51<01:44,  1.46it/s]

torch.Size([128, 2, 2])


 33%|███▎      | 74/225 [00:52<01:45,  1.43it/s]

torch.Size([128, 2, 2])


 33%|███▎      | 75/225 [00:52<01:44,  1.43it/s]

torch.Size([128, 2, 2])


 34%|███▍      | 76/225 [00:53<01:43,  1.44it/s]

torch.Size([128, 2, 2])


 34%|███▍      | 77/225 [00:54<01:42,  1.44it/s]

torch.Size([128, 2, 2])


 35%|███▍      | 78/225 [00:54<01:41,  1.45it/s]

torch.Size([128, 2, 2])


 35%|███▌      | 79/225 [00:55<01:40,  1.45it/s]

torch.Size([128, 2, 2])


 36%|███▌      | 80/225 [00:56<01:40,  1.45it/s]

torch.Size([128, 2, 2])


 36%|███▌      | 81/225 [00:56<01:39,  1.45it/s]

torch.Size([128, 2, 2])


 36%|███▋      | 82/225 [00:57<01:38,  1.45it/s]

torch.Size([128, 2, 2])


 37%|███▋      | 83/225 [00:58<01:39,  1.43it/s]

torch.Size([128, 2, 2])


 37%|███▋      | 84/225 [00:58<01:39,  1.42it/s]

torch.Size([128, 2, 2])


 38%|███▊      | 85/225 [00:59<01:38,  1.42it/s]

torch.Size([128, 2, 2])


 38%|███▊      | 86/225 [01:00<01:38,  1.42it/s]

torch.Size([128, 2, 2])


 39%|███▊      | 87/225 [01:01<01:37,  1.41it/s]

torch.Size([128, 2, 2])


 39%|███▉      | 88/225 [01:01<01:37,  1.41it/s]

torch.Size([128, 2, 2])


 40%|███▉      | 89/225 [01:02<01:36,  1.41it/s]

torch.Size([128, 2, 2])


 40%|████      | 90/225 [01:03<01:36,  1.40it/s]

torch.Size([128, 2, 2])


 40%|████      | 91/225 [01:03<01:35,  1.40it/s]

torch.Size([128, 2, 2])


 41%|████      | 92/225 [01:04<01:33,  1.42it/s]

torch.Size([128, 2, 2])


 41%|████▏     | 93/225 [01:05<01:32,  1.42it/s]

torch.Size([128, 2, 2])


 42%|████▏     | 94/225 [01:06<01:32,  1.42it/s]

torch.Size([128, 2, 2])


 42%|████▏     | 95/225 [01:06<01:30,  1.43it/s]

torch.Size([128, 2, 2])


 43%|████▎     | 96/225 [01:07<01:30,  1.43it/s]

torch.Size([128, 2, 2])


 43%|████▎     | 97/225 [01:08<01:29,  1.42it/s]

torch.Size([128, 2, 2])


 44%|████▎     | 98/225 [01:08<01:30,  1.41it/s]

torch.Size([128, 2, 2])


 44%|████▍     | 99/225 [01:09<01:29,  1.41it/s]

torch.Size([128, 2, 2])


 44%|████▍     | 100/225 [01:10<01:29,  1.39it/s]

torch.Size([128, 2, 2])


 45%|████▍     | 101/225 [01:11<01:29,  1.38it/s]

torch.Size([128, 2, 2])


 45%|████▌     | 102/225 [01:11<01:28,  1.39it/s]

torch.Size([128, 2, 2])


 46%|████▌     | 103/225 [01:12<01:26,  1.40it/s]

torch.Size([128, 2, 2])


 46%|████▌     | 104/225 [01:13<01:25,  1.42it/s]

torch.Size([128, 2, 2])


 47%|████▋     | 105/225 [01:13<01:24,  1.42it/s]

torch.Size([128, 2, 2])


 47%|████▋     | 106/225 [01:14<01:23,  1.43it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 107/225 [01:15<01:22,  1.43it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 108/225 [01:15<01:22,  1.42it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 109/225 [01:16<01:21,  1.42it/s]

torch.Size([128, 2, 2])


 49%|████▉     | 110/225 [01:17<01:21,  1.41it/s]

torch.Size([128, 2, 2])


 49%|████▉     | 111/225 [01:18<01:22,  1.38it/s]

torch.Size([128, 2, 2])


 50%|████▉     | 112/225 [01:18<01:21,  1.39it/s]

torch.Size([128, 2, 2])


 50%|█████     | 113/225 [01:19<01:20,  1.39it/s]

torch.Size([128, 2, 2])


 51%|█████     | 114/225 [01:20<01:19,  1.40it/s]

torch.Size([128, 2, 2])


 51%|█████     | 115/225 [01:20<01:18,  1.39it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 116/225 [01:21<01:17,  1.41it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 117/225 [01:22<01:15,  1.43it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 118/225 [01:23<01:14,  1.43it/s]

torch.Size([128, 2, 2])


 53%|█████▎    | 119/225 [01:23<01:13,  1.44it/s]

torch.Size([128, 2, 2])


 53%|█████▎    | 120/225 [01:24<01:12,  1.45it/s]

torch.Size([128, 2, 2])


 54%|█████▍    | 121/225 [01:25<01:12,  1.43it/s]

torch.Size([128, 2, 2])


 54%|█████▍    | 122/225 [01:25<01:12,  1.43it/s]

torch.Size([128, 2, 2])


 55%|█████▍    | 123/225 [01:26<01:11,  1.43it/s]

torch.Size([128, 2, 2])


 55%|█████▌    | 124/225 [01:27<01:11,  1.42it/s]

torch.Size([128, 2, 2])


 56%|█████▌    | 125/225 [01:27<01:10,  1.42it/s]

torch.Size([128, 2, 2])


 56%|█████▌    | 126/225 [01:28<01:09,  1.42it/s]

torch.Size([128, 2, 2])


 56%|█████▋    | 127/225 [01:29<01:08,  1.43it/s]

torch.Size([128, 2, 2])


 57%|█████▋    | 128/225 [01:30<01:08,  1.42it/s]

torch.Size([128, 2, 2])


 57%|█████▋    | 129/225 [01:30<01:07,  1.41it/s]

torch.Size([128, 2, 2])


 58%|█████▊    | 130/225 [01:31<01:06,  1.42it/s]

torch.Size([128, 2, 2])


 58%|█████▊    | 131/225 [01:32<01:07,  1.40it/s]

torch.Size([128, 2, 2])


 59%|█████▊    | 132/225 [01:32<01:07,  1.38it/s]

torch.Size([128, 2, 2])


 59%|█████▉    | 133/225 [01:33<01:06,  1.38it/s]

torch.Size([128, 2, 2])


 60%|█████▉    | 134/225 [01:34<01:06,  1.37it/s]

torch.Size([128, 2, 2])


 60%|██████    | 135/225 [01:35<01:06,  1.36it/s]

torch.Size([128, 2, 2])


 60%|██████    | 136/225 [01:35<01:04,  1.38it/s]

torch.Size([128, 2, 2])


 61%|██████    | 137/225 [01:36<01:03,  1.39it/s]

torch.Size([128, 2, 2])


 61%|██████▏   | 138/225 [01:37<01:01,  1.40it/s]

torch.Size([128, 2, 2])


 62%|██████▏   | 139/225 [01:37<01:01,  1.39it/s]

torch.Size([128, 2, 2])


 62%|██████▏   | 140/225 [01:38<01:00,  1.40it/s]

torch.Size([128, 2, 2])


 63%|██████▎   | 141/225 [01:39<00:59,  1.42it/s]

torch.Size([128, 2, 2])


 63%|██████▎   | 142/225 [01:40<00:58,  1.42it/s]

torch.Size([128, 2, 2])


 64%|██████▎   | 143/225 [01:40<00:58,  1.40it/s]

torch.Size([128, 2, 2])


 64%|██████▍   | 144/225 [01:41<00:57,  1.41it/s]

torch.Size([128, 2, 2])


 64%|██████▍   | 145/225 [01:42<00:57,  1.39it/s]

torch.Size([128, 2, 2])


 65%|██████▍   | 146/225 [01:42<00:56,  1.41it/s]

torch.Size([128, 2, 2])


 65%|██████▌   | 147/225 [01:43<00:54,  1.42it/s]

torch.Size([128, 2, 2])


 66%|██████▌   | 148/225 [01:44<00:53,  1.43it/s]

torch.Size([128, 2, 2])


 66%|██████▌   | 149/225 [01:45<00:52,  1.44it/s]

torch.Size([128, 2, 2])


 67%|██████▋   | 150/225 [01:45<00:51,  1.46it/s]

torch.Size([128, 2, 2])


 67%|██████▋   | 151/225 [01:46<00:51,  1.44it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 152/225 [01:47<00:50,  1.44it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 153/225 [01:47<00:49,  1.44it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 154/225 [01:48<00:49,  1.43it/s]

torch.Size([128, 2, 2])


 69%|██████▉   | 155/225 [01:49<00:49,  1.43it/s]

torch.Size([128, 2, 2])


 69%|██████▉   | 156/225 [01:49<00:49,  1.40it/s]

torch.Size([128, 2, 2])


 70%|██████▉   | 157/225 [01:50<00:48,  1.40it/s]

torch.Size([128, 2, 2])


 70%|███████   | 158/225 [01:51<00:47,  1.41it/s]

torch.Size([128, 2, 2])


 71%|███████   | 159/225 [01:52<00:46,  1.41it/s]

torch.Size([128, 2, 2])


 71%|███████   | 160/225 [01:52<00:45,  1.42it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 161/225 [01:53<00:44,  1.44it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 162/225 [01:54<00:44,  1.42it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 163/225 [01:54<00:44,  1.40it/s]

torch.Size([128, 2, 2])


 73%|███████▎  | 164/225 [01:55<00:43,  1.40it/s]

torch.Size([128, 2, 2])


 73%|███████▎  | 165/225 [01:56<00:42,  1.41it/s]

torch.Size([128, 2, 2])


 74%|███████▍  | 166/225 [01:57<00:42,  1.39it/s]

torch.Size([128, 2, 2])


 74%|███████▍  | 167/225 [01:57<00:41,  1.40it/s]

torch.Size([128, 2, 2])


 75%|███████▍  | 168/225 [01:58<00:40,  1.42it/s]

torch.Size([128, 2, 2])


 75%|███████▌  | 169/225 [01:59<00:39,  1.43it/s]

torch.Size([128, 2, 2])


 76%|███████▌  | 170/225 [01:59<00:38,  1.42it/s]

torch.Size([128, 2, 2])


 76%|███████▌  | 171/225 [02:00<00:37,  1.43it/s]

torch.Size([128, 2, 2])


 76%|███████▋  | 172/225 [02:01<00:37,  1.41it/s]

torch.Size([128, 2, 2])


 77%|███████▋  | 173/225 [02:01<00:37,  1.40it/s]

torch.Size([128, 2, 2])


 77%|███████▋  | 174/225 [02:02<00:36,  1.40it/s]

torch.Size([128, 2, 2])


 78%|███████▊  | 175/225 [02:03<00:35,  1.40it/s]

torch.Size([128, 2, 2])


 78%|███████▊  | 176/225 [02:04<00:35,  1.38it/s]

torch.Size([128, 2, 2])


 79%|███████▊  | 177/225 [02:04<00:34,  1.39it/s]

torch.Size([128, 2, 2])


 79%|███████▉  | 178/225 [02:05<00:33,  1.41it/s]

torch.Size([128, 2, 2])


 80%|███████▉  | 179/225 [02:06<00:32,  1.43it/s]

torch.Size([128, 2, 2])


 80%|████████  | 180/225 [02:06<00:31,  1.45it/s]

torch.Size([128, 2, 2])


 80%|████████  | 181/225 [02:07<00:30,  1.45it/s]

torch.Size([128, 2, 2])


 81%|████████  | 182/225 [02:08<00:29,  1.46it/s]

torch.Size([128, 2, 2])


 81%|████████▏ | 183/225 [02:08<00:29,  1.44it/s]

torch.Size([128, 2, 2])


 82%|████████▏ | 184/225 [02:09<00:28,  1.44it/s]

torch.Size([128, 2, 2])


 82%|████████▏ | 185/225 [02:10<00:27,  1.45it/s]

torch.Size([128, 2, 2])


 83%|████████▎ | 186/225 [02:11<00:27,  1.43it/s]

torch.Size([128, 2, 2])


 83%|████████▎ | 187/225 [02:11<00:26,  1.42it/s]

torch.Size([128, 2, 2])


 84%|████████▎ | 188/225 [02:12<00:25,  1.43it/s]

torch.Size([128, 2, 2])


 84%|████████▍ | 189/225 [02:13<00:25,  1.42it/s]

torch.Size([128, 2, 2])


 84%|████████▍ | 190/225 [02:13<00:24,  1.41it/s]

torch.Size([128, 2, 2])


 85%|████████▍ | 191/225 [02:14<00:23,  1.42it/s]

torch.Size([128, 2, 2])


 85%|████████▌ | 192/225 [02:15<00:23,  1.41it/s]

torch.Size([128, 2, 2])


 86%|████████▌ | 193/225 [02:16<00:22,  1.41it/s]

torch.Size([128, 2, 2])


 86%|████████▌ | 194/225 [02:16<00:21,  1.42it/s]

torch.Size([128, 2, 2])


 87%|████████▋ | 195/225 [02:17<00:20,  1.44it/s]

torch.Size([128, 2, 2])


 87%|████████▋ | 196/225 [02:18<00:20,  1.44it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 197/225 [02:18<00:19,  1.45it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 198/225 [02:19<00:18,  1.45it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 199/225 [02:20<00:18,  1.44it/s]

torch.Size([128, 2, 2])


 89%|████████▉ | 200/225 [02:20<00:17,  1.40it/s]

torch.Size([128, 2, 2])


 89%|████████▉ | 201/225 [02:21<00:17,  1.41it/s]

torch.Size([128, 2, 2])


 90%|████████▉ | 202/225 [02:22<00:16,  1.41it/s]

torch.Size([128, 2, 2])


 90%|█████████ | 203/225 [02:23<00:15,  1.42it/s]

torch.Size([128, 2, 2])


 91%|█████████ | 204/225 [02:23<00:14,  1.44it/s]

torch.Size([128, 2, 2])


 91%|█████████ | 205/225 [02:24<00:13,  1.45it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 206/225 [02:25<00:13,  1.44it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 207/225 [02:25<00:12,  1.44it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 208/225 [02:26<00:11,  1.44it/s]

torch.Size([128, 2, 2])


 93%|█████████▎| 209/225 [02:27<00:11,  1.41it/s]

torch.Size([128, 2, 2])


 93%|█████████▎| 210/225 [02:27<00:10,  1.39it/s]

torch.Size([128, 2, 2])


 94%|█████████▍| 211/225 [02:28<00:09,  1.41it/s]

torch.Size([128, 2, 2])


 94%|█████████▍| 212/225 [02:29<00:09,  1.42it/s]

torch.Size([128, 2, 2])


 95%|█████████▍| 213/225 [02:30<00:08,  1.42it/s]

torch.Size([128, 2, 2])


 95%|█████████▌| 214/225 [02:30<00:07,  1.44it/s]

torch.Size([128, 2, 2])


 96%|█████████▌| 215/225 [02:31<00:06,  1.44it/s]

torch.Size([128, 2, 2])


 96%|█████████▌| 216/225 [02:32<00:06,  1.43it/s]

torch.Size([128, 2, 2])


 96%|█████████▋| 217/225 [02:32<00:05,  1.42it/s]

torch.Size([128, 2, 2])


 97%|█████████▋| 218/225 [02:33<00:04,  1.43it/s]

torch.Size([128, 2, 2])


 97%|█████████▋| 219/225 [02:34<00:04,  1.42it/s]

torch.Size([128, 2, 2])


 98%|█████████▊| 220/225 [02:34<00:03,  1.41it/s]

torch.Size([128, 2, 2])


 98%|█████████▊| 221/225 [02:35<00:02,  1.41it/s]

torch.Size([128, 2, 2])


 99%|█████████▊| 222/225 [02:36<00:02,  1.41it/s]

torch.Size([128, 2, 2])


 99%|█████████▉| 223/225 [02:37<00:01,  1.41it/s]

torch.Size([128, 2, 2])


100%|█████████▉| 224/225 [02:37<00:00,  1.43it/s]

torch.Size([128, 2, 2])


100%|██████████| 225/225 [02:38<00:00,  1.42it/s]
  1%|          | 19/2494 [00:00<00:13, 188.42it/s]

tensor(0.2226, dtype=torch.float64, grad_fn=<MseLossBackward>)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5921, 0.4079], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3941, 0.6059], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6395, 0.3605], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3096, 0.6904], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.2987, 0.7013], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4233, 0.5767], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4688, 0.5312], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3889, 0.6111], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5510, 0.4490], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4180, 0.5820], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5853, 0.4147], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5929, 0.4071]

  3%|▎         | 78/2494 [00:00<00:10, 229.25it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.6511, 0.3489], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2573, 0.7427], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5024, 0.4976], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8062, 0.1938], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5146, 0.4854], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7312, 0.2688], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3478, 0.6522], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7211, 0.2789], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2970, 0.7030], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4762, 0.5238], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4619, 0.5381], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4015, 0.5985], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

  5%|▌         | 135/2494 [00:00<00:09, 248.87it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.3428, 0.6572], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4630, 0.5370], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4192, 0.5808], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4192, 0.5808], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3867, 0.6133], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5554, 0.4446], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6012, 0.3988], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5890, 0.4110], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7371, 0.2629], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4931, 0.5069], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6853, 0.3147], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5734, 0.4266], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

  8%|▊         | 192/2494 [00:00<00:08, 264.58it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.4851, 0.5149], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4205, 0.5795], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6724, 0.3276], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6126, 0.3874], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5798, 0.4202], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6523, 0.3477], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5365, 0.4635], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3983, 0.6017], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4293, 0.5707], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5736, 0.4264], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5867, 0.4133], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7385, 0.2615], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 10%|█         | 254/2494 [00:00<00:07, 284.12it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.5518, 0.4482], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7273, 0.2727], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2140, 0.7860], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4122, 0.5878], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5477, 0.4523], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4655, 0.5345], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8153, 0.1847], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6589, 0.3411], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4003, 0.5997], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3621, 0.6379], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4128, 0.5872], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3729, 0.6271], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 13%|█▎        | 316/2494 [00:01<00:07, 294.92it/s]

tensor([0.6225, 0.3775], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4675, 0.5325], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7713, 0.2287], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6881, 0.3119], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5666, 0.4334], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7241, 0.2759], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4164, 0.5836], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5823, 0.4177], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6204, 0.3796], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7727, 0.2273], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6174, 0.3826], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6989, 0.3011], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4457, 0.5543], dtype=torch.flo

 15%|█▌        | 377/2494 [00:01<00:07, 297.81it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5220, 0.4780], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5862, 0.4138], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4896, 0.5104], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6930, 0.3070], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5259, 0.4741], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5687, 0.4313], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7240, 0.2760], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5296, 0.4704], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5665, 0.4335], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4487, 0.5513], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7810, 0.2190], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5970, 0.4030], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 18%|█▊        | 437/2494 [00:01<00:07, 289.52it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5310, 0.4690], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3612, 0.6388], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6301, 0.3699], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4713, 0.5287], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.1602, 0.8398], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4077, 0.5923], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5630, 0.4370], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6704, 0.3296], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6286, 0.3714], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6378, 0.3622], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7571, 0.2429], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5747, 0.4253], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 20%|█▉        | 494/2494 [00:01<00:07, 276.60it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.7068, 0.2932], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7180, 0.2820], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5360, 0.4640], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3725, 0.6275], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4485, 0.5515], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6086, 0.3914], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3964, 0.6036], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3941, 0.6059], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4981, 0.5019], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6453, 0.3547], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4552, 0.5448], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6079, 0.3921], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 22%|██▏       | 554/2494 [00:01<00:06, 286.38it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.6495, 0.3505], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5558, 0.4442], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7110, 0.2890], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3746, 0.6254], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8421, 0.1579], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3828, 0.6172], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6552, 0.3448], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5196, 0.4804], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3293, 0.6707], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2865, 0.7135], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5144, 0.4856], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2516, 0.7484], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 25%|██▍       | 614/2494 [00:02<00:06, 291.16it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.5713, 0.4287], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3431, 0.6569], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3686, 0.6314], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5385, 0.4615], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4909, 0.5091], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3446, 0.6554], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3422, 0.6578], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4391, 0.5609], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6903, 0.3097], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5496, 0.4504], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7724, 0.2276], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5251, 0.4749], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 27%|██▋       | 674/2494 [00:02<00:06, 282.67it/s]

tensor([0.2951, 0.7049], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7445, 0.2555], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3417, 0.6583], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4006, 0.5994], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7302, 0.2698], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5009, 0.4991], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5383, 0.4617], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5779, 0.4221], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5435, 0.4565], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3478, 0.6522], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7725, 0.2275], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7605, 0.2395], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6980, 0.3020], dtype=torch.flo

 29%|██▉       | 731/2494 [00:02<00:06, 266.25it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.7837, 0.2163], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6703, 0.3297], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3874, 0.6126], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5392, 0.4608], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5236, 0.4764], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7152, 0.2848], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5323, 0.4677], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8144, 0.1856], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5651, 0.4349], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5961, 0.4039], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3722, 0.6278], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4173, 0.5827], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 32%|███▏      | 791/2494 [00:02<00:06, 280.18it/s]

tensor(0) tensor([0.5851, 0.4149], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6510, 0.3490], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4067, 0.5933], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5058, 0.4942], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3697, 0.6303], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2718, 0.7282], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7185, 0.2815], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7845, 0.2155], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3779, 0.6221], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4186, 0.5814], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7696, 0.2304], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4931, 0.5069], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4827, 0.5173], dtype

 34%|███▍      | 852/2494 [00:03<00:05, 288.73it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.5800, 0.4200], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3656, 0.6344], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5997, 0.4003], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4468, 0.5532], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5846, 0.4154], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7252, 0.2748], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6431, 0.3569], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3303, 0.6697], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7015, 0.2985], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2580, 0.7420], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5621, 0.4379], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5789, 0.4211], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 37%|███▋      | 912/2494 [00:03<00:05, 291.67it/s]

tensor(0) tensor([0.5907, 0.4093], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4219, 0.5781], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4288, 0.5712], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4194, 0.5806], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7116, 0.2884], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3639, 0.6361], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3605, 0.6395], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5486, 0.4514], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5980, 0.4020], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4881, 0.5119], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6485, 0.3515], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3366, 0.6634], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5539, 0.4461], dtype

 39%|███▉      | 972/2494 [00:03<00:05, 294.23it/s]

tensor(0) tensor([0.6726, 0.3274], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3940, 0.6060], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7045, 0.2955], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5996, 0.4004], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5574, 0.4426], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5563, 0.4437], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7036, 0.2964], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6801, 0.3199], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4399, 0.5601], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4540, 0.5460], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.2449, 0.7551], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3897, 0.6103], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6887, 0.3113], dtype

 40%|████      | 1002/2494 [00:03<00:05, 290.81it/s]

tensor([0.3516, 0.6484], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6664, 0.3336], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4073, 0.5927], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6606, 0.3394], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6348, 0.3652], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4636, 0.5364], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4279, 0.5721], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6849, 0.3151], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6367, 0.3633], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5235, 0.4765], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4726, 0.5274], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4571, 0.5429], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3517, 0.6483], dtype=torch.flo

 43%|████▎     | 1060/2494 [00:03<00:05, 271.41it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6459, 0.3541], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4288, 0.5712], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6705, 0.3295], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5262, 0.4738], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7505, 0.2495], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5890, 0.4110], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5959, 0.4041], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3561, 0.6439], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7724, 0.2276], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7312, 0.2688], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6784, 0.3216], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.1817, 0.8183], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 45%|████▍     | 1119/2494 [00:03<00:04, 281.83it/s]

tensor(0) tensor([0.5088, 0.4912], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5331, 0.4669], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5866, 0.4134], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5791, 0.4209], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6463, 0.3537], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5554, 0.4446], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4726, 0.5274], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5056, 0.4944], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5325, 0.4675], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4713, 0.5287], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5665, 0.4335], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6711, 0.3289], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6609, 0.3391], dtype

 47%|████▋     | 1179/2494 [00:04<00:04, 289.35it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5169, 0.4831], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7704, 0.2296], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8142, 0.1858], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5106, 0.4894], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8167, 0.1833], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4602, 0.5398], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5484, 0.4516], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3334, 0.6666], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6241, 0.3759], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5217, 0.4783], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4380, 0.5620], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6015, 0.3985], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 50%|████▉     | 1239/2494 [00:04<00:04, 287.44it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.4758, 0.5242], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3997, 0.6003], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3709, 0.6291], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7226, 0.2774], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3593, 0.6407], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6067, 0.3933], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6826, 0.3174], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5996, 0.4004], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7544, 0.2456], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5161, 0.4839], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5115, 0.4885], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5117, 0.4883], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 52%|█████▏    | 1296/2494 [00:04<00:04, 271.12it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.6437, 0.3563], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7012, 0.2988], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4895, 0.5105], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4771, 0.5229], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7360, 0.2640], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4167, 0.5833], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3754, 0.6246], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4742, 0.5258], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6609, 0.3391], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6146, 0.3854], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5781, 0.4219], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5106, 0.4894], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 54%|█████▍    | 1356/2494 [00:04<00:04, 283.77it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5430, 0.4570], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3533, 0.6467], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4546, 0.5454], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5961, 0.4039], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7057, 0.2943], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4331, 0.5669], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5218, 0.4782], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7908, 0.2092], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7526, 0.2474], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3127, 0.6873], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3022, 0.6978], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4504, 0.5496], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 57%|█████▋    | 1415/2494 [00:04<00:03, 288.85it/s]

tensor([0.5879, 0.4121], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7549, 0.2451], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3630, 0.6370], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5554, 0.4446], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3909, 0.6091], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.1966, 0.8034], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7028, 0.2972], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4911, 0.5089], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7034, 0.2966], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7395, 0.2605], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7929, 0.2071], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6460, 0.3540], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5116, 0.4884], dtype=torch.flo

 59%|█████▉    | 1476/2494 [00:05<00:03, 293.44it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6410, 0.3590], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4930, 0.5070], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7035, 0.2965], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4588, 0.5412], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5810, 0.4190], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5992, 0.4008], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8357, 0.1643], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5626, 0.4374], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5912, 0.4088], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5602, 0.4398], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4186, 0.5814], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7122, 0.2878], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 62%|██████▏   | 1536/2494 [00:05<00:03, 294.10it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.3485, 0.6515], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5725, 0.4275], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3943, 0.6057], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5213, 0.4787], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5899, 0.4101], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5780, 0.4220], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4552, 0.5448], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5752, 0.4248], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5242, 0.4758], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5658, 0.4342], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5541, 0.4459], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5382, 0.4618], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 64%|██████▍   | 1596/2494 [00:05<00:03, 271.65it/s]

tensor([0.5223, 0.4777], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6737, 0.3263], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5676, 0.4324], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6419, 0.3581], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4094, 0.5906], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6005, 0.3995], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4630, 0.5370], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4137, 0.5863], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8821, 0.1179], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3546, 0.6454], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6374, 0.3626], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6178, 0.3822], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6461, 0.3539], dtype=torch.flo

 66%|██████▌   | 1652/2494 [00:05<00:03, 269.29it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6960, 0.3040], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8301, 0.1699], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6870, 0.3130], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7123, 0.2877], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3553, 0.6447], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5129, 0.4871], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5535, 0.4465], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5206, 0.4794], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5630, 0.4370], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5786, 0.4214], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6395, 0.3605], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6317, 0.3683], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 69%|██████▊   | 1713/2494 [00:06<00:02, 283.30it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6275, 0.3725], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8308, 0.1692], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6054, 0.3946], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3838, 0.6162], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7120, 0.2880], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6596, 0.3404], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5581, 0.4419], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5814, 0.4186], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6534, 0.3466], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6577, 0.3423], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3406, 0.6594], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3374, 0.6626], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 71%|███████   | 1773/2494 [00:06<00:02, 290.12it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.4517, 0.5483], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7590, 0.2410], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6404, 0.3596], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5668, 0.4332], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5169, 0.4831], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7092, 0.2908], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6117, 0.3883], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4738, 0.5262], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5681, 0.4319], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4880, 0.5120], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3531, 0.6469], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4570, 0.5430], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 72%|███████▏  | 1803/2494 [00:06<00:02, 285.34it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6661, 0.3339], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5342, 0.4658], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7400, 0.2600], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4343, 0.5657], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7399, 0.2601], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7464, 0.2536], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5780, 0.4220], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5225, 0.4775], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5786, 0.4214], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4775, 0.5225], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8534, 0.1466], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4510, 0.5490], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 75%|███████▍  | 1860/2494 [00:06<00:02, 269.69it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5720, 0.4280], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6235, 0.3765], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8132, 0.1868], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5802, 0.4198], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4952, 0.5048], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4723, 0.5277], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7244, 0.2756], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7972, 0.2028], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7210, 0.2790], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5817, 0.4183], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3556, 0.6444], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3952, 0.6048], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 77%|███████▋  | 1920/2494 [00:06<00:02, 282.26it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5152, 0.4848], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5868, 0.4132], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5822, 0.4178], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7165, 0.2835], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5105, 0.4895], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5023, 0.4977], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7177, 0.2823], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6819, 0.3181], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4259, 0.5741], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8024, 0.1976], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6409, 0.3591], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4153, 0.5847], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 79%|███████▉  | 1979/2494 [00:06<00:01, 288.45it/s]

tensor(0) tensor([0.4959, 0.5041], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5437, 0.4563], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3424, 0.6576], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6366, 0.3634], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5238, 0.4762], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5064, 0.4936], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6658, 0.3342], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4101, 0.5899], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4597, 0.5403], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6226, 0.3774], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3969, 0.6031], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3166, 0.6834], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4443, 0.5557], dtype

 82%|████████▏ | 2038/2494 [00:07<00:01, 291.25it/s]

tensor([0.4718, 0.5282], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7123, 0.2877], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.2980, 0.7020], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7302, 0.2698], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5001, 0.4999], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3725, 0.6275], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5600, 0.4400], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6295, 0.3705], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3791, 0.6209], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5379, 0.4621], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4569, 0.5431], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6609, 0.3391], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5179, 0.4821], dtype=torch.flo

 84%|████████▍ | 2098/2494 [00:07<00:01, 289.61it/s]

tensor(1) tensor([0.7537, 0.2463], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6062, 0.3938], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4591, 0.5409], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3538, 0.6462], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7636, 0.2364], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5246, 0.4754], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4536, 0.5464], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4223, 0.5777], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6604, 0.3396], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7859, 0.2141], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7606, 0.2394], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3510, 0.6490], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3954, 0.6046], dtype

 86%|████████▋ | 2156/2494 [00:07<00:01, 278.49it/s]

tensor([0.4627, 0.5373], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5948, 0.4052], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6395, 0.3605], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7245, 0.2755], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6756, 0.3244], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5344, 0.4656], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6309, 0.3691], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7099, 0.2901], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6489, 0.3511], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4607, 0.5393], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5009, 0.4991], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5355, 0.4645], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7421, 0.2579], dtype=torch.flo

 89%|████████▊ | 2211/2494 [00:07<00:01, 264.63it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.4375, 0.5625], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4126, 0.5874], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6422, 0.3578], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3786, 0.6214], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.2824, 0.7176], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5622, 0.4378], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7594, 0.2406], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7344, 0.2656], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5626, 0.4374], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4532, 0.5468], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5801, 0.4199], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6006, 0.3994], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 91%|█████████ | 2268/2494 [00:08<00:00, 274.39it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.4397, 0.5603], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2796, 0.7204], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5178, 0.4822], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6873, 0.3127], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6093, 0.3907], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6487, 0.3513], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5072, 0.4928], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5045, 0.4955], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3848, 0.6152], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6058, 0.3942], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5806, 0.4194], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8075, 0.1925], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 93%|█████████▎| 2325/2494 [00:08<00:00, 277.83it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.4099, 0.5901], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7037, 0.2963], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6049, 0.3951], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7875, 0.2125], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6306, 0.3694], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4284, 0.5716], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3732, 0.6268], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5884, 0.4116], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4597, 0.5403], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6131, 0.3869], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6987, 0.3013], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6839, 0.3161], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 96%|█████████▌| 2382/2494 [00:08<00:00, 279.51it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.3811, 0.6189], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7377, 0.2623], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6309, 0.3691], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5841, 0.4159], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4951, 0.5049], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5113, 0.4887], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8337, 0.1663], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6513, 0.3487], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7160, 0.2840], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3351, 0.6649], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4947, 0.5053], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3704, 0.6296], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 98%|█████████▊| 2439/2494 [00:08<00:00, 280.87it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6723, 0.3277], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2257, 0.7743], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5673, 0.4327], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4912, 0.5088], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7481, 0.2519], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5983, 0.4017], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3136, 0.6864], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4789, 0.5211], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3131, 0.6869], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4601, 0.5399], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4451, 0.5549], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3680, 0.6320], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

100%|██████████| 2494/2494 [00:08<00:00, 282.17it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.3562, 0.6438], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3905, 0.6095], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2413, 0.7587], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7107, 0.2893], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8210, 0.1790], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2903, 0.7097], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6252, 0.3748], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6805, 0.3195], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7136, 0.2864], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5140, 0.4860], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4741, 0.5259], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6498, 0.3502], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor




In [5]:
def train(net):
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    BATCH_SIZE = 100
    EPOCHS = 3
    for epoch in range(EPOCHS):
        for i in range(0, len(train_X), BATCH_SIZE): # from 0, to the len of x, stepping BATCH_SIZE at a time. [:50] ..for now just to dev
            #print(f"{i}:{i+BATCH_SIZE}")
            batch_X = train_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50)
            batch_y = train_y[i:i+BATCH_SIZE]

            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            net.zero_grad()

            optimizer.zero_grad()   # zero the gradient buffers
            outputs = net(batch_X)
            loss = loss_function(outputs, batch_y)
            loss.backward()
            optimizer.step()    # Does the update

        print(f"Epoch: {epoch}. Loss: {loss}")

train(net)

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

In [6]:
test_X.to(device)
test_y.to(device)

def test(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for i in tqdm(range(len(test_X))):
            real_class = torch.argmax(test_y[i]).to(device)
            net_out = net(test_X[i].view(-1, 1, 50, 50).to(device))[0]  # returns a list, 
            predicted_class = torch.argmax(net_out)

            if predicted_class == real_class:
                correct += 1
            total += 1

    print("Accuracy: ", round(correct/total, 3))

test(net)

  2%|▏         | 61/2494 [00:00<00:08, 271.52it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

  6%|▌         | 142/2494 [00:00<00:07, 325.24it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

  8%|▊         | 211/2494 [00:00<00:06, 333.48it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 12%|█▏        | 290/2494 [00:00<00:06, 360.55it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 15%|█▌        | 376/2494 [00:01<00:05, 392.03it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 19%|█▊        | 463/2494 [00:01<00:04, 411.15it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 22%|██▏       | 547/2494 [00:01<00:04, 412.95it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 25%|██▌       | 631/2494 [00:01<00:04, 414.66it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 29%|██▊       | 715/2494 [00:01<00:04, 412.74it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 30%|███       | 757/2494 [00:01<00:04, 412.21it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 34%|███▎      | 838/2494 [00:02<00:04, 371.00it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 37%|███▋      | 922/2494 [00:02<00:03, 393.62it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 40%|████      | 1007/2494 [00:02<00:03, 407.89it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 44%|████▎     | 1090/2494 [00:02<00:03, 405.70it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 47%|████▋     | 1170/2494 [00:03<00:03, 371.88it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 50%|█████     | 1248/2494 [00:03<00:03, 371.08it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 53%|█████▎    | 1334/2494 [00:03<00:02, 397.13it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 57%|█████▋    | 1419/2494 [00:03<00:02, 408.55it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 60%|██████    | 1502/2494 [00:03<00:02, 406.78it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 64%|██████▎   | 1584/2494 [00:04<00:02, 398.65it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 65%|██████▌   | 1624/2494 [00:04<00:02, 377.56it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 68%|██████▊   | 1707/2494 [00:04<00:02, 392.16it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 72%|███████▏  | 1793/2494 [00:04<00:01, 407.78it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 75%|███████▌  | 1875/2494 [00:04<00:01, 405.50it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 79%|███████▊  | 1959/2494 [00:04<00:01, 410.79it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 82%|████████▏ | 2044/2494 [00:05<00:01, 410.08it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 85%|████████▌ | 2130/2494 [00:05<00:00, 412.67it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 89%|████████▉ | 2215/2494 [00:05<00:00, 417.16it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 92%|█████████▏| 2300/2494 [00:05<00:00, 419.84it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 96%|█████████▌| 2385/2494 [00:05<00:00, 417.74it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

 99%|█████████▉| 2467/2494 [00:06<00:00, 379.05it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128,

100%|██████████| 2494/2494 [00:06<00:00, 395.82it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
Accuracy:  0.757





In [7]:
correct = 0
total = 0
for i in tqdm(range(0, len(test_X), BATCH_SIZE)):

    batch_X = test_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50).to(device)
    batch_y = test_y[i:i+BATCH_SIZE].to(device)
    batch_out = net(batch_X)

    out_maxes = [torch.argmax(i) for i in batch_out]
    target_maxes = [torch.argmax(i) for i in batch_y]
    for i,j in zip(out_maxes, target_maxes):
        if i == j:
            correct += 1
        total += 1
print("Accuracy: ", round(correct/total, 3))

  4%|▍         | 1/25 [00:00<00:08,  2.88it/s]

torch.Size([128, 2, 2])


  8%|▊         | 2/25 [00:00<00:08,  2.81it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 3/25 [00:01<00:07,  2.89it/s]

torch.Size([128, 2, 2])


 16%|█▌        | 4/25 [00:01<00:06,  3.02it/s]

torch.Size([128, 2, 2])


 20%|██        | 5/25 [00:01<00:06,  3.12it/s]

torch.Size([128, 2, 2])


 24%|██▍       | 6/25 [00:01<00:05,  3.19it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 7/25 [00:02<00:05,  3.21it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 8/25 [00:02<00:05,  3.22it/s]

torch.Size([128, 2, 2])


 36%|███▌      | 9/25 [00:02<00:04,  3.26it/s]

torch.Size([128, 2, 2])


 40%|████      | 10/25 [00:03<00:04,  3.25it/s]

torch.Size([128, 2, 2])


 44%|████▍     | 11/25 [00:03<00:04,  3.19it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 12/25 [00:03<00:03,  3.26it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 13/25 [00:04<00:03,  3.28it/s]

torch.Size([128, 2, 2])


 56%|█████▌    | 14/25 [00:04<00:03,  3.32it/s]

torch.Size([128, 2, 2])


 60%|██████    | 15/25 [00:04<00:03,  3.30it/s]

torch.Size([128, 2, 2])


 64%|██████▍   | 16/25 [00:04<00:02,  3.30it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 17/25 [00:05<00:02,  3.20it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 18/25 [00:05<00:02,  3.26it/s]

torch.Size([128, 2, 2])


 76%|███████▌  | 19/25 [00:05<00:01,  3.31it/s]

torch.Size([128, 2, 2])


 80%|████████  | 20/25 [00:06<00:01,  3.31it/s]

torch.Size([128, 2, 2])


 84%|████████▍ | 21/25 [00:06<00:01,  3.33it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 22/25 [00:06<00:00,  3.35it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 23/25 [00:07<00:00,  3.24it/s]

torch.Size([128, 2, 2])


 96%|█████████▌| 24/25 [00:07<00:00,  3.28it/s]

torch.Size([128, 2, 2])


100%|██████████| 25/25 [00:07<00:00,  3.25it/s]

torch.Size([128, 2, 2])
Accuracy:  0.757



