## Working with Streaming Datasets with Skorch
This notebook demonstrates how to train a neural network classifier using streaming data with the skorch library. This approach is useful when working with large datasets that don't fit entirely in memory. This notebook was contributed by [Parag Ekbote](https://github.com/ParagEkbote).

We will also implement a custom validation callback for compatibility with streaming data patterns

Firstly you will need to install the following libraries: skorch and torch.

<table align="left"><td>
<a target="_blank" href="https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Streaming_Dataset.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank" href="https://github.com/skorch-dev/skorch/blob/master/notebooks/Streaming_Datase.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

## Imports

In [1]:
import torch
from torch.utils.data import IterableDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from skorch import NeuralNetClassifier
from skorch.callbacks import Callback

## Data Preparation

In this example, we will use synthetic dataset for training the model. We will utilize the `IterableDataset` from PyTorch to create a binary classification dataset.

In [None]:
class StreamingDataset(IterableDataset):
    def __init__(self, length=1000, seed=42):
        self.length = length
        self.rng = torch.Generator().manual_seed(seed)

    def __iter__(self):
        for _ in range(self.length):
            X = torch.randn(20, generator=self.rng)
            y = torch.randint(0, 2, (1,), generator=self.rng).item()
            yield X, y

## Neural Network Parameters

The SimpleClassifier is a neural network designed for binary classification tasks. It consists of an input layer, two hidden layers with ReLU activation functions, and a single output layer. The architecture is parameterized to allow flexibility in adjusting the input and hidden layer dimensions.

In [None]:
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim=20, hidden_dim=50, output_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, X):
        X = F.relu(self.fc1(X))
        return self.fc2(X)

## Custom Callback

We will now setup a validation callback for use during model training. The callback is defined to evaluate the model’s accuracy on a validation set at the end of each training epoch. The training dataset and  validation dataset are wrapped in Torch `DataLoader` to enable batch processing.

Normally, callbacks with references to the data are not recommended to keep the model and training data separate. However, as we cannot use the traditional `train_split` for streaming datasets and since the data is streamed, this is okay here.

In [None]:
#Create streaming datasets
train_ds = StreamingDataset(length=1000, seed=0)
valid_ds = StreamingDataset(length=200, seed=1)

train_loader = DataLoader(train_ds, batch_size=16)
valid_loader = DataLoader(valid_ds, batch_size=16)

class StreamingValidationCallback(Callback):
    def __init__(self, valid_loader, name='valid_acc'):
        self.valid_loader = valid_loader
        self.name = name
    
    def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
        net.module_.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for X_batch, y_batch in self.valid_loader:
                X_batch = X_batch.to(net.device)
                if isinstance(y_batch, (list, tuple)):
                    y_batch = torch.tensor(y_batch).to(net.device)
                else:
                    y_batch = y_batch.to(net.device)
                
                outputs = net.module_(X_batch)
                _, predicted = torch.max(outputs.data, 1)
                total += y_batch.size(0)
                correct += (predicted == y_batch).sum().item()
        
        accuracy = correct / total
        print(f"  {self.name}: {accuracy:.4f}")
        net.history.record(self.name, accuracy)

valid_ds_for_callback = StreamingDataset(length=200, seed=1)
valid_loader_for_callback = DataLoader(valid_ds_for_callback, batch_size=16)



## Model Training and Result

We will now wrap the neural net classifier with the torch model and use `DataLoader` to create data batches. We will also add the custom validation callback separately from the training loop. It is important to note that since we are using `IterableDataset`, we have to set the `train_split` to `None`, since streaming data cannot be split in advance.

We will now train the neural network and observe that the training has been completed successfully.

In [None]:
net = NeuralNetClassifier(
    module=SimpleClassifier,
    module__input_dim=20,
    module__hidden_dim=50,
    module__output_dim=2,
    max_epochs=5,
    lr=0.01,
    criterion=nn.CrossEntropyLoss,
    iterator_train=DataLoader,
    train_split=None,  
    callbacks=[
        StreamingValidationCallback(valid_loader_for_callback, name='valid_acc'),
    ],
    device='cuda' if torch.cuda.is_available() else 'cpu',
    verbose=1,
)

net.fit(train_ds, y=None) 