In [5]:
# Step 1: Importing Libraries
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

In [6]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=1, shuffle=True, num_workers=1, pin_memory=True  # Batch size of 1
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=1024, shuffle=True, num_workers=1, pin_memory=True
)

In [7]:
class SimpleFCN(nn.Module):
    def __init__(self, width):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, width)
        self.fc2 = nn.Linear(width, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [13]:
network_width = 128
model = SimpleFCN(network_width)

optimizer = optim.SGD(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

In [14]:
clip_value = 10000
noise_scale = 0

In [18]:
# Step 5: Training the Model with DP-GD
def train_dp(model, train_loader, optimizer, criterion, device, clip_value, noise_scale):
    model.train()
    train_losses = []
    total_grads = [torch.zeros_like(param) for param in model.parameters()]
    
    for data, target in tqdm(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        # Clip and aggregate gradients
        for i, param in enumerate(model.parameters()):
            if param.grad is not None:
                # print(param.shape)
                grad_norm = param.grad.norm()
                # print(grad_norm)
                if grad_norm > clip_value:
                    param.grad.mul_(clip_value / grad_norm)
                total_grads[i] += param.grad

        # break
    
    # Add noise and update parameters
    for i, param in enumerate(model.parameters()):
        noise = torch.normal(0, noise_scale, size=param.grad.size()).to(device)
        total_grads[i] += noise
        param.data -= optimizer.param_groups[0]['lr'] * total_grads[i]
    
    print(f"Train Loss: {loss.item():.6f}")

In [19]:
# Step 6: Evaluating the Model
def test(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f"Test Loss: {test_loss:.6f}, Accuracy: {accuracy:.2f}%")

In [20]:
# Main Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

for epoch in range(1, 11):
    print(f"Epoch {epoch}")
    train_dp(model, train_loader, optimizer, criterion, device, clip_value, noise_scale)
    test(model, test_loader, criterion, device)

cuda
Epoch 1


  0%|          | 28/60000 [00:00<03:35, 278.81it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  0%|          | 72/60000 [00:00<02:41, 370.60it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  0%|          | 111/60000 [00:00<02:38, 377.38it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  0%|          | 153/60000 [00:00<02:32, 392.76it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  0%|          | 193/60000 [00:00<02:32, 391.00it/s]

torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])


  0%|          | 233/60000 [00:00<02:33, 390.25it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  0%|          | 275/60000 [00:00<02:29, 399.75it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|          | 321/60000 [00:00<02:23, 416.23it/s]

torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])


  1%|          | 365/60000 [00:00<02:21, 421.64it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|          | 408/60000 [00:01<02:23, 415.36it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|          | 450/60000 [00:01<02:24, 412.89it/s]

torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])


  1%|          | 493/60000 [00:01<02:23, 415.85it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|          | 535/60000 [00:01<02:23, 414.60it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|          | 577/60000 [00:01<02:26, 404.74it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|          | 618/60000 [00:01<02:27, 403.22it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|          | 661/60000 [00:01<02:24, 409.27it/s]

torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])


  1%|          | 705/60000 [00:01<02:22, 416.48it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|▏         | 751/60000 [00:01<02:18, 426.93it/s]

torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])


  1%|▏         | 796/60000 [00:01<02:16, 432.31it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  1%|▏         | 840/60000 [00:02<02:17, 431.66it/s]

torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])


  1%|▏         | 884/60000 [00:02<02:20, 422.10it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  2%|▏         | 928/60000 [00:02<02:18, 426.60it/s]

torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])


  2%|▏         | 971/60000 [00:02<02:19, 423.93it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  2%|▏         | 1059/60000 [00:02<02:18, 426.64it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  2%|▏         | 1145/60000 [00:02<02:19, 420.89it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  2%|▏         | 1233/60000 [00:02<02:16, 429.22it/s]

torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])


  2%|▏         | 1319/60000 [00:03<02:21, 416.11it/s]

torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
t

  2%|▏         | 1403/60000 [00:03<02:23, 409.62it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  2%|▏         | 1490/60000 [00:03<02:20, 416.01it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  3%|▎         | 1581/60000 [00:03<02:14, 433.23it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  3%|▎         | 1669/60000 [00:04<02:15, 429.74it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])

  3%|▎         | 1759/60000 [00:04<02:13, 437.74it/s]

torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])


  3%|▎         | 1803/60000 [00:04<02:15, 429.94it/s]

torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
t

  3%|▎         | 1847/60000 [00:04<02:18, 420.45it/s]

torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])


  3%|▎         | 1911/60000 [00:04<02:19, 416.14it/s]

torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
torch.Size([128, 784])
torch.Size([128])




KeyboardInterrupt: 