In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Parameters
disk_epochs = 150
online_epochs = 50
epochs = online_epochs + disk_epochs

warmup_steps = epochs // 4
eval_interval = 1 #epochs // 20

noise_dim = 10
image_dim = 28
batch_size = 64

max_translate = 3
max_rotation = 15

target_accuracy = 0.98

In [None]:
# Load MNIST and extract one sample per class
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

base_images, base_labels = [], []
seen_classes = set()
for img, label in train_dataset:
    if label not in seen_classes:
        base_images.append(img)
        base_labels.append(label)
        seen_classes.add(label)
    if len(seen_classes) == 10:
        break
base_images = torch.stack(base_images, dim=0).to(device)
base_labels = torch.tensor(base_labels, dtype=torch.long).to(device)

In [None]:
# Define Sampler MLP
class SamplerMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784 + noise_dim, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Linear(128, 3)
        )
        self.net[-1].weight.data.uniform_(-0.1, 0.1)
        self.net[-1].bias.data.zero_()

    def forward(self, x):
        params = self.net(x)
        return torch.stack([
            max_translate * torch.tanh(params[:, 0]),
            max_translate * torch.tanh(params[:, 1]),
            max_rotation * torch.tanh(params[:, 2]),
        ], dim=1)

In [None]:
# Define Classifier CNN
class ClassifierCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(8, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Flatten(),
            nn.Linear(16 * 7 * 7, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.net(x)


In [None]:
# Transformation function
def apply_transform(images, params, current_step):
    B = images.size(0)
    tx, ty, rot_deg = params[:, 0], params[:, 1], params[:, 2]

    if warmup_steps > 0:
        progress = torch.tensor(current_step / warmup_steps, device=device).clamp(0, 1)
        tx = tx * progress
        ty = ty * progress
        rot_deg = rot_deg * progress

    rot_rad = torch.deg2rad(rot_deg)
    cos, sin = torch.cos(rot_rad), torch.sin(rot_rad)

    affine_mat = torch.zeros(B, 3, 3, device=device)
    affine_mat[:, 0, 0] = cos
    affine_mat[:, 0, 1] = -sin
    affine_mat[:, 1, 0] = sin
    affine_mat[:, 1, 1] = cos
    affine_mat[:, 0, 2] = tx / 14.0
    affine_mat[:, 1, 2] = ty / 14.0
    affine_mat[:, 2, 2] = 1.0

    theta = torch.inverse(affine_mat)[:, :2, :]
    grid = F.affine_grid(theta, images.size(), align_corners=False)
    return F.grid_sample(images, grid, align_corners=False, padding_mode='border')

In [None]:
# Evaluation function
def evaluate(model):
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    model.train()
    return correct / total

In [None]:
# Visualization function
def visualize_augmentations(num_samples=5):
    sampler.eval()
    with torch.no_grad():
        fig, axs = plt.subplots(10, num_samples+1, figsize=(15, 25))
        for digit in range(10):
            img = base_images[digit].cpu().numpy().squeeze()
            axs[digit, 0].imshow(img, cmap='gray')
            axs[digit, 0].axis('off')

            for i in range(num_samples):
                noise = torch.randn(1, noise_dim, device=device)

                # We need to make them both 1D before concatenating:
                params = sampler(torch.cat([base_images[digit].flatten(), noise.flatten()], 0).unsqueeze(0))

                # and then unsqueeze to add a batch dimension
                aug_img = apply_transform(base_images[digit].unsqueeze(0), params, online_epochs)  # Use max step for full transform
                axs[digit, i+1].imshow(aug_img.squeeze().cpu().numpy(), cmap='gray')
                axs[digit, i+1].axis('off')
        plt.tight_layout()
        plt.show()
    sampler.train()

In [None]:
# Online Training

def online_training(epochs, warmup_steps, sampler, optimizer_s,
                    classifier, optimizer_c, scheduler_c, criterion):

  print("Starting training with generated instances...")

  # Results
  loss_history     = []
  s_loss_history   = []
  accuracy_history = []

  for step in range(1, epochs + 1):

      # Phase 1: Train classifier
      idx = torch.randint(0, 10, (batch_size,), device=device)
      batch_images = base_images[idx]
      batch_labels = base_labels[idx]
      noise = torch.randn(batch_size, noise_dim, device=device)

      with torch.no_grad():
          params = sampler(torch.cat([batch_images.view(batch_size, -1), noise], 1))

      aug_images = apply_transform(batch_images, params, step)

      optimizer_c.zero_grad()
      logits = classifier(aug_images)
      loss_c = criterion(logits, batch_labels)
      loss_c.backward()
      torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
      optimizer_c.step()
      scheduler_c.step()

      # Phase 2: Train sampler (after warmup)
      if step > warmup_steps:
          for p in classifier.parameters():
              p.requires_grad = False

          params = sampler(torch.cat([batch_images.view(batch_size, -1), noise], 1))
          aug_images = apply_transform(batch_images, params, step)
          logits = classifier(aug_images)
          loss_s = -criterion(logits, batch_labels)

          optimizer_s.zero_grad()
          loss_s.backward()
          torch.nn.utils.clip_grad_norm_(sampler.parameters(), 1.0)
          optimizer_s.step()

          for p in classifier.parameters():
              p.requires_grad = True
      else:
          loss_s = torch.tensor(0.0)  # Dummy value during warmup

      # evaluate loss and validation accuracy at intervals
      if step % eval_interval == 0:
          acc = evaluate(classifier)
          print(f"Step {step:4d} | Class Loss: {loss_c.item():.4f} | Sampler Loss: {loss_s.item():.4f}")
          print(f"Test Accuracy: {acc*100:.2f}%")

          loss_history.append(loss_c.item())
          s_loss_history.append(loss_s.item())
          accuracy_history.append(acc)


  # TODO: Append final evaluation if eval_interval is not zero

  return  accuracy_history, loss_history, s_loss_history

In [None]:
def run(startup_proportion, total_epochs):

  # Initialize models and optimizers
  sampler = SamplerMLP().to(device)
  optimizer_s = torch.optim.Adam(sampler.parameters(), lr=1e-4)

  classifier = ClassifierCNN().to(device)
  optimizer_c = torch.optim.Adam(classifier.parameters(), lr=3e-4)
  scheduler_c = torch.optim.lr_scheduler.OneCycleLR(optimizer_c, max_lr=3e-4, total_steps=online_epochs)

  criterion = nn.CrossEntropyLoss()

  # online training (start-up steps)
  online_epochs = startup_proportion * total_epochs
  accuracy_history, loss_history, s_loss_history = online_training(
      epochs, warmup_steps, sampler, optimizer_s,
      classifier, optimizer_c, scheduler_c, criterion)

    # online training (start-up steps)
  online_epochs = startup_proportion * total_epochs
  accuracy_history, loss_history, s_loss_history = online_training(
      epochs, warmup_steps, sampler, optimizer_s,
      classifier, optimizer_c, scheduler_c, criterion)


In [3]:
percentages = [round(0.1 * p, 2) for p in range(10)]

for percent in percentages:
  run(percent)

[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

## Experiment

Vary X start-up epochs and Y regular epochs.

In [None]:
# Initialize models and optimizers
sampler = SamplerMLP().to(device)
optimizer_s = torch.optim.Adam(sampler.parameters(), lr=1e-4)

classifier = ClassifierCNN().to(device)
optimizer_c = torch.optim.Adam(classifier.parameters(), lr=3e-4)
scheduler_c = torch.optim.lr_scheduler.OneCycleLR(optimizer_c, max_lr=3e-4, total_steps=online_epochs)

criterion = nn.CrossEntropyLoss()

### X - Startup Steps

For the first X epochs, our pipeline doesn't use any data on disk, but uses generated instances to start up the model.

In [None]:
# track epoch where we meet or surpass the target validation accuracy
target_epoch = 0
target_epoch_found = False

In [None]:
# Training loop

print("Starting training with generated instances...")
for step in range(1, online_epochs + 1):

    # Phase 1: Train classifier
    idx = torch.randint(0, 10, (batch_size,), device=device)
    batch_images = base_images[idx]
    batch_labels = base_labels[idx]
    noise = torch.randn(batch_size, noise_dim, device=device)

    with torch.no_grad():
        params = sampler(torch.cat([batch_images.view(batch_size, -1), noise], 1))

    aug_images = apply_transform(batch_images, params, step)

    optimizer_c.zero_grad()
    logits = classifier(aug_images)
    loss_c = criterion(logits, batch_labels)
    loss_c.backward()
    torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
    optimizer_c.step()
    scheduler_c.step()

    # Phase 2: Train sampler (after warmup)
    if step > warmup_steps:
        for p in classifier.parameters():
            p.requires_grad = False

        params = sampler(torch.cat([batch_images.view(batch_size, -1), noise], 1))
        aug_images = apply_transform(batch_images, params, step)
        logits = classifier(aug_images)
        loss_s = -criterion(logits, batch_labels)

        optimizer_s.zero_grad()
        loss_s.backward()
        torch.nn.utils.clip_grad_norm_(sampler.parameters(), 1.0)
        optimizer_s.step()

        for p in classifier.parameters():
            p.requires_grad = True
    else:
        loss_s = torch.tensor(0.0)  # Dummy value during warmup

    # Monitoring
    if step % 5 == 0:
        print(f"Step {step:4d} | Class Loss: {loss_c.item():.4f} | Sampler Loss: {loss_s.item() if step>warmup_steps else 0:.4f}")

    # evaluate validation accuracy at intervals
    if step % eval_interval == 0:
        acc = evaluate(classifier)
        print(f"Test Accuracy: {acc*100:.2f}%")
        # visualize_augmentations()

# Final evaluation
classifier.eval()
test_acc = evaluate(classifier)
print(f"\nFinal Test Accuracy: {test_acc*100:.2f}%")

Starting training with generated instances...
Test Accuracy: 16.07%
Test Accuracy: 16.12%
Test Accuracy: 16.25%
Test Accuracy: 16.59%
Step    5 | Class Loss: 2.2792 | Sampler Loss: 0.0000
Test Accuracy: 16.97%
Test Accuracy: 17.55%
Test Accuracy: 17.99%
Test Accuracy: 18.35%
Test Accuracy: 18.57%
Step   10 | Class Loss: 2.2947 | Sampler Loss: 0.0000
Test Accuracy: 18.45%
Test Accuracy: 17.87%
Test Accuracy: 17.37%
Test Accuracy: 17.60%
Test Accuracy: 16.84%
Step   15 | Class Loss: 2.2968 | Sampler Loss: 0.0000
Test Accuracy: 15.54%
Test Accuracy: 15.24%
Test Accuracy: 15.37%
Test Accuracy: 15.02%
Test Accuracy: 16.49%
Step   20 | Class Loss: 2.2703 | Sampler Loss: 0.0000
Test Accuracy: 17.84%
Test Accuracy: 19.03%
Test Accuracy: 20.12%
Test Accuracy: 20.50%
Test Accuracy: 20.40%
Step   25 | Class Loss: 2.2511 | Sampler Loss: 0.0000
Test Accuracy: 20.31%
Test Accuracy: 20.28%
Test Accuracy: 20.40%
Test Accuracy: 20.49%
Test Accuracy: 20.58%
Step   30 | Class Loss: 2.2459 | Sampler Loss:

### Y - Typical Supervision

For the remaining Y epochs use the typical supervised learning data/pipeline.

In [None]:
# Training loop (classic supervision)

# load the training subset of the MNIST, reset scheduler
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
scheduler_c = torch.optim.lr_scheduler.OneCycleLR(optimizer_c, max_lr=3e-4, total_steps=disk_epochs)

print("Starting the typical supervised training portion...")
for step in range(1, disk_epochs + 1):

    classifier.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_images, batch_labels in train_loader:

        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device)

        # Train just the classifier
        optimizer_c.zero_grad()
        logits = classifier(batch_images)
        loss_c = criterion(logits, batch_labels)
        loss_c.backward()

        torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
        optimizer_c.step()

        # evaluation metrics
        total_loss += loss_c.item() * batch_images.size(0)
        correct += (logits.argmax(dim=1) == batch_labels).sum().item()
        total += batch_images.size(0)

    scheduler_c.step()

    # Monitoring
    if step % 5 == 0:
        print(f"Step {step:4d} | Classifier Loss: {loss_c.item():.4f} | Calculated Loss: {total_loss/total:.4f}")

    # evaluate validation accuracy at intervals
    if step % eval_interval == 0:
        acc = evaluate(classifier)
        print(f"Test Accuracy: {acc*100:.2f}% | Validation Accuracy:  {correct/total*100:.2f}%")

        # monitor accuracy for when it crosses our target
        if acc >= target_accuracy and not target_epoch_found:
            target_epoch_found = True
            target_epoch = step

    if target_epoch_found:break

# Final evaluation
classifier.eval()
test_acc = evaluate(classifier)
print(f"\nFinal Test Accuracy: {test_acc*100:.2f}%")

Starting the typical supervised training portion...
Test Accuracy: 57.29% | Validation Accuracy:  43.47%
Test Accuracy: 67.47% | Validation Accuracy:  61.43%
Test Accuracy: 78.22% | Validation Accuracy:  72.98%
Test Accuracy: 83.29% | Validation Accuracy:  80.05%
Step    5 | Classifier Loss: 0.4735 | Calculated Loss: 0.6078
Test Accuracy: 86.44% | Validation Accuracy:  83.95%
Test Accuracy: 88.75% | Validation Accuracy:  86.66%
Test Accuracy: 90.28% | Validation Accuracy:  88.77%
Test Accuracy: 91.48% | Validation Accuracy:  90.30%
Test Accuracy: 92.44% | Validation Accuracy:  91.51%
Step   10 | Classifier Loss: 0.2528 | Calculated Loss: 0.2570
Test Accuracy: 93.16% | Validation Accuracy:  92.54%
Test Accuracy: 94.11% | Validation Accuracy:  93.50%
Test Accuracy: 94.99% | Validation Accuracy:  94.41%
Test Accuracy: 95.53% | Validation Accuracy:  95.00%
Test Accuracy: 96.14% | Validation Accuracy:  95.67%
Step   15 | Classifier Loss: 0.3246 | Calculated Loss: 0.1295
Test Accuracy: 96.50

## Control

X + Y (all) epochs on classic supervision

In [None]:
# Initialize model and optimizer
control_classifier = ClassifierCNN().to(device)
control_optimizer  = torch.optim.Adam(control_classifier.parameters(), lr=3e-4)
control_scheduler  = torch.optim.lr_scheduler.OneCycleLR(control_optimizer, max_lr=3e-4, total_steps=epochs)

In [None]:
# track epoch where we meet or surpass the target validation accuracy
control_target_epoch = 0
target_epoch_found = False

In [None]:
# Train

# load the training subset of the MNIST
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

print("Starting classic supervised training...")
for step in range(1, epochs + 1):

    control_classifier.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_images, batch_labels in train_loader:

        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device)

        # Train just the classifier
        control_optimizer.zero_grad()
        logits = control_classifier(batch_images)
        loss = criterion(logits, batch_labels)
        loss .backward()

        torch.nn.utils.clip_grad_norm_(control_classifier.parameters(), 1.0)
        control_optimizer.step()

        # evaluation metrics
        total_loss += loss.item() * batch_images.size(0)
        correct += (logits.argmax(dim=1) == batch_labels).sum().item()
        total += batch_images.size(0)

    control_scheduler.step()

    # Monitoring
    if step % 5 == 0:
        print(f"Step {step:4d} | Classifier Loss: {loss.item():.4f} | Calculated Loss: {total_loss/total:.4f}")

    # evaluate validation accuracy at intervals
    if step % eval_interval == 0:
        acc = evaluate(control_classifier)
        print(f"Test Accuracy: {acc*100:.2f}% | Validation Accuracy:  {correct/total*100:.2f}%")

        # monitor accuracy for when it crosses our target
        if acc >= target_accuracy and not target_epoch_found:
            target_epoch_found = True
            control_target_epoch = step

    if target_epoch_found:
      break

# Final evaluation
control_classifier.eval()
test_acc = evaluate(control_classifier)
print(f"\nFinal Test Accuracy: {test_acc*100:.2f}%")

### Comparisons

In [None]:
target_epoch, control_target_epoch