# Pytorch Lightning Examples

These examples are taken from [Trivial Multi-Node Training With Pytorch-Lightning](https://towardsdatascience.com/trivial-multi-node-training-with-pytorch-lightning-ff75dfb809bd).


In [2]:
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as ptl
from pytorch_lightning import Trainer
from test_tube import Experiment




In [3]:
class Model(ptl.LightningModule):
    def __init__(self):
        super(Model, self).__init__()
        
        # Define the model here
        self.l1 = torch.nn.Linear(28*28,10)
    
    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))
    
    def loss_func(self, y_hat, y):
        return F.cross_entropy(y_hat, y)
    
    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': self.loss_func(y_hat, y)}
    
    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': self.loss_func(y_hat, y)}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'avg_val_loss': avg_loss}
    
    def configure_optimizers(self):
        return [torch.optim.Adam(self.parameters(), lr=0.02)]
    
    @ptl.data_loader
    def tng_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    @ptl.data_loader
    def val_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    @ptl.data_loader
    def test_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)


In [4]:
model = Model()
exp = Experiment(save_dir=os.getcwd())

trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)
trainer.fit(model)

AttributeError: module 'tensorflow.io' has no attribute 'gfile'