In [1]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

"""
parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
"""


cuda = True
batch_size = 128
epochs = 10
seed = 1
log_interval = 10




#torch.manual_seed(args.seed)
torch.manual_seed(seed)

#device = torch.device("cuda" if args.cuda else "cpu")
device = torch.device("cuda" if cuda else "cpu")

#kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        #print(logvar.shape)
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    # recon_x is output from the activation layer. if recon_x has not gone through sigmoid activation, 
    # use binary_cross_entropy_with_logits
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)
    print("BCE: "+str(BCE/batch_size))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    #tmp = 1 + logvar - mu.pow(2) - logvar.exp()
    #print("tmp: "+str(tmp.shape))
    #KLD = -0.5*torch.sum(tmp)
    #print("KLD: "+str(KLD.shape))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    #print(BCE)

    return BCE + KLD


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        #print("recon: "+str(recon_batch.shape))
        #print("var: "+str(torch.exp(logvar)))
        loss = loss_function(recon_batch, data, mu, logvar)
        #print("loss: "+str(loss.item()/len(data)))
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        #print(len(data))
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
            #print("loss "+str(loss.item()))
            #print("len "+str(len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                   'results/sample_' + str(epoch) + '.png')

BCE: tensor(550.4553, device='cuda:0')
BCE: tensor(524.4629, device='cuda:0')
BCE: tensor(498.4903, device='cuda:0')
BCE: tensor(472.5098, device='cuda:0')
BCE: tensor(445.9448, device='cuda:0')
BCE: tensor(413.1568, device='cuda:0')
BCE: tensor(376.8734, device='cuda:0')
BCE: tensor(342.9244, device='cuda:0')
BCE: tensor(307.4867, device='cuda:0')
BCE: tensor(286.3045, device='cuda:0')
BCE: tensor(271.4376, device='cuda:0')
BCE: tensor(269.9990, device='cuda:0')
BCE: tensor(242.1886, device='cuda:0')
BCE: tensor(246.1439, device='cuda:0')
BCE: tensor(234.0586, device='cuda:0')
BCE: tensor(232.4550, device='cuda:0')
BCE: tensor(238.2680, device='cuda:0')
BCE: tensor(232.4213, device='cuda:0')
BCE: tensor(234.8048, device='cuda:0')
BCE: tensor(235.3544, device='cuda:0')
BCE: tensor(228.2004, device='cuda:0')
BCE: tensor(228.2328, device='cuda:0')
BCE: tensor(229.5990, device='cuda:0')
BCE: tensor(227.6273, device='cuda:0')
BCE: tensor(224.0567, device='cuda:0')
BCE: tensor(218.5842, dev

BCE: tensor(140.3069, device='cuda:0')
BCE: tensor(142.1251, device='cuda:0')
BCE: tensor(140.8108, device='cuda:0')
BCE: tensor(143.3246, device='cuda:0')
BCE: tensor(134.8968, device='cuda:0')
BCE: tensor(135.5101, device='cuda:0')
BCE: tensor(135.9858, device='cuda:0')
BCE: tensor(132.6208, device='cuda:0')
BCE: tensor(135.2714, device='cuda:0')
BCE: tensor(139.0252, device='cuda:0')
BCE: tensor(145.5138, device='cuda:0')
BCE: tensor(140.6665, device='cuda:0')
BCE: tensor(134.7077, device='cuda:0')
BCE: tensor(138.4087, device='cuda:0')
BCE: tensor(147.1949, device='cuda:0')
BCE: tensor(138.5897, device='cuda:0')
BCE: tensor(135.0194, device='cuda:0')
BCE: tensor(133.8106, device='cuda:0')
BCE: tensor(133.7702, device='cuda:0')
BCE: tensor(139.9744, device='cuda:0')
BCE: tensor(137.2571, device='cuda:0')
BCE: tensor(135.7049, device='cuda:0')
BCE: tensor(141.2299, device='cuda:0')
BCE: tensor(136.5901, device='cuda:0')
BCE: tensor(129.5381, device='cuda:0')
BCE: tensor(132.1745, dev

BCE: tensor(115.5907, device='cuda:0')
BCE: tensor(112.5023, device='cuda:0')
BCE: tensor(115.8947, device='cuda:0')
BCE: tensor(116.9788, device='cuda:0')
BCE: tensor(109.1163, device='cuda:0')
BCE: tensor(115.9317, device='cuda:0')
BCE: tensor(110.4351, device='cuda:0')
BCE: tensor(111.7496, device='cuda:0')
BCE: tensor(113.4461, device='cuda:0')
BCE: tensor(110.6179, device='cuda:0')
BCE: tensor(118.0482, device='cuda:0')
BCE: tensor(115.3419, device='cuda:0')
BCE: tensor(112.3787, device='cuda:0')
BCE: tensor(113.3605, device='cuda:0')
BCE: tensor(112.4492, device='cuda:0')
BCE: tensor(117.6300, device='cuda:0')
BCE: tensor(109.3702, device='cuda:0')
BCE: tensor(113.3417, device='cuda:0')
BCE: tensor(114.3627, device='cuda:0')
BCE: tensor(112.7672, device='cuda:0')
BCE: tensor(113.9465, device='cuda:0')
BCE: tensor(108.7418, device='cuda:0')
BCE: tensor(114.0589, device='cuda:0')
BCE: tensor(106.9518, device='cuda:0')
BCE: tensor(111.4654, device='cuda:0')
BCE: tensor(117.4809, dev

BCE: tensor(104.9141, device='cuda:0')
BCE: tensor(103.5047, device='cuda:0')
BCE: tensor(106.3395, device='cuda:0')
BCE: tensor(105.0266, device='cuda:0')
BCE: tensor(105.2963, device='cuda:0')
BCE: tensor(105.6754, device='cuda:0')
BCE: tensor(102.4691, device='cuda:0')
BCE: tensor(104.3124, device='cuda:0')
BCE: tensor(105.7452, device='cuda:0')
BCE: tensor(102.4893, device='cuda:0')
BCE: tensor(107.9761, device='cuda:0')
BCE: tensor(104.8931, device='cuda:0')
BCE: tensor(101.4401, device='cuda:0')
BCE: tensor(102.4455, device='cuda:0')
BCE: tensor(108.1486, device='cuda:0')
BCE: tensor(103.3183, device='cuda:0')
BCE: tensor(104.4914, device='cuda:0')
BCE: tensor(104.4707, device='cuda:0')
BCE: tensor(110.8865, device='cuda:0')
BCE: tensor(106.8913, device='cuda:0')
BCE: tensor(106.1307, device='cuda:0')
BCE: tensor(108.9270, device='cuda:0')
BCE: tensor(105.8031, device='cuda:0')
BCE: tensor(107.2867, device='cuda:0')
BCE: tensor(102.7793, device='cuda:0')
BCE: tensor(104.1164, dev

BCE: tensor(100.2165, device='cuda:0')
BCE: tensor(99.4167, device='cuda:0')
BCE: tensor(102.0073, device='cuda:0')
BCE: tensor(98.4584, device='cuda:0')
BCE: tensor(98.2431, device='cuda:0')
BCE: tensor(103.3540, device='cuda:0')
BCE: tensor(99.1083, device='cuda:0')
BCE: tensor(95.8166, device='cuda:0')
BCE: tensor(98.6326, device='cuda:0')
BCE: tensor(97.8230, device='cuda:0')
BCE: tensor(97.7437, device='cuda:0')
BCE: tensor(96.8327, device='cuda:0')
BCE: tensor(99.7467, device='cuda:0')
BCE: tensor(95.0701, device='cuda:0')
BCE: tensor(104.2478, device='cuda:0')
BCE: tensor(103.8524, device='cuda:0')
BCE: tensor(95.3147, device='cuda:0')
BCE: tensor(98.1987, device='cuda:0')
BCE: tensor(101.3011, device='cuda:0')
BCE: tensor(98.4934, device='cuda:0')
BCE: tensor(93.7872, device='cuda:0')
BCE: tensor(96.7326, device='cuda:0')
BCE: tensor(97.6096, device='cuda:0')
BCE: tensor(100.8404, device='cuda:0')
BCE: tensor(97.4294, device='cuda:0')
BCE: tensor(101.0519, device='cuda:0')
BCE:

BCE: tensor(89.4342, device='cuda:0')
BCE: tensor(93.1345, device='cuda:0')
BCE: tensor(94.6868, device='cuda:0')
BCE: tensor(95.0194, device='cuda:0')
BCE: tensor(94.6229, device='cuda:0')
BCE: tensor(94.6185, device='cuda:0')
BCE: tensor(95.2846, device='cuda:0')
BCE: tensor(92.7947, device='cuda:0')
BCE: tensor(91.3364, device='cuda:0')
BCE: tensor(89.7886, device='cuda:0')
BCE: tensor(90.6022, device='cuda:0')
BCE: tensor(94.0145, device='cuda:0')
BCE: tensor(96.3211, device='cuda:0')
BCE: tensor(95.4855, device='cuda:0')
BCE: tensor(94.0981, device='cuda:0')
BCE: tensor(93.4243, device='cuda:0')
BCE: tensor(95.7922, device='cuda:0')
BCE: tensor(93.8178, device='cuda:0')
BCE: tensor(91.8729, device='cuda:0')
BCE: tensor(92.7262, device='cuda:0')
BCE: tensor(88.6749, device='cuda:0')
BCE: tensor(92.1267, device='cuda:0')
BCE: tensor(88.2124, device='cuda:0')
BCE: tensor(88.6962, device='cuda:0')
BCE: tensor(69.0707, device='cuda:0')
====> Epoch: 2 Average loss: 121.7119
BCE: tensor(

BCE: tensor(91.7801, device='cuda:0')
BCE: tensor(90.6868, device='cuda:0')
BCE: tensor(96.1107, device='cuda:0')
BCE: tensor(91.6777, device='cuda:0')
BCE: tensor(92.2650, device='cuda:0')
BCE: tensor(92.6991, device='cuda:0')
BCE: tensor(90.2634, device='cuda:0')
BCE: tensor(95.0052, device='cuda:0')
BCE: tensor(94.2770, device='cuda:0')
BCE: tensor(91.3284, device='cuda:0')
BCE: tensor(96.7111, device='cuda:0')
BCE: tensor(90.0760, device='cuda:0')
BCE: tensor(89.1279, device='cuda:0')
BCE: tensor(93.3157, device='cuda:0')
BCE: tensor(94.3051, device='cuda:0')
BCE: tensor(91.1720, device='cuda:0')
BCE: tensor(90.4532, device='cuda:0')
BCE: tensor(86.1434, device='cuda:0')
BCE: tensor(90.2857, device='cuda:0')
BCE: tensor(91.6135, device='cuda:0')
BCE: tensor(87.2990, device='cuda:0')
BCE: tensor(91.1507, device='cuda:0')
BCE: tensor(85.2123, device='cuda:0')
BCE: tensor(94.0788, device='cuda:0')
BCE: tensor(88.0926, device='cuda:0')
BCE: tensor(85.2849, device='cuda:0')
BCE: tensor(

BCE: tensor(88.6026, device='cuda:0')
BCE: tensor(88.3870, device='cuda:0')
BCE: tensor(88.3594, device='cuda:0')
BCE: tensor(84.9945, device='cuda:0')
BCE: tensor(90.8540, device='cuda:0')
BCE: tensor(83.9627, device='cuda:0')
BCE: tensor(92.1393, device='cuda:0')
BCE: tensor(88.9450, device='cuda:0')
BCE: tensor(93.9971, device='cuda:0')
BCE: tensor(90.4540, device='cuda:0')
BCE: tensor(92.0059, device='cuda:0')
BCE: tensor(88.6819, device='cuda:0')
BCE: tensor(92.0239, device='cuda:0')
BCE: tensor(92.7946, device='cuda:0')
BCE: tensor(91.0194, device='cuda:0')
BCE: tensor(91.0829, device='cuda:0')
BCE: tensor(84.7052, device='cuda:0')
BCE: tensor(91.4329, device='cuda:0')
BCE: tensor(85.9003, device='cuda:0')
BCE: tensor(90.4319, device='cuda:0')
BCE: tensor(91.7857, device='cuda:0')
BCE: tensor(88.3500, device='cuda:0')
BCE: tensor(90.7405, device='cuda:0')
BCE: tensor(88.2219, device='cuda:0')
BCE: tensor(89.9873, device='cuda:0')
BCE: tensor(90.9906, device='cuda:0')
BCE: tensor(

BCE: tensor(78.1689, device='cuda:0')
BCE: tensor(76.7118, device='cuda:0')
BCE: tensor(74.5878, device='cuda:0')
BCE: tensor(79.6072, device='cuda:0')
BCE: tensor(78.7717, device='cuda:0')
BCE: tensor(78.1358, device='cuda:0')
BCE: tensor(76.1227, device='cuda:0')
BCE: tensor(85.1523, device='cuda:0')
BCE: tensor(77.9600, device='cuda:0')
BCE: tensor(76.6211, device='cuda:0')
BCE: tensor(75.3621, device='cuda:0')
BCE: tensor(78.4637, device='cuda:0')
BCE: tensor(81.0541, device='cuda:0')
BCE: tensor(78.6616, device='cuda:0')
BCE: tensor(78.0092, device='cuda:0')
BCE: tensor(77.7478, device='cuda:0')
BCE: tensor(78.9051, device='cuda:0')
BCE: tensor(78.5046, device='cuda:0')
BCE: tensor(80.5995, device='cuda:0')
BCE: tensor(78.1679, device='cuda:0')
BCE: tensor(80.4904, device='cuda:0')
BCE: tensor(80.5772, device='cuda:0')
BCE: tensor(79.4083, device='cuda:0')
BCE: tensor(78.1051, device='cuda:0')
BCE: tensor(81.4763, device='cuda:0')
BCE: tensor(80.5092, device='cuda:0')
BCE: tensor(

BCE: tensor(88.1633, device='cuda:0')
BCE: tensor(88.2324, device='cuda:0')
BCE: tensor(86.0766, device='cuda:0')
BCE: tensor(89.6232, device='cuda:0')
BCE: tensor(88.6068, device='cuda:0')
BCE: tensor(88.9358, device='cuda:0')
BCE: tensor(83.3992, device='cuda:0')
BCE: tensor(89.5160, device='cuda:0')
BCE: tensor(82.3019, device='cuda:0')
BCE: tensor(90.4387, device='cuda:0')
BCE: tensor(87.5774, device='cuda:0')
BCE: tensor(89.4104, device='cuda:0')
BCE: tensor(90.8558, device='cuda:0')
BCE: tensor(84.7517, device='cuda:0')
BCE: tensor(85.7834, device='cuda:0')
BCE: tensor(88.0695, device='cuda:0')
BCE: tensor(88.7728, device='cuda:0')
BCE: tensor(88.4809, device='cuda:0')
BCE: tensor(85.4205, device='cuda:0')
BCE: tensor(88.2796, device='cuda:0')
BCE: tensor(88.8328, device='cuda:0')
BCE: tensor(87.9854, device='cuda:0')
BCE: tensor(85.9990, device='cuda:0')
BCE: tensor(90.5369, device='cuda:0')
BCE: tensor(86.9131, device='cuda:0')
BCE: tensor(83.4847, device='cuda:0')
BCE: tensor(

BCE: tensor(87.0682, device='cuda:0')
BCE: tensor(85.0959, device='cuda:0')
BCE: tensor(86.8493, device='cuda:0')
BCE: tensor(80.9120, device='cuda:0')
BCE: tensor(90.4155, device='cuda:0')
BCE: tensor(86.8661, device='cuda:0')
BCE: tensor(83.8138, device='cuda:0')
BCE: tensor(82.3385, device='cuda:0')
BCE: tensor(86.2797, device='cuda:0')
BCE: tensor(82.7385, device='cuda:0')
BCE: tensor(83.7593, device='cuda:0')
BCE: tensor(88.1243, device='cuda:0')
BCE: tensor(85.7753, device='cuda:0')
BCE: tensor(87.3990, device='cuda:0')
BCE: tensor(84.5566, device='cuda:0')
BCE: tensor(90.8450, device='cuda:0')
BCE: tensor(89.1980, device='cuda:0')
BCE: tensor(87.3407, device='cuda:0')
BCE: tensor(88.5134, device='cuda:0')
BCE: tensor(90.2277, device='cuda:0')
BCE: tensor(86.8624, device='cuda:0')
BCE: tensor(83.8756, device='cuda:0')
BCE: tensor(84.8304, device='cuda:0')
BCE: tensor(87.3504, device='cuda:0')
BCE: tensor(86.1405, device='cuda:0')
BCE: tensor(86.5299, device='cuda:0')
BCE: tensor(

BCE: tensor(85.9692, device='cuda:0')
BCE: tensor(86.1344, device='cuda:0')
BCE: tensor(82.1957, device='cuda:0')
BCE: tensor(87.8633, device='cuda:0')
BCE: tensor(87.0423, device='cuda:0')
BCE: tensor(87.6094, device='cuda:0')
BCE: tensor(84.3320, device='cuda:0')
BCE: tensor(87.3302, device='cuda:0')
BCE: tensor(84.9139, device='cuda:0')
BCE: tensor(87.3297, device='cuda:0')
BCE: tensor(84.2589, device='cuda:0')
BCE: tensor(84.9112, device='cuda:0')
BCE: tensor(85.4978, device='cuda:0')
BCE: tensor(85.1496, device='cuda:0')
BCE: tensor(85.6296, device='cuda:0')
BCE: tensor(83.9510, device='cuda:0')
BCE: tensor(82.3843, device='cuda:0')
BCE: tensor(88.7184, device='cuda:0')
BCE: tensor(84.6880, device='cuda:0')
BCE: tensor(85.8722, device='cuda:0')
BCE: tensor(84.1608, device='cuda:0')
BCE: tensor(85.2445, device='cuda:0')
BCE: tensor(86.9738, device='cuda:0')
BCE: tensor(86.2471, device='cuda:0')
BCE: tensor(86.6744, device='cuda:0')
BCE: tensor(85.8490, device='cuda:0')
BCE: tensor(

BCE: tensor(82.7805, device='cuda:0')
BCE: tensor(80.2812, device='cuda:0')
BCE: tensor(86.5346, device='cuda:0')
BCE: tensor(83.1872, device='cuda:0')
BCE: tensor(81.3184, device='cuda:0')
BCE: tensor(84.9830, device='cuda:0')
BCE: tensor(86.0563, device='cuda:0')
BCE: tensor(85.7450, device='cuda:0')
BCE: tensor(86.3802, device='cuda:0')
BCE: tensor(87.7279, device='cuda:0')
BCE: tensor(83.4111, device='cuda:0')
BCE: tensor(85.1616, device='cuda:0')
BCE: tensor(86.8462, device='cuda:0')
BCE: tensor(84.1683, device='cuda:0')
BCE: tensor(84.0771, device='cuda:0')
BCE: tensor(81.7424, device='cuda:0')
BCE: tensor(80.4285, device='cuda:0')
BCE: tensor(82.8755, device='cuda:0')
BCE: tensor(86.7318, device='cuda:0')
BCE: tensor(83.5938, device='cuda:0')
BCE: tensor(87.2671, device='cuda:0')
BCE: tensor(84.9566, device='cuda:0')
BCE: tensor(85.8121, device='cuda:0')
BCE: tensor(86.0176, device='cuda:0')
BCE: tensor(84.5444, device='cuda:0')
BCE: tensor(83.1179, device='cuda:0')
BCE: tensor(

BCE: tensor(85.5568, device='cuda:0')
BCE: tensor(85.7654, device='cuda:0')
BCE: tensor(81.9998, device='cuda:0')
BCE: tensor(84.2035, device='cuda:0')
BCE: tensor(83.9211, device='cuda:0')
BCE: tensor(82.3461, device='cuda:0')
BCE: tensor(87.2376, device='cuda:0')
BCE: tensor(91.0324, device='cuda:0')
BCE: tensor(84.6112, device='cuda:0')
BCE: tensor(85.4140, device='cuda:0')
BCE: tensor(86.0809, device='cuda:0')
BCE: tensor(88.7628, device='cuda:0')
BCE: tensor(85.5498, device='cuda:0')
BCE: tensor(83.4135, device='cuda:0')
BCE: tensor(80.9310, device='cuda:0')
BCE: tensor(86.5419, device='cuda:0')
BCE: tensor(83.5050, device='cuda:0')
BCE: tensor(83.2059, device='cuda:0')
BCE: tensor(84.0081, device='cuda:0')
BCE: tensor(80.6170, device='cuda:0')
BCE: tensor(85.9138, device='cuda:0')
BCE: tensor(86.5124, device='cuda:0')
BCE: tensor(85.9754, device='cuda:0')
BCE: tensor(85.2402, device='cuda:0')
BCE: tensor(84.6942, device='cuda:0')
BCE: tensor(83.8411, device='cuda:0')
BCE: tensor(

BCE: tensor(80.3124, device='cuda:0')
BCE: tensor(83.4034, device='cuda:0')
BCE: tensor(87.7182, device='cuda:0')
BCE: tensor(85.7345, device='cuda:0')
BCE: tensor(84.9297, device='cuda:0')
BCE: tensor(84.1560, device='cuda:0')
BCE: tensor(85.1924, device='cuda:0')
BCE: tensor(83.1886, device='cuda:0')
BCE: tensor(85.6972, device='cuda:0')
BCE: tensor(87.6968, device='cuda:0')
BCE: tensor(82.0790, device='cuda:0')
BCE: tensor(86.5980, device='cuda:0')
BCE: tensor(82.4204, device='cuda:0')
BCE: tensor(82.9233, device='cuda:0')
BCE: tensor(88.2512, device='cuda:0')
BCE: tensor(85.0839, device='cuda:0')
BCE: tensor(84.0185, device='cuda:0')
BCE: tensor(83.7148, device='cuda:0')
BCE: tensor(83.5244, device='cuda:0')
BCE: tensor(81.5748, device='cuda:0')
BCE: tensor(85.6702, device='cuda:0')
BCE: tensor(87.2715, device='cuda:0')
BCE: tensor(85.3264, device='cuda:0')
BCE: tensor(87.9557, device='cuda:0')
BCE: tensor(84.5867, device='cuda:0')
BCE: tensor(89.1636, device='cuda:0')
BCE: tensor(

BCE: tensor(83.8507, device='cuda:0')
BCE: tensor(86.6892, device='cuda:0')
BCE: tensor(81.5632, device='cuda:0')
BCE: tensor(83.9906, device='cuda:0')
BCE: tensor(84.4536, device='cuda:0')
BCE: tensor(84.9573, device='cuda:0')
BCE: tensor(83.6616, device='cuda:0')
BCE: tensor(83.0008, device='cuda:0')
BCE: tensor(83.3436, device='cuda:0')
BCE: tensor(84.7953, device='cuda:0')
BCE: tensor(83.8354, device='cuda:0')
BCE: tensor(82.5879, device='cuda:0')
BCE: tensor(83.3079, device='cuda:0')
BCE: tensor(82.7351, device='cuda:0')
BCE: tensor(83.2392, device='cuda:0')
BCE: tensor(81.0091, device='cuda:0')
BCE: tensor(81.1664, device='cuda:0')
BCE: tensor(87.6882, device='cuda:0')
BCE: tensor(80.9333, device='cuda:0')
BCE: tensor(78.8519, device='cuda:0')
BCE: tensor(84.6264, device='cuda:0')
BCE: tensor(81.5496, device='cuda:0')
BCE: tensor(82.8764, device='cuda:0')
BCE: tensor(85.7341, device='cuda:0')
BCE: tensor(87.5990, device='cuda:0')
BCE: tensor(80.3056, device='cuda:0')
BCE: tensor(

BCE: tensor(85.4379, device='cuda:0')
BCE: tensor(80.7839, device='cuda:0')
BCE: tensor(82.0071, device='cuda:0')
BCE: tensor(82.0511, device='cuda:0')
BCE: tensor(82.9244, device='cuda:0')
BCE: tensor(80.6678, device='cuda:0')
BCE: tensor(83.5678, device='cuda:0')
BCE: tensor(64.3522, device='cuda:0')
====> Epoch: 6 Average loss: 108.7701
BCE: tensor(72.4289, device='cuda:0')
BCE: tensor(75.7602, device='cuda:0')
BCE: tensor(72.3347, device='cuda:0')
BCE: tensor(74.0697, device='cuda:0')
BCE: tensor(75.4249, device='cuda:0')
BCE: tensor(73.1436, device='cuda:0')
BCE: tensor(75.5420, device='cuda:0')
BCE: tensor(70.7291, device='cuda:0')
BCE: tensor(73.7928, device='cuda:0')
BCE: tensor(77.1585, device='cuda:0')
BCE: tensor(75.5759, device='cuda:0')
BCE: tensor(69.9013, device='cuda:0')
BCE: tensor(73.4019, device='cuda:0')
BCE: tensor(76.4283, device='cuda:0')
BCE: tensor(77.9653, device='cuda:0')
BCE: tensor(71.2647, device='cuda:0')
BCE: tensor(74.0709, device='cuda:0')
BCE: tensor(

BCE: tensor(81.1717, device='cuda:0')
BCE: tensor(82.6474, device='cuda:0')
BCE: tensor(82.8019, device='cuda:0')
BCE: tensor(81.1568, device='cuda:0')
BCE: tensor(82.1054, device='cuda:0')
BCE: tensor(84.3382, device='cuda:0')
BCE: tensor(77.1710, device='cuda:0')
BCE: tensor(83.8647, device='cuda:0')
BCE: tensor(83.5129, device='cuda:0')
BCE: tensor(85.6234, device='cuda:0')
BCE: tensor(83.7464, device='cuda:0')
BCE: tensor(83.7169, device='cuda:0')
BCE: tensor(86.4118, device='cuda:0')
BCE: tensor(80.2049, device='cuda:0')
BCE: tensor(82.1486, device='cuda:0')
BCE: tensor(84.5296, device='cuda:0')
BCE: tensor(85.4218, device='cuda:0')
BCE: tensor(81.1278, device='cuda:0')
BCE: tensor(80.1204, device='cuda:0')
BCE: tensor(80.3089, device='cuda:0')
BCE: tensor(79.8568, device='cuda:0')
BCE: tensor(86.4719, device='cuda:0')
BCE: tensor(81.9234, device='cuda:0')
BCE: tensor(79.5279, device='cuda:0')
BCE: tensor(84.5084, device='cuda:0')
BCE: tensor(83.8848, device='cuda:0')
BCE: tensor(

BCE: tensor(83.7987, device='cuda:0')
BCE: tensor(83.3132, device='cuda:0')
BCE: tensor(85.3290, device='cuda:0')
BCE: tensor(88.9150, device='cuda:0')
BCE: tensor(81.7775, device='cuda:0')
BCE: tensor(82.8644, device='cuda:0')
BCE: tensor(84.5246, device='cuda:0')
BCE: tensor(85.4568, device='cuda:0')
BCE: tensor(83.1663, device='cuda:0')
BCE: tensor(81.5841, device='cuda:0')
BCE: tensor(82.6491, device='cuda:0')
BCE: tensor(81.8122, device='cuda:0')
BCE: tensor(80.1064, device='cuda:0')
BCE: tensor(83.5336, device='cuda:0')
BCE: tensor(80.7732, device='cuda:0')
BCE: tensor(81.6695, device='cuda:0')
BCE: tensor(85.8543, device='cuda:0')
BCE: tensor(85.0763, device='cuda:0')
BCE: tensor(83.4336, device='cuda:0')
BCE: tensor(81.8882, device='cuda:0')
BCE: tensor(80.1013, device='cuda:0')
BCE: tensor(80.7142, device='cuda:0')
BCE: tensor(87.4104, device='cuda:0')
BCE: tensor(81.8456, device='cuda:0')
BCE: tensor(82.2948, device='cuda:0')
BCE: tensor(84.3355, device='cuda:0')
BCE: tensor(

BCE: tensor(74.3822, device='cuda:0')
BCE: tensor(74.1214, device='cuda:0')
BCE: tensor(72.5897, device='cuda:0')
BCE: tensor(70.8323, device='cuda:0')
BCE: tensor(70.5768, device='cuda:0')
BCE: tensor(73.4460, device='cuda:0')
BCE: tensor(71.0644, device='cuda:0')
BCE: tensor(71.7628, device='cuda:0')
BCE: tensor(71.6445, device='cuda:0')
BCE: tensor(73.3939, device='cuda:0')
BCE: tensor(70.6547, device='cuda:0')
BCE: tensor(74.6079, device='cuda:0')
BCE: tensor(72.0192, device='cuda:0')
BCE: tensor(73.4348, device='cuda:0')
BCE: tensor(72.0320, device='cuda:0')
BCE: tensor(72.6030, device='cuda:0')
BCE: tensor(75.3244, device='cuda:0')
BCE: tensor(74.9030, device='cuda:0')
BCE: tensor(75.8490, device='cuda:0')
BCE: tensor(72.0842, device='cuda:0')
BCE: tensor(79.7358, device='cuda:0')
BCE: tensor(8.1136, device='cuda:0')
====> Test set loss: 98.0701
BCE: tensor(87.8866, device='cuda:0')
BCE: tensor(80.4301, device='cuda:0')
BCE: tensor(82.0632, device='cuda:0')
BCE: tensor(83.2130, d

BCE: tensor(78.0086, device='cuda:0')
BCE: tensor(78.4728, device='cuda:0')
BCE: tensor(80.1329, device='cuda:0')
BCE: tensor(79.4376, device='cuda:0')
BCE: tensor(79.4157, device='cuda:0')
BCE: tensor(85.6193, device='cuda:0')
BCE: tensor(85.5086, device='cuda:0')
BCE: tensor(80.8317, device='cuda:0')
BCE: tensor(86.0385, device='cuda:0')
BCE: tensor(80.3475, device='cuda:0')
BCE: tensor(78.4212, device='cuda:0')
BCE: tensor(80.5171, device='cuda:0')
BCE: tensor(79.2996, device='cuda:0')
BCE: tensor(81.2368, device='cuda:0')
BCE: tensor(83.9739, device='cuda:0')
BCE: tensor(80.3020, device='cuda:0')
BCE: tensor(81.9234, device='cuda:0')
BCE: tensor(81.0051, device='cuda:0')
BCE: tensor(81.7866, device='cuda:0')
BCE: tensor(82.5627, device='cuda:0')
BCE: tensor(84.3725, device='cuda:0')
BCE: tensor(79.0607, device='cuda:0')
BCE: tensor(79.2315, device='cuda:0')
BCE: tensor(83.3455, device='cuda:0')
BCE: tensor(84.6841, device='cuda:0')
BCE: tensor(86.4217, device='cuda:0')
BCE: tensor(

BCE: tensor(82.9661, device='cuda:0')
BCE: tensor(83.3357, device='cuda:0')
BCE: tensor(82.5845, device='cuda:0')
BCE: tensor(83.1599, device='cuda:0')
BCE: tensor(81.7310, device='cuda:0')
BCE: tensor(79.9393, device='cuda:0')
BCE: tensor(80.9999, device='cuda:0')
BCE: tensor(79.0342, device='cuda:0')
BCE: tensor(82.7756, device='cuda:0')
BCE: tensor(79.1635, device='cuda:0')
BCE: tensor(85.8236, device='cuda:0')
BCE: tensor(75.2690, device='cuda:0')
BCE: tensor(82.3676, device='cuda:0')
BCE: tensor(80.1829, device='cuda:0')
BCE: tensor(79.5352, device='cuda:0')
BCE: tensor(81.1270, device='cuda:0')
BCE: tensor(84.0444, device='cuda:0')
BCE: tensor(81.5516, device='cuda:0')
BCE: tensor(79.9121, device='cuda:0')
BCE: tensor(83.2844, device='cuda:0')
BCE: tensor(83.1868, device='cuda:0')
BCE: tensor(81.0173, device='cuda:0')
BCE: tensor(82.8151, device='cuda:0')
BCE: tensor(83.6585, device='cuda:0')
BCE: tensor(81.5868, device='cuda:0')
BCE: tensor(80.6114, device='cuda:0')
BCE: tensor(

BCE: tensor(79.8919, device='cuda:0')
BCE: tensor(82.9067, device='cuda:0')
BCE: tensor(83.6967, device='cuda:0')
BCE: tensor(80.5283, device='cuda:0')
BCE: tensor(80.2363, device='cuda:0')
BCE: tensor(81.7283, device='cuda:0')
BCE: tensor(80.8133, device='cuda:0')
BCE: tensor(80.8348, device='cuda:0')
BCE: tensor(81.5531, device='cuda:0')
BCE: tensor(77.3017, device='cuda:0')
BCE: tensor(78.0708, device='cuda:0')
BCE: tensor(82.2699, device='cuda:0')
BCE: tensor(79.4230, device='cuda:0')
BCE: tensor(79.3423, device='cuda:0')
BCE: tensor(82.0235, device='cuda:0')
BCE: tensor(81.3668, device='cuda:0')
BCE: tensor(79.0194, device='cuda:0')
BCE: tensor(80.2403, device='cuda:0')
BCE: tensor(79.4234, device='cuda:0')
BCE: tensor(82.4987, device='cuda:0')
BCE: tensor(85.7090, device='cuda:0')
BCE: tensor(82.8009, device='cuda:0')
BCE: tensor(81.5287, device='cuda:0')
BCE: tensor(83.0555, device='cuda:0')
BCE: tensor(77.9108, device='cuda:0')
BCE: tensor(81.8843, device='cuda:0')
BCE: tensor(

BCE: tensor(80.3407, device='cuda:0')
BCE: tensor(84.6497, device='cuda:0')
BCE: tensor(80.0138, device='cuda:0')
BCE: tensor(84.2781, device='cuda:0')
BCE: tensor(83.0990, device='cuda:0')
BCE: tensor(82.8100, device='cuda:0')
BCE: tensor(80.7626, device='cuda:0')
BCE: tensor(76.6055, device='cuda:0')
BCE: tensor(82.0713, device='cuda:0')
BCE: tensor(82.6287, device='cuda:0')
BCE: tensor(83.0821, device='cuda:0')
BCE: tensor(86.6925, device='cuda:0')
BCE: tensor(79.9447, device='cuda:0')
BCE: tensor(83.2963, device='cuda:0')
BCE: tensor(81.3017, device='cuda:0')
BCE: tensor(77.9565, device='cuda:0')
BCE: tensor(80.1601, device='cuda:0')
BCE: tensor(79.8599, device='cuda:0')
BCE: tensor(80.6151, device='cuda:0')
BCE: tensor(81.2475, device='cuda:0')
BCE: tensor(81.7957, device='cuda:0')
BCE: tensor(79.3487, device='cuda:0')
BCE: tensor(80.6032, device='cuda:0')
BCE: tensor(83.5897, device='cuda:0')
BCE: tensor(84.5185, device='cuda:0')
BCE: tensor(79.5417, device='cuda:0')
BCE: tensor(

BCE: tensor(82.1261, device='cuda:0')
BCE: tensor(83.4655, device='cuda:0')
BCE: tensor(82.1524, device='cuda:0')
BCE: tensor(79.8550, device='cuda:0')
BCE: tensor(77.4068, device='cuda:0')
BCE: tensor(84.6033, device='cuda:0')
BCE: tensor(82.9290, device='cuda:0')
BCE: tensor(81.7094, device='cuda:0')
BCE: tensor(83.6018, device='cuda:0')
BCE: tensor(81.5666, device='cuda:0')
BCE: tensor(79.9140, device='cuda:0')
BCE: tensor(80.5347, device='cuda:0')
BCE: tensor(82.3408, device='cuda:0')
BCE: tensor(81.8898, device='cuda:0')
BCE: tensor(80.9087, device='cuda:0')
BCE: tensor(81.3560, device='cuda:0')
BCE: tensor(82.4159, device='cuda:0')
BCE: tensor(81.8942, device='cuda:0')
BCE: tensor(84.4698, device='cuda:0')
BCE: tensor(76.9810, device='cuda:0')
BCE: tensor(82.7029, device='cuda:0')
BCE: tensor(82.7050, device='cuda:0')
BCE: tensor(82.0561, device='cuda:0')
BCE: tensor(79.0740, device='cuda:0')
BCE: tensor(82.7266, device='cuda:0')
BCE: tensor(81.8133, device='cuda:0')
BCE: tensor(

BCE: tensor(79.4468, device='cuda:0')
BCE: tensor(75.5564, device='cuda:0')
BCE: tensor(81.0668, device='cuda:0')
BCE: tensor(83.3612, device='cuda:0')
BCE: tensor(81.7592, device='cuda:0')
BCE: tensor(86.6559, device='cuda:0')
BCE: tensor(80.4227, device='cuda:0')
BCE: tensor(82.9741, device='cuda:0')
BCE: tensor(80.8957, device='cuda:0')
BCE: tensor(81.2642, device='cuda:0')
BCE: tensor(81.9029, device='cuda:0')
BCE: tensor(81.0512, device='cuda:0')
BCE: tensor(82.6356, device='cuda:0')
BCE: tensor(80.5613, device='cuda:0')
BCE: tensor(78.3603, device='cuda:0')
BCE: tensor(79.7966, device='cuda:0')
BCE: tensor(80.4089, device='cuda:0')
BCE: tensor(80.1804, device='cuda:0')
BCE: tensor(80.0642, device='cuda:0')
BCE: tensor(86.0708, device='cuda:0')
BCE: tensor(82.1227, device='cuda:0')
BCE: tensor(80.0772, device='cuda:0')
BCE: tensor(80.6121, device='cuda:0')
BCE: tensor(79.7581, device='cuda:0')
BCE: tensor(86.7340, device='cuda:0')
BCE: tensor(82.4306, device='cuda:0')
BCE: tensor(

BCE: tensor(77.7049, device='cuda:0')
BCE: tensor(80.5116, device='cuda:0')
BCE: tensor(81.2141, device='cuda:0')
BCE: tensor(82.5149, device='cuda:0')
BCE: tensor(83.1400, device='cuda:0')
BCE: tensor(80.3127, device='cuda:0')
BCE: tensor(80.8088, device='cuda:0')
BCE: tensor(78.5407, device='cuda:0')
BCE: tensor(86.4324, device='cuda:0')
BCE: tensor(79.6420, device='cuda:0')
BCE: tensor(82.6584, device='cuda:0')
BCE: tensor(81.6568, device='cuda:0')
BCE: tensor(79.4548, device='cuda:0')
BCE: tensor(79.5058, device='cuda:0')
BCE: tensor(80.3042, device='cuda:0')
BCE: tensor(77.1753, device='cuda:0')
BCE: tensor(84.4442, device='cuda:0')
BCE: tensor(81.9859, device='cuda:0')
BCE: tensor(81.7222, device='cuda:0')
BCE: tensor(84.3543, device='cuda:0')
BCE: tensor(80.4401, device='cuda:0')
BCE: tensor(83.5729, device='cuda:0')
BCE: tensor(82.6252, device='cuda:0')
BCE: tensor(79.5094, device='cuda:0')
BCE: tensor(79.5771, device='cuda:0')
BCE: tensor(81.0187, device='cuda:0')
BCE: tensor(

BCE: tensor(76.8485, device='cuda:0')
BCE: tensor(80.0884, device='cuda:0')
BCE: tensor(82.2178, device='cuda:0')
BCE: tensor(79.9798, device='cuda:0')
BCE: tensor(78.4674, device='cuda:0')
BCE: tensor(80.8303, device='cuda:0')
BCE: tensor(85.3770, device='cuda:0')
BCE: tensor(61.9431, device='cuda:0')
====> Epoch: 10 Average loss: 106.3760
BCE: tensor(71.9850, device='cuda:0')
BCE: tensor(71.5086, device='cuda:0')
BCE: tensor(71.9345, device='cuda:0')
BCE: tensor(74.8455, device='cuda:0')
BCE: tensor(72.2213, device='cuda:0')
BCE: tensor(73.1502, device='cuda:0')
BCE: tensor(73.8238, device='cuda:0')
BCE: tensor(69.2600, device='cuda:0')
BCE: tensor(71.3151, device='cuda:0')
BCE: tensor(71.1923, device='cuda:0')
BCE: tensor(73.4261, device='cuda:0')
BCE: tensor(73.1665, device='cuda:0')
BCE: tensor(69.8356, device='cuda:0')
BCE: tensor(70.2552, device='cuda:0')
BCE: tensor(70.4265, device='cuda:0')
BCE: tensor(72.3901, device='cuda:0')
BCE: tensor(71.5564, device='cuda:0')
BCE: tensor

In [12]:
x = torch.tensor([[1,2],[3,4]])
t = torch.tensor([[1,1],[1,1]])
x = (x>t).float()
x

tensor([[ 0.,  1.],
        [ 1.,  1.]])

In [None]:
train_loader

In [5]:
for batch_idx, (data, _) in enumerate(train_loader):
    print(type(data))
    print(data.shape)
    #d = data.view(-1, 784)
    #print (d.shape)

<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torc

<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torc

<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torch.Tensor'>
torch.Size([128, 1, 28, 28])
<class 'torc

In [5]:
x = torch.randn(2,4)
x

tensor([[-0.3111, -0.9662, -0.7928,  0.6137],
        [-0.1155,  0.2167,  0.5646, -3.1376]])

In [7]:
x.view(8)

tensor([-0.3111, -0.9662, -0.7928,  0.6137, -0.1155,  0.2167,  0.5646,
        -3.1376])