#### imports

In [0]:
import torch
import os
from  PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

# MLflow imports - these are essential for tracking
import mlflow
import mlflow.pytorch
import mlflow.sklearn  # For potential metric logging
from mlflow.models.signature import infer_signature
from mlflow.utils.environment import _mlflow_conda_env

mlflow.set_tracking_uri("databricks")

In [0]:
image_dir ='/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/train'
for label, class_dir in enumerate(os.listdir(image_dir)):
    print(label,class_dir)

In [0]:
class generate_image_dataset(Dataset):
  def __init__(self,image_dir, transform=None):
    self.image_dir=image_dir
    self.image_paths=[]
    self.labels=[]
    self.class_name={}
    self.transform = transform

    for label, class_dir in enumerate(os.listdir(image_dir)):
        self.class_name[label] =class_dir
        class_path = os.path.join(image_dir,class_dir)
        for img_name in os.listdir(class_path):
          self.image_paths.append(os.path.join(class_path,img_name))
          self.labels.append(label)

# Log dataset statistics - this helps other data scientists understand your data
        self.dataset_stats = {
            'total_images': len(self.image_paths),
            'num_classes': len(self.class_name),
            'class_distribution': {self.class_name[i]: self.labels.count(i) for i in self.class_name.keys()},
            'class_names': list(self.class_name.values())
        }

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self,idx):
    image_path = self.image_paths[idx]
    image = Image.open(image_path).convert('RGB')     
    label = self.labels[idx]             

    if self.transform:
        image = self.transform(image)     

    return image, label
  
  def get_stats(self):
        """Return dataset statistics for MLflow logging"""
        return self.dataset_stats

In [0]:
# transform = transforms.Compose([
#     transforms.Resize((128,128)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5,0.5])]
#     )

# Define transforms with detailed configuration for logging, define the confog separately
transform_config = {
    'resize': (128, 128),
    'normalize_mean': [0.5, 0.5, 0.5],
    'normalize_std': [0.5, 0.5, 0.5]
}

transform = transforms.Compose([
    transforms.Resize(transform_config['resize']),
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_config['normalize_mean'], 
                        std=transform_config['normalize_std'])
])



train_image_dir ='/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/train'
test_image_dir ='/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/test'

# Create datasets
training_image_dataset = generate_image_dataset(image_dir=train_image_dir,transform=transform)
test_image_dataset = generate_image_dataset(image_dir=test_image_dir,transform=transform)

# Create data loaders with configuration
batch_size = 32
train_image_loader = DataLoader(dataset=training_image_dataset, batch_size=batch_size, shuffle=True)
test_image_loader = DataLoader(dataset=test_image_dataset, batch_size=batch_size, shuffle=False)

# Prepare configuration dictionary for MLflow logging
data_config = {
    'train_data_path': train_image_dir,
    'test_data_path': test_image_dir,
    'batch_size': batch_size,
    'transform_config': transform_config,
    'train_dataset_stats': training_image_dataset.get_stats(),
    'test_dataset_stats': test_image_dataset.get_stats()
}

print("Dataset loaded successfully!")
print(f"Training samples: {len(training_image_dataset)}")
print(f"Test samples: {len(test_image_dataset)}")
print(f"Classes: {training_image_dataset.class_name}")

In [0]:
for images, labels in train_image_loader:
    print(images.shape, labels.shape)

#### Sample Visualization with Artifact Logging

In [0]:
import matplotlib.pyplot as plt
import numpy as np

def create_sample_visualization(data_loader, dataset, save_path="sample_images.png"):
    """Create and save sample images for MLflow artifact logging"""
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    fig.suptitle('Sample Training Images', fontsize=16)
    
    # Get one batch
    images, labels = next(iter(data_loader))
    
    for i in range(8):
        row = i // 4
        col = i % 4
        
        img = images[i].numpy()
        label = labels[i].item()
        
        # Convert from tensor format to displayable format
        img = np.transpose(img, (1, 2, 0))
        # Denormalize for display
        img = img * 0.5 + 0.5  # Reverse the normalization
        img = np.clip(img, 0, 1)
        
        axes[row, col].imshow(img)
        axes[row, col].set_title(f'{dataset.class_name[label]}')
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    return save_path

# Create visualization
sample_viz_path = create_sample_visualization(train_image_loader, training_image_dataset)

#### improved model definition for better integration with mlflow

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim

class CustomCNNModel(nn.Module):
    """
    Custom CNN model for image classification
    Enhanced with proper documentation for MLflow tracking
    """
    def __init__(self, input_dim, num_classes):
        super(CustomCNNModel, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        
        # Define architecture configuration for logging
        self.architecture_config = {
            'input_dim': input_dim,
            'num_classes': num_classes,
            'conv_layers': [
                {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
                {'out_channels': 64, 'kernel_size': 3, 'stride': 1, 'padding': 1},
                {'out_channels': 128, 'kernel_size': 3, 'stride': 1, 'padding': 1}
            ],
            'fc_layers': [512, 128, num_classes]
        }
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),    
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),    
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),    
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self._to_linear = None
        self.get_conv_output(self.input_dim)
        
        self.fc_layers = nn.Sequential(
            nn.Linear(self._to_linear, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def get_conv_output(self, input_dim=128):
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, input_dim, input_dim)
            output = self.conv_layers(dummy_input)
            self._to_linear = output.view(output.size(0), -1).size(1)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x
    
    def get_architecture_config(self):
        """Return model architecture configuration for logging"""
        return self.architecture_config

# Model initialization remains the same
device = torch.device("cpu")
model = CustomCNNModel(input_dim=128, num_classes=3).to(device)

print("Model initialized successfully!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## training with mlflow tracking

In [0]:
import os
import mlflow

# Get the full path of the current notebook (e.g., "/Users/your.username/Folder/MyNotebook")
notebook_path = dbutils.entry_point.getDbutils().notebook().getContext().notebookPath().get()

# Extract the folder path (e.g., "/Users/your.username/Folder")
folder_path = os.path.dirname(notebook_path)

# the folder for tracking the experiments cannot exist within the current folder as it is part of a it repo 
experiment_path = f"/Workspace/sid-v2/experiment_tracker"
experiment_name = experiment_path + "/3class_cnn_classifier"

# Create the experiment if it does not exist
if not mlflow.get_experiment_by_name(experiment_name):
    mlflow.create_experiment(experiment_name)

mlflow.set_experiment(experiment_name)

In [0]:
def train_model_with_mlflow():
    """
    Train the model with comprehensive MLflow tracking
    This is the core function that implements all MLflow best practices
    """  

    
    # Start MLflow run
    with mlflow.start_run(run_name=f"cnn_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}") as run:
        
        # Enable autologging for PyTorch (captures model summary, parameters, etc.)
        mlflow.pytorch.autolog(log_models=False)  # We'll manually log the model for better control
        
        # Log comprehensive parameters
        training_params = {
            'model_type': 'CustomCNN',
            'optimizer': 'Adam',
            'learning_rate': 0.001,
            'batch_size': batch_size,
            'epochs': 2,
            'device': str(device),
            'input_size': 128,
            'num_classes': 3
        }
        
        # Log all parameters
        mlflow.log_params(training_params)
        mlflow.log_params(transform_config)
        mlflow.log_params(model.get_architecture_config())
        
        # Log dataset information
        mlflow.log_params(data_config['train_dataset_stats'])
        mlflow.log_dict(data_config, "data_configuration.json")
        
        # Set up training
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=training_params['learning_rate'])
        epochs = training_params['epochs']
        
        # Log sample images as artifact
        mlflow.log_artifact(sample_viz_path, "visualizations")
        
        # Training loop with detailed logging
        training_history = {
            'epoch_losses': [],
            'batch_losses': [],
            'learning_rates': []
        }
        
        model.train()
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            batch_losses = []
            
            # Log learning rate for this epoch
            current_lr = optimizer.param_groups[0]['lr']
            mlflow.log_metric("learning_rate", current_lr, step=epoch)
            training_history['learning_rates'].append(current_lr)
            
            for batch_idx, (images, labels) in enumerate(train_image_loader):
                images, labels = images.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                batch_loss = loss.item()
                epoch_loss += batch_loss
                batch_losses.append(batch_loss)
                
                # Log batch loss every 10 batches
                if batch_idx % 10 == 0:
                    step = epoch * len(train_image_loader) + batch_idx
                    mlflow.log_metric("batch_loss", batch_loss, step=step)
            
            # Calculate and log epoch metrics
            avg_epoch_loss = epoch_loss / len(train_image_loader)
            training_history['epoch_losses'].append(avg_epoch_loss)
            training_history['batch_losses'].extend(batch_losses)
            
            # Log epoch metrics
            mlflow.log_metric("epoch_loss", avg_epoch_loss, step=epoch)
            mlflow.log_metric("epoch", epoch + 1, step=epoch)
            
            print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_epoch_loss:.4f}")
        
        # Evaluate model on test set
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        class_correct = {i: 0 for i in range(3)}
        class_total = {i: 0 for i in range(3)}
        
        with torch.no_grad():
            for images, labels in test_image_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                # Class-wise accuracy
                for i in range(labels.size(0)):
                    label = labels[i].item()
                    class_correct[label] += (predicted[i] == labels[i]).item()
                    class_total[label] += 1
        
        # Calculate final metrics
        test_accuracy = 100 * correct / total
        avg_test_loss = test_loss / len(test_image_loader)
        
        # Log final metrics
        mlflow.log_metric("test_accuracy", test_accuracy)
        mlflow.log_metric("test_loss", avg_test_loss)
        mlflow.log_metric("final_train_loss", training_history['epoch_losses'][-1])
        
        # Log class-wise accuracy
        for class_idx in range(3):
            if class_total[class_idx] > 0:
                class_acc = 100 * class_correct[class_idx] / class_total[class_idx]
                class_name = training_image_dataset.class_name[class_idx]
                mlflow.log_metric(f"accuracy_{class_name}", class_acc)
        
        # Create and log training history plot
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(training_history['epoch_losses'])
        plt.title('Training Loss per Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        
        plt.subplot(1, 2, 2)
        plt.plot(training_history['batch_losses'])
        plt.title('Training Loss per Batch')
        plt.xlabel('Batch')
        plt.ylabel('Loss')
        plt.grid(True)
        
        plt.tight_layout()
        training_plot_path = "training_history.png"
        plt.savefig(training_plot_path, dpi=150, bbox_inches='tight')
        plt.show()
        
        # Log training history as artifact
        mlflow.log_artifact(training_plot_path, "plots")
        
        # Save and log training history as JSON
        history_path = "training_history.json"
        with open(history_path, 'w') as f:
            json.dump(training_history, f, indent=2)
        mlflow.log_artifact(history_path, "metrics")
        
        # Log model with signature for deployment
        # Create sample input for signature inference
        sample_input = torch.randn(1, 3, 128, 128)
        with torch.no_grad():
            sample_output = model(sample_input)
        
        # Infer signature
        signature = infer_signature(
            sample_input.numpy(), 
            sample_output.numpy()
        )
        
        # Log the model
        model_info = mlflow.pytorch.log_model(
            pytorch_model=model,
            artifact_path="model",
            signature=signature,
            input_example=sample_input.numpy(),
            registered_model_name="image_classification_cnn"  # This registers to Unity Catalog
        )
        
        # Log additional model artifacts
        torch.save(model.state_dict(), "model_state_dict.pth")
        mlflow.log_artifact("model_state_dict.pth", "model_files")
        
        # Log run summary
        run_summary = {
            'run_id': run.info.run_id,
            'model_uri': model_info.model_uri,
            'final_test_accuracy': test_accuracy,
            'final_train_loss': training_history['epoch_losses'][-1],
            'total_parameters': sum(p.numel() for p in model.parameters()),
            'training_time': datetime.now().isoformat()
        }
        
        mlflow.log_dict(run_summary, "run_summary.json")
        
        print("Training completed!")
        print(f"Final Test Accuracy: {test_accuracy:.2f}%")
        print(f"Model registered as: image_classification_cnn")
        print(f"Run ID: {run.info.run_id}")
        
        return model, run_summary

# Execute training
trained_model, run_info = train_model_with_mlflow()