In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os

# Set up parameters
n_input = 784
n_dense = 256

try:
    import wandb
    _wandb_ok = bool(os.environ.get("WANDB_API_KEY"))
except ImportError:
    wandb = None
    _wandb_ok = False

# Custom weight and bias initializers
class RandomNormalInitializer:
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return nn.init.normal_(tensor, mean=self.mean, std=self.std)

class ZerosInitializer:
    def __call__(self, tensor):
        return nn.init.zeros_(tensor)

class GlorotNormalInitializer:
    def __call__(self, tensor):
        return nn.init.xavier_normal_(tensor)

class GlorotUniformInitializer:
    def __call__(self, tensor):
        return nn.init.xavier_uniform_(tensor)

class HeNormalInitializer:
    def __call__(self, tensor):
        return nn.init.kaiming_normal_(tensor, nonlinearity='relu')

class HeUniformInitializer:
    def __call__(self, tensor):
        return nn.init.kaiming_uniform_(tensor, nonlinearity='relu')

# Create a simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self, n_input, n_dense, w_init, b_init):
        super(SimpleMLP, self).__init__()
        self.fc = nn.Linear(n_input, n_dense)
        # Initialize weights and biases
        w_init(self.fc.weight)
        b_init(self.fc.bias)
        self.activation = nn.ReLU() #nn.Sigmoid()  # You can change to Tanh or ReLU if needed

    def forward(self, x):
        x = self.fc(x)
        x = self.activation(x)
        return x

# Initialize the model
w_init =  HeNormalInitializer() #RandomNormalInitializer(std=1.0)  # Replace with desired initializer
b_init = ZerosInitializer()
model = SimpleMLP(n_input, n_dense, w_init, b_init)

# Generate random input values
x = torch.randn((1, n_input))

# Forward propagate through the network
a = model(x)

x_np = x.detach().numpy()  # Convert to numpy for plotting
_ = plt.hist(x_np.T)
plt.title("Input Distribution")
plt.xlabel("Output Value")
plt.ylabel("Frequency")
plt.show()


In [None]:
# Plot the output
a_np = a.detach().numpy()  # Convert to numpy for plotting
_ = plt.hist(a_np.T)
plt.title("Output Distribution")
plt.xlabel("Output Value")
plt.ylabel("Frequency")
plt.show()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm

# Define a custom Batch Normalization layer
class CustomBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(CustomBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # Calculate batch mean and variance
            batch_mean = x.mean(dim=[0, 2, 3], keepdim=True)
            batch_var = x.var(dim=[0, 2, 3], keepdim=True, unbiased=False)
            # Normalize
            x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
            # Scale and shift
            out = self.gamma.view(1, -1, 1, 1) * x_hat + self.beta.view(1, -1, 1, 1)
            # Update running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean.view(-1)
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var.view(-1)
        else:
            # Use running mean and variance during inference
            x_hat = (x - self.running_mean.view(1, -1, 1, 1)) / torch.sqrt(self.running_var.view(1, -1, 1, 1) + self.eps)
            out = self.gamma.view(1, -1, 1, 1) * x_hat + self.beta.view(1, -1, 1, 1)
        return out

# Define a simple CNN model with custom Batch Normalization
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.bn1 = CustomBatchNorm(10)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.bn2 = CustomBatchNorm(20)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = self.bn1(x)
        x = self.dropout(x)
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = self.bn2(x)
        x = self.dropout(x)
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# Set up training parameters
batch_size = 64
learning_rate = 0.01
weight_decay = 1e-4  # L2 regularization parameter
patience = 20  # Early stopping patience

# Load the dataset
train_dataset = datasets.MNIST('/tmp/data', train=True, download=True, transform=transforms.ToTensor())
train_data, val_data = train_test_split(train_dataset, test_size=0.2, random_state=42)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

# Check if GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model, loss function, and optimizer
model = SimpleCNN().to(device)
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Init W&B run
_wb_run = None
if _wandb_ok and wandb is not None:
    try:
        _wb_run = wandb.init(
            settings=wandb.Settings(init_timeout=120),
            project="eng-ai-agents",
            entity="pantelis",
            id="train-batch-norm-mnist",
            resume="allow",
            name="batch-norm-mnist",
            group="optimization",
            tags=["optimization", "batch-normalization"],
            job_type="training",
            config={
                "batch_size": batch_size,
                "learning_rate": learning_rate,
                "weight_decay": weight_decay,
                "patience": patience,
                "num_epochs": 100,
            },
        )
    except Exception as e:
        print(f"W&B init failed (non-fatal): {e}")
        _wb_run = None

# Training loop with Early Stopping and TQDM for Epoch Progress
num_epochs = 100
train_losses = []
val_losses = []
min_val_loss = np.inf
patience_counter = 0

try:
    for epoch in tqdm(range(num_epochs), desc="Epoch Progress", position=0):
        model.train()
        total_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()  # Zero the gradients
            output = model(data)  # Forward pass
            loss = criterion(output, target)  # Compute the loss
            loss.backward()  # Backpropagate the gradients
            optimizer.step()  # Update the weights
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f'Epoch {epoch + 1}: Train Loss: {avg_train_loss:.6f}')

        # Validation loss
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f'Epoch {epoch + 1}: Validation Loss: {avg_val_loss:.6f}')

        # Log to W&B
        if _wb_run is not None:
            _wb_run.log({
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss,
                "epoch": epoch,
            })

        # Early stopping check
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping triggered after {epoch + 1} epochs.')
                break

    # Load the best model state (if early stopping was triggered)
    model.load_state_dict(best_model_state)

    # Log final summary to W&B
    if _wb_run is not None:
        _wb_run.summary["best_val_loss"] = min_val_loss
        _wb_run.summary["epochs_completed"] = len(train_losses)
        _wb_run.summary["early_stopped"] = patience_counter >= patience
finally:
    if _wb_run is not None:
        _wb_run.finish()

# Plotting Train and Validation Loss vs Epochs
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train and Validation Loss vs Epochs with Dropout, Batch Normalization, and Early Stopping')
plt.legend()
plt.grid(True)
plt.show()