In [None]:
# Imports and environment setup
import os
import random
from PIL import Image, ImageFile
import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torchvision import transforms, models
from torchvision.models import MobileNet_V2_Weights

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, precision_score, recall_score, f1_score, classification_report
)

ImageFile.LOAD_TRUNCATED_IMAGES = True


# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)


# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if torch.cuda.is_available():

    # GPU info
    gpu_name = torch.cuda.get_device_name(0)
    props = torch.cuda.get_device_properties(0)
    total_mem_gb = props.total_memory / (1024**3)

    print(f"GPU: {gpu_name}")
    print(f"Total memory: {total_mem_gb:.1f} GB")
    print(f"SM count: {props.multi_processor_count}")
    print(f"Compute Capability (SM): {props.major}.{props.minor}")

    # ================================
    #  PyTorch optimization
    # ================================

    # 1. TF32 acceleration
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # 2. cuDNN optimization (choose the fastest kernel)
    torch.backends.cudnn.benchmark = True

else:
    print("Not CUDA compatible, CPU will be used.")


In [None]:
# Dataset: Kaggle fruit detection dataset
# The dataset contains annotations (_annotations.csv), here we classify pictures so that
# a picture's label will be its most frequently occuring object class.

class FruitsFromAnnotations(Dataset):
    def __init__(self, images_dir: str, annotations_csv: str, transform=None):
        self.images_dir = images_dir
        self.annotations_csv = annotations_csv
        self.transform = transform

        # Reading annotations
        df = pd.read_csv(annotations_csv)
        # Choose the most frequent class for each image (filename)
        agg = (
            df.groupby('filename')['class']
              .agg(lambda s: s.value_counts().index[0])
              .reset_index()
        )

        # Gather class names and their indices alphabetically
        classes = sorted(agg['class'].unique().tolist())
        self.class_to_idx = {c: i for i, c in enumerate(classes)}
        self.idx_to_class = {i: c for c, i in self.class_to_idx.items()}

        # List of (image path, class index) tuples
        self.samples = []
        for _, row in agg.iterrows():
            img_path = os.path.join(images_dir, row['filename'])
            if os.path.exists(img_path):
                self.samples.append((img_path, self.class_to_idx[row['class']]))

        print(f"Loaded samples: {len(self.samples)} | Classes: {len(self.class_to_idx)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label


In [None]:
# Transformations
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_eval = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [None]:
# Datasets, dataloaders
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '.')) if '__file__' in globals() else os.getcwd()
train_images = os.path.join(ROOT, 'archive', 'train')
valid_images = os.path.join(ROOT, 'archive', 'valid')
test_images = os.path.join(ROOT, 'archive', 'test')

train_csv = os.path.join(train_images, '_annotations.csv')
valid_csv = os.path.join(valid_images, '_annotations.csv')
test_csv = os.path.join(test_images, '_annotations.csv') if os.path.exists(os.path.join(test_images, '_annotations.csv')) else None

train_ds = FruitsFromAnnotations(train_images, train_csv, transform=transform_train)
valid_ds = FruitsFromAnnotations(valid_images, valid_csv, transform=transform_eval)

num_classes = len(train_ds.class_to_idx)
assert num_classes == len(valid_ds.class_to_idx), "Train and valid class suite differs."

batch_size = 64
pin = torch.cuda.is_available()
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=pin)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=pin)

print(f"Train batches: {len(train_loader)}, Valid batches: {len(valid_loader)}")


In [None]:
# Set up model (MobileNetV2, pretrained)
model = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
# Training loop
num_epochs = 6
train_losses, valid_losses = [], []
train_accs, valid_accs = [], []

all_valid_preds = []
all_valid_labels = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_correct += (preds == labels).sum().item()
        running_total += labels.size(0)

        if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(train_loader):
            print(f"  Train batch {batch_idx+1}/{len(train_loader)} | loss={loss.item():.4f}")

    epoch_train_loss = running_loss / max(1, len(train_loader))
    epoch_train_acc = 100.0 * running_correct / max(1, running_total)
    train_losses.append(epoch_train_loss)
    train_accs.append(epoch_train_acc)

    # Eval
    model.eval()
    v_loss, v_correct, v_total = 0.0, 0, 0
    epoch_preds, epoch_labels = [], []
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            v_loss += loss.item()
            preds = outputs.argmax(dim=1)
            v_correct += (preds == labels).sum().item()
            v_total += labels.size(0)
            epoch_preds.extend(preds.cpu().numpy().tolist())
            epoch_labels.extend(labels.cpu().numpy().tolist())

    epoch_valid_loss = v_loss / max(1, len(valid_loader))
    epoch_valid_acc = 100.0 * v_correct / max(1, v_total)
    valid_losses.append(epoch_valid_loss)
    valid_accs.append(epoch_valid_acc)

    all_valid_preds = epoch_preds  # Epoch results
    all_valid_labels = epoch_labels

    print(f"Epoch ended | train_loss={epoch_train_loss:.4f}, train_acc={epoch_train_acc:.2f}% | "
          f"valid_loss={epoch_valid_loss:.4f}, valid_acc={epoch_valid_acc:.2f}%")


In [None]:
# Visualizing learning curves
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, marker='o', label='Train Loss')
plt.plot(valid_losses, marker='s', label='Valid Loss')
plt.title('Loss curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)

plt.subplot(1,2,2)
plt.plot(train_accs, marker='o', label='Train Acc')
plt.plot(valid_accs, marker='s', label='Valid Acc')
plt.title('Accuracy curves')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.grid(True, linestyle='--', alpha=0.5)
try:
    os.makedirs('GeneratedPhotos', exist_ok=True)
    plt.savefig('./GeneratedPhotos/learning_curves.png', dpi=150)
    print("Learning curves saved: learning_curves.png")
except Exception as e:
    print(f"Saving picture failed: {e}")
plt.show()


In [None]:
# Evaluation: Confusion Matrix, Precision, Recall, F1
class_names = [train_ds.idx_to_class[i] for i in range(num_classes)]

assert len(all_valid_labels) == len(all_valid_preds), "The length of labels and predictions must be equal."
if len(all_valid_labels) == 0:
    print("Warning: No valid samples for evaluation.")
else:
    cm = confusion_matrix(all_valid_labels, all_valid_preds, labels=list(range(num_classes)))
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix (Valid) - Counts')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    try:
        plt.savefig('./GeneratedPhotos/confusion_matrix_counts.png', dpi=150)
        print("Confusion matrix (counts) saved: confusion_matrix_counts.png")
    except Exception as e:
        print(f"CM save failed: {e}")
    plt.show()

    # Normalized CM (row by row)
    cm_norm = confusion_matrix(all_valid_labels, all_valid_preds, labels=list(range(num_classes)), normalize='true')
    plt.figure(figsize=(10,8))
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Greens', xticklabels=class_names, yticklabels=class_names, vmin=0, vmax=1)
    plt.title('Confusion Matrix (Valid) – Normalized')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    try:
        plt.savefig('./GeneratedPhotos/confusion_matrix_normalized.png', dpi=150)
        print("Confusion matrix (normalized) saved: confusion_matrix_normalized.png")
    except Exception as e:
        print(f"CM norm save failed: {e}")
    plt.show()

if len(all_valid_labels) > 0:
    prec = precision_score(all_valid_labels, all_valid_preds, average='macro', zero_division=0)
    rec = recall_score(all_valid_labels, all_valid_preds, average='macro', zero_division=0)
    f1 = f1_score(all_valid_labels, all_valid_preds, average='macro', zero_division=0)
    print(f"Precision (macro): {prec:.4f}")
    print(f"Recall (macro):    {rec:.4f}")
    print(f"F1-score (macro):  {f1:.4f}")

    print("\nDetailed report:\n")
    print(classification_report(all_valid_labels, all_valid_preds, target_names=class_names, zero_division=0))
else:
    print("The metrics cannot be calculated because there are no valid predictions.")



In [None]:
# Some valid samples visualized with predictions
def imshow_tensor(img_tensor):
    inv_norm = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    img = inv_norm(img_tensor.cpu()).clamp(0,1)
    npimg = img.permute(1,2,0).numpy()
    plt.imshow(npimg)
    plt.axis('off')

model.eval()
images_shown = 0
plt.figure(figsize=(12,8))
with torch.no_grad():
    for images, labels in valid_loader:
        images = images.to(device, non_blocking=True)
        outputs = model(images)
        preds = outputs.argmax(dim=1).cpu()
        for i in range(min(8, images.size(0))):
            plt.subplot(2,4,images_shown+1)
            imshow_tensor(images[i])
            plt.title(f"Pred: {class_names[preds[i].item()]}")
            images_shown += 1
            if images_shown == 8:
                break
        break
plt.tight_layout()
try:
    plt.savefig('./GeneratedPhotos/val_samples.png', dpi=150)
    print("Samples saved: val_samples.png")
except Exception as e:
    print(f"Sample save failed: {e}")
plt.show()


## Quantization Aware Training (QAT) with Pytorch

This optional section demonstrates how to run Quantization Aware Training (QAT) on the trained MobileNetV2 model using the FX graph mode quantization APIs in PyTorch. It will:

- Prepare a CPU copy of the trained model for QAT.
- Optionally perform a short fine-tuning loop for QAT on a small subset of the training data (to keep it fast).
- Convert the model to a quantized int8 model and evaluate accuracy on the validation set.
- Compare model size and a small inference speed sample.

Notes:
- QAT fine-tuning runs on CUDA if available (it uses fake-quant modules that work on GPU). INT8 conversion and inference remain on CPU (x86 fbgemm kernels).
- Deprecation/User warnings from legacy torch.ao.quantization APIs are suppressed here for cleaner output; consider migrating to torchao pt2e (prepare_pt2e/convert_pt2e) in the future.
- To keep your main workflow unaffected, this block is guarded by RUN_QAT flag (default False).



In [13]:
import copy
import time
import warnings

try:
    import torch.ao.quantization as tq
    from torch.ao.quantization import get_default_qat_qconfig, QConfigMapping
    # Compatibility: in some PyTorch versions the FX APIs live in quantize_fx, in others in fx
    try:
        from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx  # PyTorch >= 1.13/2.x
    except Exception:
        try:
            from torch.ao.quantization.fx import prepare_qat_fx, convert_fx  # Older layout
        except Exception as _qat_import_err:
            raise _qat_import_err
except Exception as _qat_import_err:
    print(
        "Warning: QAT imports failed. FX QAT APIs not found in your torch build. "
        f"Details: {_qat_import_err}\n"
        "Tip: Ensure you're using a PyTorch version that provides torch.ao.quantization.quantize_fx or fx."
    )
    tq = None


# Guard to avoid running QAT unless explicitly enabled
RUN_QAT = True  # Set to True to run a brief QAT fine-tune and convert


def evaluate_top1(model_cpu, data_loader):
    model_cpu.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to('cpu', non_blocking=True)
            labels = labels.to('cpu', non_blocking=True)
            outputs = model_cpu(images)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / max(1, total)


def estimate_inference_time(model_cpu, data_loader, warmup=3, iters=10):
    model_cpu.eval()
    images_it = None
    # Get one batch
    for images, _ in data_loader:
        images_it = images.to('cpu', non_blocking=True)
        break
    if images_it is None:
        return None

    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model_cpu(images_it)

    start = time.time()
    with torch.no_grad():
        for _ in range(iters):
            _ = model_cpu(images_it)
    end = time.time()

    avg_ms = (end - start) * 1000.0 / max(1, iters)
    return avg_ms


# Only proceed if imports are available
if tq is None:
    print("QAT section skipped because torch.ao.quantization is unavailable.")
else:
    # Suppress known deprecation and observer warnings from legacy torch.ao.quantization APIs
    try:
        warnings.filterwarnings(
            "ignore",
            category=DeprecationWarning,
            message=r".*torch\.ao\.quantization is deprecated and will be removed in 2\.10.*",
        )
        warnings.filterwarnings(
            "ignore",
            category=UserWarning,
            message=r".*reduce_range will be deprecated.*",
            module=r"torch\.ao\.quantization\.observer"
        )
    except Exception:
        pass

    # Create a CPU float copy of the trained model
    float_model_cpu = copy.deepcopy(model).to('cpu')
    float_model_cpu.eval()

    # Baseline (float) accuracy and timing
    try:
        baseline_acc = evaluate_top1(float_model_cpu, valid_loader)
        baseline_time_ms = estimate_inference_time(float_model_cpu, valid_loader)
        print(f"[Baseline FP32] Valid top-1 acc: {baseline_acc:.2f}% | Sample avg latency: {baseline_time_ms:.2f} ms/batch" if baseline_time_ms is not None else f"[Baseline FP32] Valid top-1 acc: {baseline_acc:.2f}%")
    except Exception as e:
        print(f"Baseline evaluation failed: {e}")
        baseline_acc, baseline_time_ms = None, None

    # QAT preparation (FX graph mode)
    example_inputs = (torch.randn(1, 3, 224, 224),)
    qconfig = get_default_qat_qconfig('fbgemm')
    qconfig_mapping = QConfigMapping().set_global(qconfig)

    try:
        prepared_qat = prepare_qat_fx(float_model_cpu, qconfig_mapping, example_inputs)
        print("Model prepared for QAT (FX).")
    except Exception as e:
        print(f"QAT prepare failed: {e}")
        prepared_qat = None

    quantized_model = None
    if prepared_qat is not None:
        if RUN_QAT:
            # Brief QAT fine-tuning loop on a small subset
            qat_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            prepared_qat = prepared_qat.to(qat_device)
            criterion_q = nn.CrossEntropyLoss()
            optimizer_q = optim.Adam(prepared_qat.parameters(), lr=5e-5)
            prepared_qat.train()

            max_steps = 50  # keep it small and fast
            step_count = 0
            print(f"Starting brief QAT fine-tuning on {qat_device}...")
            for images, labels in train_loader:
                images = images.to(qat_device, non_blocking=True)
                labels = labels.to(qat_device, non_blocking=True)
                optimizer_q.zero_grad()
                outputs = prepared_qat(images)
                loss = criterion_q(outputs, labels)
                loss.backward()
                optimizer_q.step()
                step_count += 1
                if step_count % 10 == 0:
                    print(f"  QAT step {step_count} | loss={loss.item():.4f}")
                if step_count >= max_steps:
                    break
            print("QAT fine-tuning finished.")
        else:
            print("RUN_QAT is False → skipping QAT fine-tune (you can enable it to improve int8 accuracy).")

        # Convert to quantized model
        try:
            # Ensure model is on CPU for int8 conversion/inference
            prepared_qat = prepared_qat.to('cpu')
            quantized_model = convert_fx(prepared_qat)
            quantized_model.eval()
            print("Converted to quantized int8 model.")
        except Exception as e:
            print(f"Quantized convert failed: {e}")

    if quantized_model is not None:
        try:
            q_acc = evaluate_top1(quantized_model, valid_loader)
            q_time_ms = estimate_inference_time(quantized_model, valid_loader)
            print(f"[Quantized INT8] Valid top-1 acc: {q_acc:.2f}% | Sample avg latency: {q_time_ms:.2f} ms/batch" if q_time_ms is not None else f"[Quantized INT8] Valid top-1 acc: {q_acc:.2f}%")
        except Exception as e:
            print(f"Quantized evaluation failed: {e}")

        # Size comparison
        try:
            # Ensure QAT models directory exists
            base_root = ROOT if 'ROOT' in globals() else os.getcwd()
            qat_dir = os.path.join(base_root, 'QATmodels')
            os.makedirs(qat_dir, exist_ok=True)

            fp32_path = os.path.join(qat_dir, 'model_fp32_state_dict.pth')
            int8_path = os.path.join(qat_dir, 'model_int8_qat_fx_state_dict.pth')
            torch.save(float_model_cpu.state_dict(), fp32_path)
            torch.save(quantized_model.state_dict(), int8_path)
            fp32_size = os.path.getsize(fp32_path) / (1024**2)
            int8_size = os.path.getsize(int8_path) / (1024**2)
            print(f"Saved FP32 state_dict: {fp32_path} ({fp32_size:.2f} MB)")
            print(f"Saved INT8 state_dict: {int8_path} ({int8_size:.2f} MB)")
        except Exception as e:
            print(f"Saving state_dicts failed: {e}")

        # Optional: save a TorchScript version (may fail depending on ops)
        try:
            with warnings.catch_warnings():
                warnings.filterwarnings(
                    "ignore",
                    category=UserWarning,
                    message=r"The TorchScript type system doesn't support instance-level annotations.*",
                    module=r"torch\.jit\._check"
                )
                scripted = torch.jit.script(quantized_model)
            base_root = ROOT if 'ROOT' in globals() else os.getcwd()
            qat_dir = os.path.join(base_root, 'QATmodels')
            os.makedirs(qat_dir, exist_ok=True)
            ts_path = os.path.join(qat_dir, 'model_int8_qat_fx_scripted.pt')
            scripted.save(ts_path)
            print(f"Saved scripted INT8 model: {ts_path}")
        except Exception as e:
            print(f"Saving scripted INT8 model failed: {e}")


[Baseline FP32] Valid top-1 acc: 79.83% | Sample avg latency: 438.30 ms/batch
Model prepared for QAT (FX).
Starting brief QAT fine-tuning on cuda...
  QAT step 10 | loss=1.0398
  QAT step 20 | loss=1.0956
  QAT step 30 | loss=0.8959
  QAT step 40 | loss=0.7308
  QAT step 50 | loss=0.5603
QAT fine-tuning finished.
Converted to quantized int8 model.
[Quantized INT8] Valid top-1 acc: 74.33% | Sample avg latency: 97.98 ms/batch
Saved FP32 state_dict: C:\Dev\AI\Fruit-Classification\QATmodels\model_fp32_state_dict.pth (8.79 MB)
Saved INT8 state_dict: C:\Dev\AI\Fruit-Classification\QATmodels\model_int8_qat_fx_state_dict.pth (2.54 MB)
Saved scripted INT8 model: C:\Dev\AI\Fruit-Classification\QATmodels\model_int8_qat_fx_scripted.pt
