In [1]:
import torch
from torch.nn import functional as F
from torch import nn
from torch.optim import Adam
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from keras.datasets import mnist
import os
from torchvision import datasets, transforms

Using TensorFlow backend.


In [2]:
class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(nn.Linear(28*28, 64), nn.ReLU(), nn.Linear(64, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28*28))
        
    def forward(self, x):
        embedding = self.encoder(x)
        return embedding
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)
         
    def training_step(self, batch, batch_idx):
        x = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss, prog_bar=True)
        return loss

In [3]:
(mnist_train, _), _ = mnist.load_data()

mnist_train = torch.Tensor(mnist_train)
mnist_train /= 255.

mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_val = DataLoader(mnist_val, batch_size=64)

In [4]:
model = LitAutoEncoder()
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(model, mnist_train, mnist_val)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]