In [22]:
import torch
import pytorch_lightning as pl
from collections import OrderedDict
from torch.utils.data import DataLoader, random_split
import torchvision as tv
from pytorch_lightning.callbacks import EarlyStopping

In [16]:
class FashionMNIST(pl.LightningModule):
    def __init__(self, batch_size, lr):
        super(FashionMNIST, self).__init__()
        self.batch_size = batch_size
        self.lr = lr
        
        self.model = torch.nn.Sequential(OrderedDict([
            ('flatten', torch.nn.Flatten()),
            ('batchnorm', torch.nn.BatchNorm1d(784)),
            ('lin1', torch.nn.Linear(784, 256)),
            ('relu', torch.nn.ReLU()),
            ('lin2', torch.nn.Linear(256, 10))
        ]))

    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.lr)
    
    def cross_entropy_loss(self, logits, labels):
        loss = torch.nn.CrossEntropyLoss()
        return loss(logits, labels)
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.forward(X)
        loss = self.cross_entropy_loss(y_hat, y)
        logs = {'loss': loss}
        return {'loss': loss, 'log': logs}
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.forward(X)
        loss = self.cross_entropy_loss(y_hat, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': logs}
    
    def test_step(self, batch, batch_nb):
        X, y = batch
        y_hat = self.forward(X)
        return {'test_loss': self.cross_entropy_loss(y_hat, y)}
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': logs, 'progress_bar': logs}
    
    def prepare_data(self):
        self.train_set, self.val_set = random_split(tv.datasets.FashionMNIST('../..', train=True, download=True, transform=tv.transforms.ToTensor()), [55000, 5000])
        self.test_set = tv.datasets.FashionMNIST('../..', train=False, download=True, transform=tv.transforms.ToTensor())
    
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size)

In [41]:
clf = FashionMNIST(batch_size=256, lr=0.05)

In [42]:
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.05,
    patience=5,
    verbose=False,
    mode='min'
)

In [43]:
trainer = pl.Trainer(early_stop_callback=early_stop_callback)

In [44]:
trainer.fit(clf)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…




1

In [45]:
trainer.test()

HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=40.0, style=Progre…

----------------------------------------------------------------------------------------------------
TEST RESULTS
{'test_loss': tensor(0.3562)}
----------------------------------------------------------------------------------------------------

