In [None]:
import torch
from torch.utils.data import IterableDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from skorch import NeuralNetClassifier
from skorch.callbacks import Callback


=== Solution 3: Custom validation with callbacks ===
  valid_acc: 0.5550
  epoch    train_loss    valid_acc     dur
-------  ------------  -----------  ------
      1        [36m0.7095[0m       0.5550  0.0530
  valid_acc: 0.4750
      2        [36m0.7076[0m       0.4750  0.0306
  valid_acc: 0.5200
      3        0.7113       0.5200  0.0277
  valid_acc: 0.5100
      4        [36m0.7030[0m       0.5100  0.0251
  valid_acc: 0.4500
      5        [36m0.7019[0m       0.4500  0.0259


In [None]:
# 1. Define a streaming dataset using IterableDataset
class StreamingDataset(IterableDataset):
    def __init__(self, length=1000, seed=42):
        self.length = length
        self.rng = torch.Generator().manual_seed(seed)

    def __iter__(self):
        for _ in range(self.length):
            X = torch.randn(20, generator=self.rng)
            y = torch.randint(0, 2, (1,), generator=self.rng).item()
            yield X, y

In [None]:
# 2. Define a simple classifier
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim=20, hidden_dim=50, output_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, X):
        X = F.relu(self.fc1(X))
        return self.fc2(X)



In [None]:
# 3. Create streaming datasets
train_ds = StreamingDataset(length=1000, seed=0)
valid_ds = StreamingDataset(length=200, seed=1)

# 4. Create skorch-compatible DataLoaders
train_loader = DataLoader(train_ds, batch_size=16)
valid_loader = DataLoader(valid_ds, batch_size=16)

# Create a custom callback for validation with streaming data
class StreamingValidationCallback(Callback):
    def __init__(self, valid_loader, name='valid_acc'):
        self.valid_loader = valid_loader
        self.name = name
    
    def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
        net.module_.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for X_batch, y_batch in self.valid_loader:
                X_batch = X_batch.to(net.device)
                if isinstance(y_batch, (list, tuple)):
                    y_batch = torch.tensor(y_batch).to(net.device)
                else:
                    y_batch = y_batch.to(net.device)
                
                outputs = net.module_(X_batch)
                _, predicted = torch.max(outputs.data, 1)
                total += y_batch.size(0)
                correct += (predicted == y_batch).sum().item()
        
        accuracy = correct / total
        print(f"  {self.name}: {accuracy:.4f}")
        net.history.record(self.name, accuracy)

# Create validation data
valid_ds_for_callback = StreamingDataset(length=200, seed=1)
valid_loader_for_callback = DataLoader(valid_ds_for_callback, batch_size=16)



In [None]:
net = NeuralNetClassifier(
    module=SimpleClassifier,
    module__input_dim=20,
    module__hidden_dim=50,
    module__output_dim=2,
    max_epochs=5,
    lr=0.01,
    criterion=nn.CrossEntropyLoss,
    iterator_train=DataLoader,
    train_split=None,  
    callbacks=[
        StreamingValidationCallback(valid_loader_for_callback, name='valid_acc'),
    ],
    device='cuda' if torch.cuda.is_available() else 'cpu',
    verbose=1,
)

net.fit(train_ds, y=None)  # Pass dataset directly, not DataLoader