In [1]:
import os
import cv2
import numpy as np
from tqdm import tqdm

REBUILD_DATA = False

class DogsVSCats():
    IMG_SIZE = 50
    CATS = 'PetImages/Cat'
    DOGS = 'PetImages/Dog'
    LABELS = {CATS: 0, DOGS: 1}
    training_data = []
    catcount = 0
    dogcount = 0
    
    def make_training_data(self):
        for label in self.LABELS:
            for f in tqdm(os.listdir(label)):
                try:
                    path = os.path.join(label, f)
                    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
                    img = cv2.resize(img, (self.IMG_SIZE, self.IMG_SIZE))
                    self.training_data.append([np.array(img), np.eye(2)[self.LABELS[label]]])
                    if label == self.CATS:
                        self.catcount += 1
                    elif label == self.DOGS:
                        self.dogcount += 1
                except Exception as e:
                    pass
                
        np.random.shuffle(self.training_data)
        np.save('training_data.npy', self.training_data)
        print('Cats:', self.catcount)
        print('Dogs:', self.dogcount)
        
if REBUILD_DATA:
    dogsvcats = DogsVSCats()
    dogsvcats.make_training_data()
    
training_data = np.load('training_data.npy', allow_pickle = True)
print(len(training_data))

24946


In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)
        
        x = torch.randn(50, 50).view(-1, 1, 50, 50)
        self._to_linear = None
        self.convs(x)
        self.fc1 = nn.Linear(self._to_linear, 512)
        self.fc2 = nn.Linear(512, 2)
        
    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
        
        print(x[0].shape)
        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
#         print(x)
        return x
    
    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim = 1)
    
net = Net()
net = net.double()

torch.Size([128, 2, 2])


In [32]:
import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr = 0.001)
loss_function = nn.MSELoss()

X = torch.tensor([i[0] for i in training_data]).view(-1, 50, 50)
X = X/255.0
y = torch.tensor([i[1] for i in training_data])

VAL_PCT = 0.1
val_size = int(len(X)*VAL_PCT)
print(val_size)

2494


In [36]:
train_X = X[:-val_size].double()
train_y = y[:-val_size].double()

test_X = X[-val_size:].double()
test_y = y[-val_size:].double()

print(len(train_X))
print(len(test_X))

22452
2494


In [34]:
BATCH_SIZE = 100
EPOCHS = 1

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH_SIZE)):
        batch_X = train_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50)
        batch_y = train_y[i:i+BATCH_SIZE]
        net.zero_grad()
        outputs = net(batch_X)
        loss = loss_function(outputs, batch_y)
        loss.backward()
        optimizer.step()
print(loss)

  0%|          | 0/225 [00:00<?, ?it/s]

torch.Size([128, 2, 2])


  0%|          | 1/225 [00:00<02:52,  1.30it/s]

torch.Size([128, 2, 2])


  1%|          | 2/225 [00:01<02:44,  1.35it/s]

torch.Size([128, 2, 2])


  1%|▏         | 3/225 [00:02<02:37,  1.41it/s]

torch.Size([128, 2, 2])


  2%|▏         | 4/225 [00:02<02:32,  1.45it/s]

torch.Size([128, 2, 2])


  2%|▏         | 5/225 [00:03<02:27,  1.49it/s]

torch.Size([128, 2, 2])


  3%|▎         | 6/225 [00:03<02:24,  1.51it/s]

torch.Size([128, 2, 2])


  3%|▎         | 7/225 [00:04<02:23,  1.52it/s]

torch.Size([128, 2, 2])


  4%|▎         | 8/225 [00:05<02:22,  1.53it/s]

torch.Size([128, 2, 2])


  4%|▍         | 9/225 [00:05<02:20,  1.53it/s]

torch.Size([128, 2, 2])


  4%|▍         | 10/225 [00:06<02:20,  1.53it/s]

torch.Size([128, 2, 2])


  5%|▍         | 11/225 [00:07<02:20,  1.52it/s]

torch.Size([128, 2, 2])


  5%|▌         | 12/225 [00:07<02:18,  1.53it/s]

torch.Size([128, 2, 2])


  6%|▌         | 13/225 [00:08<02:18,  1.53it/s]

torch.Size([128, 2, 2])


  6%|▌         | 14/225 [00:09<02:16,  1.54it/s]

torch.Size([128, 2, 2])


  7%|▋         | 15/225 [00:09<02:16,  1.54it/s]

torch.Size([128, 2, 2])


  7%|▋         | 16/225 [00:10<02:18,  1.51it/s]

torch.Size([128, 2, 2])


  8%|▊         | 17/225 [00:11<02:18,  1.50it/s]

torch.Size([128, 2, 2])


  8%|▊         | 18/225 [00:11<02:18,  1.50it/s]

torch.Size([128, 2, 2])


  8%|▊         | 19/225 [00:12<02:18,  1.49it/s]

torch.Size([128, 2, 2])


  9%|▉         | 20/225 [00:13<02:18,  1.48it/s]

torch.Size([128, 2, 2])


  9%|▉         | 21/225 [00:13<02:18,  1.47it/s]

torch.Size([128, 2, 2])


 10%|▉         | 22/225 [00:14<02:18,  1.47it/s]

torch.Size([128, 2, 2])


 10%|█         | 23/225 [00:15<02:18,  1.46it/s]

torch.Size([128, 2, 2])


 11%|█         | 24/225 [00:15<02:17,  1.46it/s]

torch.Size([128, 2, 2])


 11%|█         | 25/225 [00:16<02:17,  1.45it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 26/225 [00:17<02:16,  1.46it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 27/225 [00:18<02:15,  1.46it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 28/225 [00:18<02:13,  1.48it/s]

torch.Size([128, 2, 2])


 13%|█▎        | 29/225 [00:19<02:11,  1.49it/s]

torch.Size([128, 2, 2])


 13%|█▎        | 30/225 [00:20<02:09,  1.51it/s]

torch.Size([128, 2, 2])


 14%|█▍        | 31/225 [00:20<02:08,  1.51it/s]

torch.Size([128, 2, 2])


 14%|█▍        | 32/225 [00:21<02:07,  1.52it/s]

torch.Size([128, 2, 2])


 15%|█▍        | 33/225 [00:21<02:06,  1.52it/s]

torch.Size([128, 2, 2])


 15%|█▌        | 34/225 [00:22<02:04,  1.53it/s]

torch.Size([128, 2, 2])


 16%|█▌        | 35/225 [00:23<02:07,  1.49it/s]

torch.Size([128, 2, 2])


 16%|█▌        | 36/225 [00:24<02:11,  1.44it/s]

torch.Size([128, 2, 2])


 16%|█▋        | 37/225 [00:24<02:10,  1.44it/s]

torch.Size([128, 2, 2])


 17%|█▋        | 38/225 [00:25<02:08,  1.45it/s]

torch.Size([128, 2, 2])


 17%|█▋        | 39/225 [00:26<02:11,  1.41it/s]

torch.Size([128, 2, 2])


 18%|█▊        | 40/225 [00:26<02:13,  1.39it/s]

torch.Size([128, 2, 2])


 18%|█▊        | 41/225 [00:27<02:17,  1.34it/s]

torch.Size([128, 2, 2])


 19%|█▊        | 42/225 [00:28<02:21,  1.30it/s]

torch.Size([128, 2, 2])


 19%|█▉        | 43/225 [00:29<02:26,  1.24it/s]

torch.Size([128, 2, 2])


 20%|█▉        | 44/225 [00:30<02:26,  1.24it/s]

torch.Size([128, 2, 2])


 20%|██        | 45/225 [00:31<02:21,  1.27it/s]

torch.Size([128, 2, 2])


 20%|██        | 46/225 [00:31<02:17,  1.30it/s]

torch.Size([128, 2, 2])


 21%|██        | 47/225 [00:32<02:13,  1.33it/s]

torch.Size([128, 2, 2])


 21%|██▏       | 48/225 [00:33<02:09,  1.37it/s]

torch.Size([128, 2, 2])


 22%|██▏       | 49/225 [00:33<02:07,  1.38it/s]

torch.Size([128, 2, 2])


 22%|██▏       | 50/225 [00:34<02:04,  1.40it/s]

torch.Size([128, 2, 2])


 23%|██▎       | 51/225 [00:35<02:02,  1.42it/s]

torch.Size([128, 2, 2])


 23%|██▎       | 52/225 [00:35<02:00,  1.44it/s]

torch.Size([128, 2, 2])


 24%|██▎       | 53/225 [00:36<01:59,  1.44it/s]

torch.Size([128, 2, 2])


 24%|██▍       | 54/225 [00:37<01:58,  1.45it/s]

torch.Size([128, 2, 2])


 24%|██▍       | 55/225 [00:37<01:57,  1.45it/s]

torch.Size([128, 2, 2])


 25%|██▍       | 56/225 [00:38<01:57,  1.44it/s]

torch.Size([128, 2, 2])


 25%|██▌       | 57/225 [00:39<01:57,  1.44it/s]

torch.Size([128, 2, 2])


 26%|██▌       | 58/225 [00:40<01:56,  1.44it/s]

torch.Size([128, 2, 2])


 26%|██▌       | 59/225 [00:40<01:55,  1.44it/s]

torch.Size([128, 2, 2])


 27%|██▋       | 60/225 [00:41<01:54,  1.44it/s]

torch.Size([128, 2, 2])


 27%|██▋       | 61/225 [00:42<01:53,  1.44it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 62/225 [00:42<01:55,  1.41it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 63/225 [00:43<01:56,  1.39it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 64/225 [00:44<01:57,  1.37it/s]

torch.Size([128, 2, 2])


 29%|██▉       | 65/225 [00:45<01:58,  1.36it/s]

torch.Size([128, 2, 2])


 29%|██▉       | 66/225 [00:45<01:58,  1.34it/s]

torch.Size([128, 2, 2])


 30%|██▉       | 67/225 [00:46<01:55,  1.37it/s]

torch.Size([128, 2, 2])


 30%|███       | 68/225 [00:47<01:52,  1.39it/s]

torch.Size([128, 2, 2])


 31%|███       | 69/225 [00:47<01:51,  1.40it/s]

torch.Size([128, 2, 2])


 31%|███       | 70/225 [00:48<01:49,  1.41it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 71/225 [00:49<01:54,  1.35it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 72/225 [00:50<01:56,  1.31it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 73/225 [00:51<01:53,  1.34it/s]

torch.Size([128, 2, 2])


 33%|███▎      | 74/225 [00:51<01:51,  1.36it/s]

torch.Size([128, 2, 2])


 33%|███▎      | 75/225 [00:52<01:51,  1.35it/s]

torch.Size([128, 2, 2])


 34%|███▍      | 76/225 [00:53<01:50,  1.34it/s]

torch.Size([128, 2, 2])


 34%|███▍      | 77/225 [00:53<01:48,  1.36it/s]

torch.Size([128, 2, 2])


 35%|███▍      | 78/225 [00:54<01:46,  1.38it/s]

torch.Size([128, 2, 2])


 35%|███▌      | 79/225 [00:55<01:45,  1.38it/s]

torch.Size([128, 2, 2])


 36%|███▌      | 80/225 [00:56<01:43,  1.40it/s]

torch.Size([128, 2, 2])


 36%|███▌      | 81/225 [00:56<01:41,  1.42it/s]

torch.Size([128, 2, 2])


 36%|███▋      | 82/225 [00:57<01:45,  1.36it/s]

torch.Size([128, 2, 2])


 37%|███▋      | 83/225 [00:58<01:48,  1.31it/s]

torch.Size([128, 2, 2])


 37%|███▋      | 84/225 [00:59<01:48,  1.30it/s]

torch.Size([128, 2, 2])


 38%|███▊      | 85/225 [00:59<01:45,  1.32it/s]

torch.Size([128, 2, 2])


 38%|███▊      | 86/225 [01:00<01:43,  1.34it/s]

torch.Size([128, 2, 2])


 39%|███▊      | 87/225 [01:01<01:41,  1.37it/s]

torch.Size([128, 2, 2])


 39%|███▉      | 88/225 [01:01<01:38,  1.40it/s]

torch.Size([128, 2, 2])


 40%|███▉      | 89/225 [01:02<01:35,  1.42it/s]

torch.Size([128, 2, 2])


 40%|████      | 90/225 [01:03<01:34,  1.43it/s]

torch.Size([128, 2, 2])


 40%|████      | 91/225 [01:04<01:33,  1.43it/s]

torch.Size([128, 2, 2])


 41%|████      | 92/225 [01:04<01:33,  1.43it/s]

torch.Size([128, 2, 2])


 41%|████▏     | 93/225 [01:05<01:32,  1.43it/s]

torch.Size([128, 2, 2])


 42%|████▏     | 94/225 [01:06<01:31,  1.43it/s]

torch.Size([128, 2, 2])


 42%|████▏     | 95/225 [01:06<01:31,  1.43it/s]

torch.Size([128, 2, 2])


 43%|████▎     | 96/225 [01:07<01:30,  1.43it/s]

torch.Size([128, 2, 2])


 43%|████▎     | 97/225 [01:08<01:31,  1.40it/s]

torch.Size([128, 2, 2])


 44%|████▎     | 98/225 [01:08<01:29,  1.41it/s]

torch.Size([128, 2, 2])


 44%|████▍     | 99/225 [01:09<01:28,  1.42it/s]

torch.Size([128, 2, 2])


 44%|████▍     | 100/225 [01:10<01:28,  1.42it/s]

torch.Size([128, 2, 2])


 45%|████▍     | 101/225 [01:11<01:26,  1.43it/s]

torch.Size([128, 2, 2])


 45%|████▌     | 102/225 [01:11<01:25,  1.45it/s]

torch.Size([128, 2, 2])


 46%|████▌     | 103/225 [01:12<01:24,  1.44it/s]

torch.Size([128, 2, 2])


 46%|████▌     | 104/225 [01:13<01:23,  1.45it/s]

torch.Size([128, 2, 2])


 47%|████▋     | 105/225 [01:13<01:22,  1.45it/s]

torch.Size([128, 2, 2])


 47%|████▋     | 106/225 [01:14<01:22,  1.44it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 107/225 [01:15<01:22,  1.43it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 108/225 [01:15<01:22,  1.42it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 109/225 [01:16<01:21,  1.43it/s]

torch.Size([128, 2, 2])


 49%|████▉     | 110/225 [01:17<01:20,  1.44it/s]

torch.Size([128, 2, 2])


 49%|████▉     | 111/225 [01:18<01:19,  1.44it/s]

torch.Size([128, 2, 2])


 50%|████▉     | 112/225 [01:18<01:17,  1.45it/s]

torch.Size([128, 2, 2])


 50%|█████     | 113/225 [01:19<01:17,  1.45it/s]

torch.Size([128, 2, 2])


 51%|█████     | 114/225 [01:20<01:16,  1.45it/s]

torch.Size([128, 2, 2])


 51%|█████     | 115/225 [01:20<01:16,  1.44it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 116/225 [01:21<01:16,  1.43it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 117/225 [01:22<01:16,  1.42it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 118/225 [01:22<01:16,  1.40it/s]

torch.Size([128, 2, 2])


 53%|█████▎    | 119/225 [01:23<01:16,  1.38it/s]

torch.Size([128, 2, 2])


 53%|█████▎    | 120/225 [01:24<01:15,  1.39it/s]

torch.Size([128, 2, 2])


 54%|█████▍    | 121/225 [01:25<01:13,  1.41it/s]

torch.Size([128, 2, 2])


 54%|█████▍    | 122/225 [01:25<01:12,  1.41it/s]

torch.Size([128, 2, 2])


 55%|█████▍    | 123/225 [01:26<01:11,  1.42it/s]

torch.Size([128, 2, 2])


 55%|█████▌    | 124/225 [01:27<01:10,  1.43it/s]

torch.Size([128, 2, 2])


 56%|█████▌    | 125/225 [01:27<01:09,  1.43it/s]

torch.Size([128, 2, 2])


 56%|█████▌    | 126/225 [01:28<01:09,  1.43it/s]

torch.Size([128, 2, 2])


 56%|█████▋    | 127/225 [01:29<01:08,  1.43it/s]

torch.Size([128, 2, 2])


 57%|█████▋    | 128/225 [01:29<01:07,  1.44it/s]

torch.Size([128, 2, 2])


 57%|█████▋    | 129/225 [01:30<01:06,  1.44it/s]

torch.Size([128, 2, 2])


 58%|█████▊    | 130/225 [01:31<01:06,  1.44it/s]

torch.Size([128, 2, 2])


 58%|█████▊    | 131/225 [01:32<01:05,  1.44it/s]

torch.Size([128, 2, 2])


 59%|█████▊    | 132/225 [01:32<01:04,  1.44it/s]

torch.Size([128, 2, 2])


 59%|█████▉    | 133/225 [01:33<01:03,  1.44it/s]

torch.Size([128, 2, 2])


 60%|█████▉    | 134/225 [01:34<01:03,  1.44it/s]

torch.Size([128, 2, 2])


 60%|██████    | 135/225 [01:34<01:02,  1.44it/s]

torch.Size([128, 2, 2])


 60%|██████    | 136/225 [01:35<01:02,  1.43it/s]

torch.Size([128, 2, 2])


 61%|██████    | 137/225 [01:36<01:01,  1.43it/s]

torch.Size([128, 2, 2])


 61%|██████▏   | 138/225 [01:36<01:00,  1.43it/s]

torch.Size([128, 2, 2])


 62%|██████▏   | 139/225 [01:37<01:00,  1.43it/s]

torch.Size([128, 2, 2])


 62%|██████▏   | 140/225 [01:38<00:59,  1.44it/s]

torch.Size([128, 2, 2])


 63%|██████▎   | 141/225 [01:39<00:58,  1.44it/s]

torch.Size([128, 2, 2])


 63%|██████▎   | 142/225 [01:39<00:58,  1.41it/s]

torch.Size([128, 2, 2])


 64%|██████▎   | 143/225 [01:40<00:57,  1.42it/s]

torch.Size([128, 2, 2])


 64%|██████▍   | 144/225 [01:41<00:57,  1.42it/s]

torch.Size([128, 2, 2])


 64%|██████▍   | 145/225 [01:41<00:56,  1.42it/s]

torch.Size([128, 2, 2])


 65%|██████▍   | 146/225 [01:42<00:55,  1.41it/s]

torch.Size([128, 2, 2])


 65%|██████▌   | 147/225 [01:43<00:55,  1.41it/s]

torch.Size([128, 2, 2])


 66%|██████▌   | 148/225 [01:43<00:54,  1.42it/s]

torch.Size([128, 2, 2])


 66%|██████▌   | 149/225 [01:44<00:53,  1.43it/s]

torch.Size([128, 2, 2])


 67%|██████▋   | 150/225 [01:45<00:52,  1.44it/s]

torch.Size([128, 2, 2])


 67%|██████▋   | 151/225 [01:46<00:51,  1.44it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 152/225 [01:46<00:50,  1.44it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 153/225 [01:47<00:50,  1.44it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 154/225 [01:48<00:49,  1.43it/s]

torch.Size([128, 2, 2])


 69%|██████▉   | 155/225 [01:48<00:49,  1.42it/s]

torch.Size([128, 2, 2])


 69%|██████▉   | 156/225 [01:49<00:48,  1.41it/s]

torch.Size([128, 2, 2])


 70%|██████▉   | 157/225 [01:50<00:48,  1.41it/s]

torch.Size([128, 2, 2])


 70%|███████   | 158/225 [01:50<00:47,  1.42it/s]

torch.Size([128, 2, 2])


 71%|███████   | 159/225 [01:51<00:46,  1.43it/s]

torch.Size([128, 2, 2])


 71%|███████   | 160/225 [01:52<00:45,  1.44it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 161/225 [01:53<00:44,  1.44it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 162/225 [01:53<00:44,  1.43it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 163/225 [01:54<00:43,  1.42it/s]

torch.Size([128, 2, 2])


 73%|███████▎  | 164/225 [01:55<00:43,  1.41it/s]

torch.Size([128, 2, 2])


 73%|███████▎  | 165/225 [01:55<00:42,  1.41it/s]

torch.Size([128, 2, 2])


 74%|███████▍  | 166/225 [01:56<00:41,  1.41it/s]

torch.Size([128, 2, 2])


 74%|███████▍  | 167/225 [01:57<00:40,  1.42it/s]

torch.Size([128, 2, 2])


 75%|███████▍  | 168/225 [01:57<00:39,  1.43it/s]

torch.Size([128, 2, 2])


 75%|███████▌  | 169/225 [01:58<00:38,  1.44it/s]

torch.Size([128, 2, 2])


 76%|███████▌  | 170/225 [01:59<00:37,  1.45it/s]

torch.Size([128, 2, 2])


 76%|███████▌  | 171/225 [02:00<00:37,  1.42it/s]

torch.Size([128, 2, 2])


 76%|███████▋  | 172/225 [02:00<00:38,  1.36it/s]

torch.Size([128, 2, 2])


 77%|███████▋  | 173/225 [02:01<00:38,  1.34it/s]

torch.Size([128, 2, 2])


 77%|███████▋  | 174/225 [02:02<00:37,  1.35it/s]

torch.Size([128, 2, 2])


 78%|███████▊  | 175/225 [02:03<00:37,  1.34it/s]

torch.Size([128, 2, 2])


 78%|███████▊  | 176/225 [02:03<00:36,  1.35it/s]

torch.Size([128, 2, 2])


 79%|███████▊  | 177/225 [02:04<00:35,  1.36it/s]

torch.Size([128, 2, 2])


 79%|███████▉  | 178/225 [02:05<00:34,  1.38it/s]

torch.Size([128, 2, 2])


 80%|███████▉  | 179/225 [02:06<00:33,  1.39it/s]

torch.Size([128, 2, 2])


 80%|████████  | 180/225 [02:06<00:31,  1.41it/s]

torch.Size([128, 2, 2])


 80%|████████  | 181/225 [02:07<00:31,  1.40it/s]

torch.Size([128, 2, 2])


 81%|████████  | 182/225 [02:08<00:30,  1.40it/s]

torch.Size([128, 2, 2])


 81%|████████▏ | 183/225 [02:08<00:29,  1.43it/s]

torch.Size([128, 2, 2])


 82%|████████▏ | 184/225 [02:09<00:28,  1.43it/s]

torch.Size([128, 2, 2])


 82%|████████▏ | 185/225 [02:10<00:28,  1.43it/s]

torch.Size([128, 2, 2])


 83%|████████▎ | 186/225 [02:10<00:27,  1.43it/s]

torch.Size([128, 2, 2])


 83%|████████▎ | 187/225 [02:11<00:26,  1.43it/s]

torch.Size([128, 2, 2])


 84%|████████▎ | 188/225 [02:12<00:25,  1.42it/s]

torch.Size([128, 2, 2])


 84%|████████▍ | 189/225 [02:12<00:25,  1.44it/s]

torch.Size([128, 2, 2])


 84%|████████▍ | 190/225 [02:13<00:24,  1.44it/s]

torch.Size([128, 2, 2])


 85%|████████▍ | 191/225 [02:14<00:23,  1.45it/s]

torch.Size([128, 2, 2])


 85%|████████▌ | 192/225 [02:15<00:22,  1.45it/s]

torch.Size([128, 2, 2])


 86%|████████▌ | 193/225 [02:15<00:21,  1.46it/s]

torch.Size([128, 2, 2])


 86%|████████▌ | 194/225 [02:16<00:22,  1.37it/s]

torch.Size([128, 2, 2])


 87%|████████▋ | 195/225 [02:17<00:21,  1.37it/s]

torch.Size([128, 2, 2])


 87%|████████▋ | 196/225 [02:18<00:20,  1.38it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 197/225 [02:18<00:20,  1.36it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 198/225 [02:19<00:19,  1.35it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 199/225 [02:20<00:19,  1.36it/s]

torch.Size([128, 2, 2])


 89%|████████▉ | 200/225 [02:20<00:18,  1.36it/s]

torch.Size([128, 2, 2])


 89%|████████▉ | 201/225 [02:21<00:17,  1.38it/s]

torch.Size([128, 2, 2])


 90%|████████▉ | 202/225 [02:22<00:16,  1.40it/s]

torch.Size([128, 2, 2])


 90%|█████████ | 203/225 [02:23<00:15,  1.41it/s]

torch.Size([128, 2, 2])


 91%|█████████ | 204/225 [02:23<00:14,  1.42it/s]

torch.Size([128, 2, 2])


 91%|█████████ | 205/225 [02:24<00:13,  1.43it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 206/225 [02:25<00:13,  1.44it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 207/225 [02:25<00:12,  1.45it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 208/225 [02:26<00:11,  1.45it/s]

torch.Size([128, 2, 2])


 93%|█████████▎| 209/225 [02:27<00:11,  1.45it/s]

torch.Size([128, 2, 2])


 93%|█████████▎| 210/225 [02:27<00:10,  1.45it/s]

torch.Size([128, 2, 2])


 94%|█████████▍| 211/225 [02:28<00:09,  1.45it/s]

torch.Size([128, 2, 2])


 94%|█████████▍| 212/225 [02:29<00:08,  1.45it/s]

torch.Size([128, 2, 2])


 95%|█████████▍| 213/225 [02:29<00:08,  1.45it/s]

torch.Size([128, 2, 2])


 95%|█████████▌| 214/225 [02:30<00:07,  1.44it/s]

torch.Size([128, 2, 2])


 96%|█████████▌| 215/225 [02:31<00:06,  1.44it/s]

torch.Size([128, 2, 2])


 96%|█████████▌| 216/225 [02:32<00:06,  1.44it/s]

torch.Size([128, 2, 2])


 96%|█████████▋| 217/225 [02:32<00:05,  1.45it/s]

torch.Size([128, 2, 2])


 97%|█████████▋| 218/225 [02:33<00:04,  1.41it/s]

torch.Size([128, 2, 2])


 97%|█████████▋| 219/225 [02:34<00:04,  1.36it/s]

torch.Size([128, 2, 2])


 98%|█████████▊| 220/225 [02:35<00:03,  1.33it/s]

torch.Size([128, 2, 2])


 98%|█████████▊| 221/225 [02:35<00:03,  1.33it/s]

torch.Size([128, 2, 2])


 99%|█████████▊| 222/225 [02:36<00:02,  1.34it/s]

torch.Size([128, 2, 2])


 99%|█████████▉| 223/225 [02:37<00:01,  1.31it/s]

torch.Size([128, 2, 2])


100%|█████████▉| 224/225 [02:38<00:00,  1.19it/s]

torch.Size([128, 2, 2])


100%|██████████| 225/225 [02:38<00:00,  1.42it/s]


In [37]:
correct = 0
total = 0
with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        net_out = net(test_X[i].view(-1, 1, 50, 50))[0]
        print(real_class, net_out)
        predicted_class = torch.argmax(net_out)
        if predicted_class == real_class:
            correct += 1
        total += 1
        
print("Accuracy:", round(correct/total, 3))

  2%|▏         | 43/2494 [00:00<00:15, 155.34it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.5313, 0.4687], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5549, 0.4451], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6367, 0.3633], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7407, 0.2593], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4685, 0.5315], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7444, 0.2556], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7938, 0.2062], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6203, 0.3797], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4070, 0.5930], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4791, 0.5209], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6327, 0.3673], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5044, 0.4956], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

  4%|▍         | 105/2494 [00:00<00:11, 207.77it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6182, 0.3818], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6508, 0.3492], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6515, 0.3485], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7022, 0.2978], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6539, 0.3461], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5805, 0.4195], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4786, 0.5214], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6598, 0.3402], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5308, 0.4692], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6222, 0.3778], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8378, 0.1622], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6523, 0.3477], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

  7%|▋         | 170/2494 [00:00<00:09, 254.32it/s]

tensor([0.3349, 0.6651], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7464, 0.2536], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3779, 0.6221], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6742, 0.3258], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4822, 0.5178], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4692, 0.5308], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6125, 0.3875], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8944, 0.1056], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4240, 0.5760], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5159, 0.4841], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8367, 0.1633], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.2856, 0.7144], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4623, 0.5377], dtype=torch.flo

 10%|▉         | 237/2494 [00:00<00:07, 287.77it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6835, 0.3165], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4498, 0.5502], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5813, 0.4187], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7360, 0.2640], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8057, 0.1943], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6557, 0.3443], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6657, 0.3343], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3629, 0.6371], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6079, 0.3921], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5949, 0.4051], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6430, 0.3570], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4309, 0.5691], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 12%|█▏        | 304/2494 [00:01<00:07, 307.53it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.4162, 0.5838], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3145, 0.6855], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8513, 0.1487], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4980, 0.5020], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7220, 0.2780], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6319, 0.3681], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6429, 0.3571], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7933, 0.2067], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6224, 0.3776], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.9071, 0.0929], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7062, 0.2938], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8496, 0.1504], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 15%|█▍        | 369/2494 [00:01<00:06, 315.15it/s]

tensor(1) tensor([0.3081, 0.6919], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7042, 0.2958], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7356, 0.2644], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2575, 0.7425], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5887, 0.4113], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6744, 0.3256], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2636, 0.7364], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4821, 0.5179], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4527, 0.5473], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4752, 0.5248], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6844, 0.3156], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2835, 0.7165], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5159, 0.4841], dtype

 17%|█▋        | 434/2494 [00:01<00:06, 316.11it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.5468, 0.4532], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6280, 0.3720], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4870, 0.5130], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7078, 0.2922], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3623, 0.6377], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8192, 0.1808], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7243, 0.2757], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5313, 0.4687], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5145, 0.4855], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4388, 0.5612], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7907, 0.2093], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3339, 0.6661], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 20%|█▉        | 498/2494 [00:01<00:06, 310.86it/s]

tensor(1) tensor([0.3946, 0.6054], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5565, 0.4435], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4748, 0.5252], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7072, 0.2928], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7821, 0.2179], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5687, 0.4313], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5136, 0.4864], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4419, 0.5581], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6567, 0.3433], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3346, 0.6654], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3572, 0.6428], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4981, 0.5019], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6747, 0.3253], dtype

 22%|██▏       | 561/2494 [00:01<00:06, 305.98it/s]

tensor([0.3540, 0.6460], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8382, 0.1618], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6519, 0.3481], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4010, 0.5990], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6359, 0.3641], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6521, 0.3479], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5507, 0.4493], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4034, 0.5966], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4699, 0.5301], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3397, 0.6603], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4524, 0.5476], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4864, 0.5136], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7843, 0.2157], dtype=torch.flo

 25%|██▌       | 624/2494 [00:02<00:06, 307.30it/s]

tensor(0) tensor([0.7890, 0.2110], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3754, 0.6246], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6703, 0.3297], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5321, 0.4679], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7429, 0.2571], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7546, 0.2454], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6816, 0.3184], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8176, 0.1824], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6850, 0.3150], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5836, 0.4164], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3843, 0.6157], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7161, 0.2839], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7757, 0.2243], dtype

 26%|██▋       | 655/2494 [00:02<00:06, 294.44it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.3797, 0.6203], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7734, 0.2266], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6083, 0.3917], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6557, 0.3443], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5983, 0.4017], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4098, 0.5902], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3800, 0.6200], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5390, 0.4610], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6573, 0.3427], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2902, 0.7098], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4416, 0.5584], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8286, 0.1714], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 29%|██▉       | 720/2494 [00:02<00:05, 307.58it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5164, 0.4836], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7003, 0.2997], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5336, 0.4664], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4014, 0.5986], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5679, 0.4321], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7651, 0.2349], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6766, 0.3234], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3300, 0.6700], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5622, 0.4378], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5470, 0.4530], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3355, 0.6645], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5188, 0.4812], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 31%|███▏      | 784/2494 [00:02<00:05, 311.24it/s]

tensor([0.4773, 0.5227], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7906, 0.2094], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5234, 0.4766], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5409, 0.4591], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3707, 0.6293], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3475, 0.6525], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6195, 0.3805], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5106, 0.4894], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3077, 0.6923], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3847, 0.6153], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6127, 0.3873], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4427, 0.5573], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3470, 0.6530], dtype=torch.flo

 34%|███▍      | 849/2494 [00:02<00:05, 316.07it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.6305, 0.3695], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5056, 0.4944], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6936, 0.3064], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5912, 0.4088], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4495, 0.5505], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6342, 0.3658], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8211, 0.1789], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7891, 0.2109], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4897, 0.5103], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4965, 0.5035], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7845, 0.2155], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3890, 0.6110], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 37%|███▋      | 914/2494 [00:02<00:04, 319.46it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.8440, 0.1560], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6508, 0.3492], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8075, 0.1925], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6205, 0.3795], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2772, 0.7228], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4658, 0.5342], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8307, 0.1693], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5406, 0.4594], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7089, 0.2911], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5291, 0.4709], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6711, 0.3289], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3839, 0.6161], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 39%|███▉      | 979/2494 [00:03<00:04, 319.74it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.4510, 0.5490], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3184, 0.6816], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3885, 0.6115], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6152, 0.3848], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5827, 0.4173], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5509, 0.4491], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5231, 0.4769], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4904, 0.5096], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6062, 0.3938], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4521, 0.5479], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4243, 0.5757], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8245, 0.1755], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 42%|████▏     | 1043/2494 [00:03<00:04, 315.48it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5340, 0.4660], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5501, 0.4499], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4334, 0.5666], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6121, 0.3879], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.9224, 0.0776], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6483, 0.3517], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5535, 0.4465], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4029, 0.5971], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4921, 0.5079], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6839, 0.3161], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7330, 0.2670], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6806, 0.3194], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 44%|████▍     | 1109/2494 [00:03<00:04, 318.93it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5765, 0.4235], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6858, 0.3142], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8003, 0.1997], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5687, 0.4313], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6885, 0.3115], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5511, 0.4489], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7003, 0.2997], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4139, 0.5861], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8180, 0.1820], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8436, 0.1564], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7138, 0.2862], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5449, 0.4551], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 47%|████▋     | 1175/2494 [00:03<00:04, 321.02it/s]

tensor(0) tensor([0.7486, 0.2514], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4348, 0.5652], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6845, 0.3155], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7404, 0.2596], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5683, 0.4317], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6872, 0.3128], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5187, 0.4813], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6691, 0.3309], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5493, 0.4507], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5887, 0.4113], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3011, 0.6989], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7718, 0.2282], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4946, 0.5054], dtype

 50%|████▉     | 1241/2494 [00:04<00:03, 316.67it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6027, 0.3973], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2888, 0.7112], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2901, 0.7099], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7823, 0.2177], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5774, 0.4226], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7425, 0.2575], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4764, 0.5236], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6638, 0.3362], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3765, 0.6235], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5337, 0.4663], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4001, 0.5999], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7958, 0.2042], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 52%|█████▏    | 1306/2494 [00:04<00:03, 316.65it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6637, 0.3363], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6521, 0.3479], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6675, 0.3325], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4020, 0.5980], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5304, 0.4696], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5167, 0.4833], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5892, 0.4108], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8653, 0.1347], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5672, 0.4328], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4618, 0.5382], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5129, 0.4871], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3532, 0.6468], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 55%|█████▍    | 1370/2494 [00:04<00:03, 309.58it/s]

tensor(1) tensor([0.5941, 0.4059], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7731, 0.2269], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6067, 0.3933], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7738, 0.2262], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6187, 0.3813], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5622, 0.4378], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3293, 0.6707], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5653, 0.4347], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5892, 0.4108], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8816, 0.1184], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5857, 0.4143], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6497, 0.3503], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4311, 0.5689], dtype

 57%|█████▋    | 1433/2494 [00:04<00:03, 309.37it/s]

tensor(1) tensor([0.2267, 0.7733], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6258, 0.3742], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8309, 0.1691], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5022, 0.4978], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4974, 0.5026], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6994, 0.3006], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5568, 0.4432], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4616, 0.5384], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2383, 0.7617], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6137, 0.3863], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3458, 0.6542], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5298, 0.4702], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6920, 0.3080], dtype

 60%|█████▉    | 1495/2494 [00:04<00:03, 307.92it/s]

tensor(1) tensor([0.6632, 0.3368], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5256, 0.4744], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6649, 0.3351], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6417, 0.3583], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6642, 0.3358], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6023, 0.3977], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4293, 0.5707], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3581, 0.6419], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5662, 0.4338], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7572, 0.2428], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8370, 0.1630], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5032, 0.4968], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8055, 0.1945], dtype

 62%|██████▏   | 1557/2494 [00:05<00:03, 303.09it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5171, 0.4829], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3490, 0.6510], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2022, 0.7978], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3643, 0.6357], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7359, 0.2641], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5887, 0.4113], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6817, 0.3183], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7128, 0.2872], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3613, 0.6387], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7072, 0.2928], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3253, 0.6747], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3620, 0.6380], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 65%|██████▍   | 1620/2494 [00:05<00:02, 305.60it/s]

tensor(0) tensor([0.5727, 0.4273], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4689, 0.5311], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5907, 0.4093], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7123, 0.2877], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7984, 0.2016], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4002, 0.5998], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6076, 0.3924], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4711, 0.5289], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6274, 0.3726], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7419, 0.2581], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4258, 0.5742], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2125, 0.7875], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5244, 0.4756], dtype

 67%|██████▋   | 1682/2494 [00:05<00:02, 300.67it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.3581, 0.6419], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5930, 0.4070], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8300, 0.1700], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4502, 0.5498], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7994, 0.2006], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6843, 0.3157], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4190, 0.5810], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3034, 0.6966], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5927, 0.4073], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6139, 0.3861], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.9130, 0.0870], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8002, 0.1998], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 70%|██████▉   | 1744/2494 [00:05<00:02, 303.07it/s]

tensor(1) tensor([0.4721, 0.5279], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3461, 0.6539], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6037, 0.3963], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2997, 0.7003], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4467, 0.5533], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5784, 0.4216], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4681, 0.5319], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4713, 0.5287], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4792, 0.5208], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6955, 0.3045], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4458, 0.5542], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3647, 0.6353], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6536, 0.3464], dtype

 72%|███████▏  | 1807/2494 [00:05<00:02, 305.83it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.5963, 0.4037], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5373, 0.4627], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5419, 0.4581], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5490, 0.4510], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4677, 0.5323], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6706, 0.3294], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5129, 0.4871], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3896, 0.6104], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5456, 0.4544], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5777, 0.4223], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3857, 0.6143], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5313, 0.4687], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 75%|███████▍  | 1869/2494 [00:06<00:02, 301.00it/s]

tensor([0.6422, 0.3578], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4952, 0.5048], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8283, 0.1717], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7403, 0.2597], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5046, 0.4954], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8305, 0.1695], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7741, 0.2259], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3224, 0.6776], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7655, 0.2345], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4423, 0.5577], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6024, 0.3976], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7082, 0.2918], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5448, 0.4552], dtype=torch.flo

 77%|███████▋  | 1931/2494 [00:06<00:01, 303.36it/s]

tensor(0) tensor([0.4551, 0.5449], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5005, 0.4995], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4559, 0.5441], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7527, 0.2473], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3936, 0.6064], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5593, 0.4407], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8319, 0.1681], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5653, 0.4347], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6770, 0.3230], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5816, 0.4184], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2448, 0.7552], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3235, 0.6765], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8270, 0.1730], dtype

 80%|███████▉  | 1993/2494 [00:06<00:01, 302.85it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.6670, 0.3330], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6951, 0.3049], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.2840, 0.7160], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7412, 0.2588], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7621, 0.2379], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.2109, 0.7891], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7284, 0.2716], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4571, 0.5429], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8335, 0.1665], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6705, 0.3295], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5247, 0.4753], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5591, 0.4409], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 82%|████████▏ | 2055/2494 [00:06<00:01, 303.83it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.2939, 0.7061], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6541, 0.3459], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6673, 0.3327], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5161, 0.4839], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4866, 0.5134], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7382, 0.2618], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7246, 0.2754], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4228, 0.5772], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7374, 0.2626], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5941, 0.4059], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5197, 0.4803], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6358, 0.3642], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 85%|████████▍ | 2117/2494 [00:06<00:01, 304.37it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.7470, 0.2530], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5495, 0.4505], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5679, 0.4321], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5186, 0.4814], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8185, 0.1815], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7957, 0.2043], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6111, 0.3889], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6610, 0.3390], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8238, 0.1762], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3556, 0.6444], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7115, 0.2885], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5581, 0.4419], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 87%|████████▋ | 2179/2494 [00:07<00:01, 304.52it/s]

tensor([0.6166, 0.3834], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5993, 0.4007], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4503, 0.5497], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8573, 0.1427], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8881, 0.1119], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5264, 0.4736], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6995, 0.3005], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4689, 0.5311], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2954, 0.7046], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4208, 0.5792], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3580, 0.6420], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6573, 0.3427], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3538, 0.6462], dtype=torch.flo

 90%|████████▉ | 2241/2494 [00:07<00:00, 301.86it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.7642, 0.2358], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4793, 0.5207], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6271, 0.3729], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4772, 0.5228], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7084, 0.2916], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6892, 0.3108], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5997, 0.4003], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5736, 0.4264], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4602, 0.5398], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4293, 0.5707], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7901, 0.2099], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6739, 0.3261], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 91%|█████████ | 2272/2494 [00:07<00:00, 302.32it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.7482, 0.2518], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5282, 0.4718], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3577, 0.6423], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3348, 0.6652], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5212, 0.4788], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2242, 0.7758], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6896, 0.3104], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6776, 0.3224], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7890, 0.2110], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7493, 0.2507], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5544, 0.4456], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5329, 0.4671], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor

 94%|█████████▎| 2336/2494 [00:07<00:00, 306.34it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.7907, 0.2093], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5688, 0.4312], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7601, 0.2399], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4726, 0.5274], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5114, 0.4886], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5701, 0.4299], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5008, 0.4992], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4998, 0.5002], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8447, 0.1553], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4518, 0.5482], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.8276, 0.1724], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6844, 0.3156], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor

 96%|█████████▌| 2400/2494 [00:07<00:00, 310.46it/s]

tensor(1) tensor([0.1693, 0.8307], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6124, 0.3876], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.6269, 0.3731], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3216, 0.6784], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2210, 0.7790], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6760, 0.3240], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.2998, 0.7002], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.5534, 0.4466], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7185, 0.2815], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8457, 0.1543], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4534, 0.5466], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7397, 0.2603], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.2743, 0.7257], dtype

100%|██████████| 2494/2494 [00:08<00:00, 308.12it/s]

torch.Size([128, 2, 2])
tensor(0) tensor([0.7234, 0.2766], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.7000, 0.3000], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4956, 0.5044], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.7915, 0.2085], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.3284, 0.6716], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.8153, 0.1847], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.4021, 0.5979], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6743, 0.3257], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.6342, 0.3658], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor([0.5811, 0.4189], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.3797, 0.6203], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(1) tensor([0.4962, 0.5038], dtype=torch.float64)
torch.Size([128, 2, 2])
tensor(0) tensor


