In [26]:
import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(87)
torch.cuda.manual_seed(87)

Setup

In [27]:
import numpy as np
from sklearn.datasets import make_classification

np.random.seed(87)

X, y = make_classification(
    n_samples=1000, n_features=20, n_informative=10, random_state=87)
X, y = X.astype(np.float32), y.astype(np.int64)

print(X.shape, y.shape, y.mean())

(1000, 20) (1000,) 0.496


Define pytorch classification module

In [28]:
from skorch import NeuralNetClassifier

class ClassifierModule(nn.Module):
    def __init__(
        self,
        num_units=10,
        nonlin=F.relu,
        dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()
        
        self.num_units = num_units
        self.nonlin = nonlin
        self.dropout = dropout

        self.dense0 = nn.Linear(20, num_units)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)
        self.dropout = nn.Dropout(dropout)
        self.nonlin = nonlin
        
    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        
        X = self.nonlin(self.dense1(X))
        X = F.softmax(self.output(X), dim=-1)
        return X

Writing CUSTOM callback
Rules:
    1. inherit skorch.callbacks.Callback
    2. implement atleast one on_ method
    3. argments first neuralnet instance, second optionally local data, **kwargs
    (4. attributes that should be reset in the initialize method

mock twitter API that tweets epoch validation accuracy

In [29]:
from skorch.callbacks import Callback

def tweet(msg):
    print('~'*60)
    print("*tweet*", msg, "#skorch #pytorch")
    
class AccuracyTweet(Callback):
    def __init__(self, min_accuracy):
        self.min_accuracy = min_accuracy
        
    def initialize(self):
        self.critical_epoch_ = -1
    
    # runst after each epoch
    def on_epoch_end(self, net, **kwargs):
        if self.critical_epoch_ > -1:
            return
        
        if net.history[-1, 'valid_acc'] >= self.min_accuracy:
            self.critical_epoch_ = len(net.history)
            
    # runs after each training
    def on_train_end(self, net, **kwargs):
        if self.critical_epoch_ < 0:
            msg = f"Accuracy never reached {self.min_accuracy}"
        else:
            msg = f"Accuracy reached {self.min_accuracy} at epoch {self.critical_epoch_}"
            
        tweet(msg)

Train the model

In [30]:
net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=15,
    lr=0.02,
    warm_start=True,
    callbacks=[AccuracyTweet(min_accuracy=0.7)],
)

net.fit(X, y)
net.fit(X, y)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7250[0m       [32m0.5050[0m        [35m0.7168[0m  0.0168
      2        [36m0.7169[0m       0.5000        [35m0.7069[0m  0.0112
      3        [36m0.7091[0m       [32m0.5150[0m        [35m0.6997[0m  0.0125
      4        [36m0.6983[0m       [32m0.5250[0m        [35m0.6938[0m  0.0080
      5        [36m0.6949[0m       [32m0.5650[0m        [35m0.6886[0m  0.0056
      6        [36m0.6905[0m       [32m0.5850[0m        [35m0.6849[0m  0.0052
      7        [36m0.6806[0m       [32m0.6050[0m        [35m0.6812[0m  0.0056
      8        0.6845       [32m0.6500[0m        [35m0.6776[0m  0.0052
      9        0.6811       [32m0.6550[0m        [35m0.6743[0m  0.0060
     10        [36m0.6742[0m       [32m0.6700[0m        [35m0.6712[0m  0.0059
     11        [36m0.6707[0m       [32m0.6900[0m        [35m0.6675[

<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=ClassifierModule(
    (dense0): Linear(in_features=20, out_features=10, bias=True)
    (dense1): Linear(in_features=10, out_features=10, bias=True)
    (output): Linear(in_features=10, out_features=2, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  ),
)

Multiple return values from forward

In [31]:
from skorch import NeuralNetRegressor

class Encoder(nn.Module):
    def __init__(self, num_units=5):
        super().__init__()
        self.num_units = num_units
        
        self.encode = nn.Sequential(
            nn.Linear(20, 10),
            nn.ReLU(),
            nn.Linear(10, num_units),
            nn.ReLU(),
        )
        
    def forward(self, X, **kwargs):
        return self.encode(X)
    
class Decoder(nn.Module):
    def __init__(self, num_units=5):
        super().__init__()
        self.num_units = num_units
        
        self.decode = nn.Sequential(
            nn.Linear(self.num_units, 10),
            nn.ReLU(),
            nn.Linear(10, 20),
        )
        
    def forward(self, X, **kwargs):
        return self.decode(X)
    
class AutoEncoder(nn.Module):
    def __init__(self, num_units=5):
        super().__init__()
        self.num_units = num_units
        
        self.encoder = Encoder(self.num_units)
        self.decoder = Decoder(self.num_units)
        
    def forward(self, X, **kwargs):
        encoded = self.encoder(X)
        decoded = self.decoder(encoded)
        return decoded, encoded

Override the get_loss method

In [32]:
class AutoEncoderNet(NeuralNetRegressor):
    def get_loss(self, y_pred, y_true, *args, **kwargs):
        decoded, encoded = y_pred
        
        loss_reconstruction = super().get_loss(
            decoded, y_true, *args, **kwargs)
        
        loss_l1 = torch.abs(encoded).sum()
        
        return loss_reconstruction + loss_l1

Training the autoencoder

In [41]:
net = AutoEncoderNet(
    AutoEncoder,
    module__num_units=5,
    lr=0.3,
    max_epochs=20
)

net.fit(X, X)

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1       [36m29.4691[0m        [32m3.8066[0m  0.0164
      2        [36m3.8723[0m        [32m3.7768[0m  0.0141
      3        [36m3.8420[0m        [32m3.7617[0m  0.0105
      4        [36m3.8259[0m        [32m3.7545[0m  0.0092
      5        [36m3.8178[0m        [32m3.7510[0m  0.0095
      6        [36m3.8139[0m        [32m3.7494[0m  0.0073
      7        [36m3.8120[0m        [32m3.7485[0m  0.0099
      8        [36m3.8110[0m        [32m3.7480[0m  0.0099
      9        [36m3.8105[0m        [32m3.7476[0m  0.0093
     10        [36m3.8103[0m        [32m3.7474[0m  0.0080
     11        [36m3.8101[0m        [32m3.7472[0m  0.0110
     12        [36m3.8100[0m        [32m3.7471[0m  0.0083
     13        [36m3.8100[0m        [32m3.7470[0m  0.0072
     14        [36m3.8099[0m        [32m3.7470[0m  0.0126
     15        [36m3.8099[0m        [32m3

<class '__main__.AutoEncoderNet'>[initialized](
  module_=AutoEncoder(
    (encoder): Encoder(
      (encode): Sequential(
        (0): Linear(in_features=20, out_features=10, bias=True)
        (1): ReLU()
        (2): Linear(in_features=10, out_features=5, bias=True)
        (3): ReLU()
      )
    )
    (decoder): Decoder(
      (decode): Sequential(
        (0): Linear(in_features=5, out_features=10, bias=True)
        (1): ReLU()
        (2): Linear(in_features=10, out_features=20, bias=True)
      )
    )
  ),
)

Extracting decoder and encoder output

In [42]:
y_pred = net.predict(X)
print(y_pred.shape) # only decoder state is returned

(1000, 20)


In [43]:
# retrieve all predicted batches from Module.forward
decoder_pred, encoder_pred = net.forward(X)
print(decoder_pred.shape, encoder_pred.shape)

torch.Size([1000, 20]) torch.Size([1000, 5])


In [36]:
# lazy colleciton, one batch at a time
for decoder_pred, encoder_pred in net.forward_iter(X):
    print(decoder_pred.shape, encoder_pred.shape)
    break

torch.Size([128, 20]) torch.Size([128, 5])


Was the encoder sparse?

In [44]:
torch.isclose(encoder_pred, torch.zeros_like(encoder_pred)).float().mean()

tensor(1.)