In [7]:
#!/usr/bin/env python3
"""
Complete training script for Fashion-MNIST classification using MobileNetV2.
This script handles data preparation, model training, and model export to multiple formats.
"""

import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import torch.onnx
import numpy as np
from torchvision import models, datasets, transforms
from PIL import Image
import json
import sys
from tqdm import tqdm
import argparse

# Memory management functions
def clear_memory():
    """Clear GPU and CPU memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def setup_environment():
    """Configure PyTorch environment for training"""
    # Configure PyTorch to be more memory-efficient
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available():
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64'
    
    # Set up directories
    os.makedirs('./data', exist_ok=True)
    os.makedirs('./models', exist_ok=True)

def prepare_data(batch_size=16):
    """Download and prepare the Fashion-MNIST dataset"""
    print("Setting up data preparation...")
    
    # Define data transformations
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    try:
        print("Downloading Fashion-MNIST dataset...")
        train_data = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
        test_data = datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)
        clear_memory()

        # Create class mapping
        class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                      'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

        with open("index_to_name.json", "w") as f:
            json.dump({str(i): name for i, name in enumerate(class_names)}, f)

        print(f"Dataset downloaded with {len(train_data)} training and {len(test_data)} test samples.")
        print(f"Class mapping created with {len(class_names)} classes.")

        # Create data loaders
        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0)

        return train_loader, test_loader, class_names

    except Exception as e:
        print(f"Error downloading dataset: {e}")
        raise

class LightweightModel(nn.Module):
    """MobileNetV2 model adapted for Fashion-MNIST"""
    
    def __init__(self, num_classes=10):
        super(LightweightModel, self).__init__()
        # Use MobileNetV2 pretrained on ImageNet
        self.model = models.mobilenet_v2(pretrained=True)
        
        # Modify first layer to accept grayscale images (1 channel)
        self.model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
        
        # Modify classifier to output correct number of classes
        self.model.classifier[1] = nn.Linear(1280, num_classes)

    def forward(self, x):
        return self.model(x)

def train_model(model, train_loader, test_loader, epochs=2, save_path="mobilenet_fashion_mnist.pth"):
    """Train the model and evaluate on test data"""
    
    # Determine device (GPU or CPU)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on: {device}")

    # Move model to device
    model = model.to(device)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0005)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)

    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Update statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            if i % 50 == 49:
                print(f'Epoch {epoch+1}, Batch {i+1}: Loss: {running_loss/50:.3f}, Accuracy: {100*correct/total:.2f}%')
                running_loss = 0.0
            
            # Periodically clear memory to avoid OOM errors
            if i % 10 == 9:
                clear_memory()

        # Step the scheduler
        scheduler.step()

        # Evaluation phase
        model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                
                # Update statistics
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # Print epoch summary
        print(f'Validation Accuracy after Epoch {epoch+1}: {100*val_correct/val_total:.2f}%')
        
        # Clear memory after each epoch
        clear_memory()
    
    # Save the trained model
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")
    
    return model

def export_for_torchserve(model, output_path="mobilenet_fashion_mnist.pt"):
    """Export model for TorchServe"""
    model.eval()
    example_input = torch.randn(1, 1, 128, 128)
    scripted_model = torch.jit.trace(model, example_input)
    torch.jit.save(scripted_model, output_path)
    print(f"Model exported for TorchServe to {output_path}")
    return output_path

def export_to_onnx(model, output_path="mobilenet_fashion_mnist.onnx"):
    """Export model to ONNX format"""
    model.eval()
    example_input = torch.randn(1, 1, 128, 128)
    
    torch.onnx.export(
        model,
        example_input,
        output_path,
        export_params=True,
        opset_version=12,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    
    # Verify ONNX model
    try:
        import onnx
        onnx_model = onnx.load(output_path)
        onnx.checker.check_model(onnx_model)
        print(f"Model exported to ONNX format at {output_path} and verified successfully.")
        return output_path, onnx_model
    except ImportError:
        print(f"Model exported to ONNX format at {output_path}, but couldn't verify (onnx package not found).")
        return output_path, None
    except Exception as e:
        print(f"Model exported to ONNX format at {output_path}, but verification failed: {e}")
        return output_path, None

def direct_tensorflow_export(model, class_names, output_dir="tensorflow_model"):
    """Export to TensorFlow using Keras directly"""
    try:
        # Check if tensorflow is installed
        try:
            import tensorflow as tf
            print(f"TensorFlow version: {tf.__version__}")
        except ImportError:
            print("TensorFlow not installed. Installing...")
            import subprocess
            subprocess.check_call([sys.executable, "-m", "pip", "install", "tensorflow"])
            import tensorflow as tf
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        saved_model_path = os.path.join(output_dir, "saved_model")
        
        print("Creating TensorFlow model directly...")
        
        # Extract PyTorch model weights
        model.eval()
        state_dict = model.state_dict()
        
        # Create a Keras model with similar architecture to MobileNetV2
        tf_model = tf.keras.Sequential([
            # Input layer expecting NHWC format (TensorFlow's default)
            tf.keras.layers.InputLayer(input_shape=(128, 128, 1)),
            
            # First convolutional block (similar to MobileNetV2 first layer)
            tf.keras.layers.Conv2D(32, kernel_size=3, strides=2, padding='same', use_bias=False),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            
            # Several inverted residual blocks would go here in a full implementation
            # For simplicity, we're using a more basic structure
            tf.keras.layers.Conv2D(64, kernel_size=3, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPooling2D(pool_size=2),
            
            tf.keras.layers.Conv2D(128, kernel_size=3, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPooling2D(pool_size=2),
            
            # Global pooling and final classifier
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(10)  # 10 classes for Fashion-MNIST
        ])
        
        # Compile the model
        tf_model.compile(
            optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy']
        )
        
        # Build model (needed to create weights)
        tf_model.build((None, 128, 128, 1))
        
        # Print model summary
        tf_model.summary()
        
        # Save the TensorFlow model
        tf_model.save(saved_model_path)
        print(f"TensorFlow model saved to {saved_model_path}")
        
        # Create a test script for the TensorFlow model
        test_script = """
import tensorflow as tf
import numpy as np
from PIL import Image
import json

def preprocess_image(image_path):
    # Load and preprocess image
    image = Image.open(image_path).convert('L')
    image = image.resize((128, 128))
    image = np.array(image).astype(np.float32) / 255.0
    image = (image - 0.5) / 0.5  # normalize to [-1, 1]
    # Return in NHWC format (batch, height, width, channels)
    return image.reshape(1, 128, 128, 1)

def predict(model_path, image_path):
    # Load model
    model = tf.keras.models.load_model(model_path)
    
    # Preprocess image
    input_tensor = preprocess_image(image_path)
    
    # Predict
    prediction = model.predict(input_tensor)
    
    # Get class with highest probability
    predicted_class = np.argmax(prediction, axis=1)[0]
    
    # Load class names
    with open("index_to_name.json", "r") as f:
        class_names = json.load(f)
    
    print(f"Predicted class: {{class_names[str(predicted_class)]}}") 
    return predicted_class, class_names[str(predicted_class)]

if __name__ == "__main__":
    import sys
    if len(sys.argv) > 1:
        image_path = sys.argv[1]
        model_path = "{0}"
        predict(model_path, image_path)
    else:
        print("Please provide an image path to test")
""".format(saved_model_path)
        
        with open(os.path.join(output_dir, "test_tf_model.py"), "w") as f:
            f.write(test_script)
        
        print(f"Test script created at {os.path.join(output_dir, 'test_tf_model.py')}")
        
        # Note about weights
        print("NOTE: This TensorFlow model has a similar architecture to the PyTorch model")
        print("but does not contain the exact trained weights. It uses randomly initialized weights.")
        print("For a full conversion with matching weights, specialized conversion would be needed.")
        
        return saved_model_path
        
    except Exception as e:
        print(f"Error creating TensorFlow model: {e}")
        import traceback
        traceback.print_exc()
        return None

def create_tensorflow_from_scratch(class_names, output_dir="tensorflow_model"):
    """Create a TensorFlow model from scratch and train it directly on Fashion-MNIST"""
    try:
        # Check if tensorflow is installed
        try:
            import tensorflow as tf
            print(f"TensorFlow version: {tf.__version__}")
        except ImportError:
            print("TensorFlow not installed. Installing...")
            import subprocess
            subprocess.check_call([sys.executable, "-m", "pip", "install", "tensorflow"])
            import tensorflow as tf
        
        print("Creating and training a TensorFlow model directly on Fashion-MNIST...")
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        saved_model_path = os.path.join(output_dir, "saved_model")
        
        # Load Fashion-MNIST dataset directly in TensorFlow
        fashion_mnist = tf.keras.datasets.fashion_mnist
        (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
        
        # Normalize and reshape images
        train_images = train_images.astype('float32') / 255.0
        test_images = test_images.astype('float32') / 255.0
        
        # Normalize to [-1, 1] to match PyTorch preprocessing
        train_images = (train_images - 0.5) / 0.5
        test_images = (test_images - 0.5) / 0.5
        
        # Reshape to include channel dimension (NHWC format)
        train_images = train_images.reshape((-1, 28, 28, 1))
        test_images = test_images.reshape((-1, 28, 28, 1))
        
        # Resize images to 128x128 to match the PyTorch model's input size
        def resize_images(images):
            resized = []
            for img in images:
                img = tf.image.resize(img, [128, 128])
                resized.append(img)
            return np.stack(resized)
        
        print("Resizing training images...")
        train_images_resized = resize_images(train_images[:5000])  # Use subset for faster training
        print("Resizing test images...")
        test_images_resized = resize_images(test_images[:1000])    # Use subset for faster evaluation
        
        # Create a simple CNN model
        model = tf.keras.Sequential([
            # Input layer expecting NHWC format (TensorFlow's default)
            tf.keras.layers.InputLayer(input_shape=(128, 128, 1)),
            
            # Convolutional layers
            tf.keras.layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2),
            
            tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2),
            
            tf.keras.layers.Conv2D(128, kernel_size=3, padding='same', activation='relu'),
            tf.keras.layers.MaxPooling2D(pool_size=2),
            
            # Flatten and dense layers
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.5),
            tf.keras.layers.Dense(10)  # 10 classes for Fashion-MNIST
        ])
        
        # Compile the model
        model.compile(
            optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy']
        )
        
        # Train the model
        print("Training TensorFlow model...")
        model.fit(
            train_images_resized, 
            train_labels[:5000],  # Use same subset as the resized images
            epochs=2,
            batch_size=32,
            validation_data=(test_images_resized, test_labels[:1000])
        )
        
        # Evaluate the model
        test_loss, test_acc = model.evaluate(test_images_resized, test_labels[:1000])
        print(f"Test accuracy: {test_acc:.2f}")
        
        # Save the model
        model.save(saved_model_path)
        print(f"TensorFlow model saved to {saved_model_path}")
        
        # Create a test script for the TensorFlow model
        test_script = """
import tensorflow as tf
import numpy as np
from PIL import Image
import json

def preprocess_image(image_path):
    # Load and preprocess image
    image = Image.open(image_path).convert('L')
    image = image.resize((128, 128))
    image = np.array(image).astype(np.float32) / 255.0
    image = (image - 0.5) / 0.5  # normalize to [-1, 1]
    # Return in NHWC format (batch, height, width, channels)
    return image.reshape(1, 128, 128, 1)

def predict(model_path, image_path):
    # Load model
    model = tf.keras.models.load_model(model_path)
    
    # Preprocess image
    input_tensor = preprocess_image(image_path)
    
    # Predict
    prediction = model.predict(input_tensor)
    
    # Get class with highest probability
    predicted_class = np.argmax(prediction, axis=1)[0]
    
    # Load class names
    with open("index_to_name.json", "r") as f:
        class_names = json.load(f)
    
    print(f"Predicted class: {{class_names[str(predicted_class)]}}") 
    return predicted_class, class_names[str(predicted_class)]

if __name__ == "__main__":
    import sys
    if len(sys.argv) > 1:
        image_path = sys.argv[1]
        model_path = "{0}"
        predict(model_path, image_path)
    else:
        print("Please provide an image path to test")
""".format(saved_model_path)
        
        with open(os.path.join(output_dir, "test_tf_model.py"), "w") as f:
            f.write(test_script)
        
        print(f"Test script created at {os.path.join(output_dir, 'test_tf_model.py')}")
        
        return saved_model_path
        
    except Exception as e:
        print(f"Error creating and training TensorFlow model: {e}")
        import traceback
        traceback.print_exc()
        return None

def main():
    parser = argparse.ArgumentParser(description="Train a MobileNetV2 model on Fashion-MNIST and export to multiple formats")
    parser.add_argument("--epochs", type=int, default=2, help="Number of epochs to train")
    parser.add_argument("--batch-size", type=int, default=16, help="Batch size for training")
    parser.add_argument("--output-dir", type=str, default=".", help="Directory to save models")
    parser.add_argument("--skip-tf", action="store_true", help="Skip TensorFlow export")
    parser.add_argument("--tf-method", type=str, default="direct", choices=["direct", "from_scratch"], 
                        help="Method for TensorFlow export: 'direct' creates a similar architecture, 'from_scratch' trains a new model")
    args = parser.parse_args()
    
    try:
        # Setup environment
        setup_environment()
        
        # Prepare data
        train_loader, test_loader, class_names = prepare_data(batch_size=args.batch_size)
        
        # Create and train model
        print("Initializing model...")
        model = LightweightModel()
        
        print(f"Starting model training for {args.epochs} epochs...")
        trained_model = train_model(
            model, 
            train_loader, 
            test_loader, 
            epochs=args.epochs, 
            save_path=os.path.join(args.output_dir, "mobilenet_fashion_mnist.pth")
        )
        
        # Export model in different formats
        print("Exporting model to different formats...")
        
        # 1. Export for TorchServe (TorchScript format)
        torchscript_path = export_for_torchserve(
            trained_model, 
            output_path=os.path.join(args.output_dir, "mobilenet_fashion_mnist.pt")
        )
        
        # 2. Export to ONNX format
        onnx_path, _ = export_to_onnx(
            trained_model, 
            output_path=os.path.join(args.output_dir, "mobilenet_fashion_mnist.onnx")
        )
        
        # 3. Handle TensorFlow export
        tf_saved_model_path = None
        if not args.skip_tf:
            if args.tf_method == "direct":
                print("Creating TensorFlow model with similar architecture...")
                tf_saved_model_path = direct_tensorflow_export(
                    trained_model,
                    class_names,  # Pass class_names here
                    output_dir=os.path.join(args.output_dir, "tensorflow_model")
                )
            else:  # from_scratch
                print("Creating and training a TensorFlow model from scratch...")
                tf_saved_model_path = create_tensorflow_from_scratch(
                    class_names,
                    output_dir=os.path.join(args.output_dir, "tensorflow_model")
                )
        else:
            print("Skipping TensorFlow conversion as requested.")
        
        print("Training and export completed successfully!")
        print(f"Model saved in the following formats:")
        print(f"- PyTorch: {os.path.join(args.output_dir, 'mobilenet_fashion_mnist.pth')}")
        print(f"- TorchScript: {torchscript_path}")
        print(f"- ONNX: {onnx_path}")
        if tf_saved_model_path:
            if args.tf_method == "direct":
                print(f"- TensorFlow (similar architecture): {tf_saved_model_path}")
            else:  # from_scratch
                print(f"- TensorFlow (trained from scratch): {tf_saved_model_path}")
        
    except Exception as e:
        print(f"Error during training or export: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()


Setting up data preparation...
Downloading Fashion-MNIST dataset...
Defining lightweight model...
Starting model training...
Training on device: cpu




Epoch 1, Batch 100: Loss: 1.624, Accuracy: 47.50%
Epoch 1, Batch 200: Loss: 1.398, Accuracy: 50.25%
Epoch 1, Batch 300: Loss: 1.204, Accuracy: 54.08%
Epoch 1, Batch 400: Loss: 0.974, Accuracy: 57.75%
Epoch 1, Batch 500: Loss: 0.954, Accuracy: 60.00%
Epoch 1, Batch 600: Loss: 1.061, Accuracy: 61.08%
Epoch 1, Batch 700: Loss: 0.909, Accuracy: 62.21%
Epoch 1, Batch 800: Loss: 0.877, Accuracy: 63.56%
Epoch 1, Batch 900: Loss: 0.862, Accuracy: 64.47%
Epoch 1, Batch 1000: Loss: 0.923, Accuracy: 65.22%
Epoch 1, Batch 1100: Loss: 0.880, Accuracy: 65.93%
Epoch 1, Batch 1200: Loss: 0.783, Accuracy: 66.56%
Epoch 1, Batch 1300: Loss: 0.773, Accuracy: 67.02%
Epoch 1, Batch 1400: Loss: 0.820, Accuracy: 67.48%
Epoch 1, Batch 1500: Loss: 0.791, Accuracy: 68.18%
Epoch 1, Batch 1600: Loss: 0.621, Accuracy: 69.03%
Epoch 1, Batch 1700: Loss: 0.711, Accuracy: 69.56%
Epoch 1, Batch 1800: Loss: 0.766, Accuracy: 69.97%
Epoch 1, Batch 1900: Loss: 0.626, Accuracy: 70.57%
Epoch 1, Batch 2000: Loss: 0.636, Accura