## Note
The easiest and fastest way to train a neural network is to use Google Colab. Just load this `train.ipynb` notebook to Google Colab, then switch to a T4 GPU and execute all cells. It will take approximately 10 minutes to complete the training and export the best-trained model to ONNX.

In [28]:
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms, models
from io import BytesIO
from pathlib import Path

import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models

import os
import requests
import glob
import shutil
import zipfile
import re

In [21]:
!pip install onnxscript

Collecting onnxscript
  Downloading onnxscript-0.5.7-py3-none-any.whl.metadata (13 kB)
Collecting onnx_ir<2,>=0.1.12 (from onnxscript)
  Downloading onnx_ir-0.1.13-py3-none-any.whl.metadata (3.2 kB)
Collecting onnx>=1.16 (from onnxscript)
  Downloading onnx-1.20.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Downloading onnxscript-0.5.7-py3-none-any.whl (693 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m693.4/693.4 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnx-1.20.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (18.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.1/18.1 MB[0m [31m126.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnx_ir-0.1.13-py3-none-any.whl (133 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.1/133.1 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx, onnx_ir, onnxscript
Successfully install

In [27]:
print(f'numpy=={np.__version__}')
print(f'pytorch=={torch.__version__}')

numpy==2.0.2
pytorch==2.9.0+cu126


In [3]:
# ImageNet normalization values
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

In [5]:
# Training transforms WITH augmentation
train_transforms = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.RandomRotation(10),           # Rotate up to 10 degrees
    transforms.RandomResizedCrop(256, scale=(0.9, 1.0)),  # Zoom
    transforms.RandomHorizontalFlip(),       # Horizontal flip
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# Validation transforms - NO augmentation
val_transforms = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [6]:
class GemstoneDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = sorted(os.listdir(data_dir))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        for label_name in self.classes:
            label_dir = os.path.join(data_dir, label_name)
            for img_name in os.listdir(label_dir):
                self.image_paths.append(os.path.join(label_dir, img_name))
                self.labels.append(self.class_to_idx[label_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [7]:
class GemstoneClassifierResnet101(nn.Module):
    def __init__(self, size_inner=128, droprate=0.2, num_classes=87):
        super(GemstoneClassifierResnet101, self).__init__()

        # Load pre-trained Resnet101
        modules = list(models.resnet101(weights='DEFAULT').children())[:-2]
        self.base_model = nn.Sequential(*modules)

        # Freeze base model parameters
        for param in self.base_model.parameters():
            param.requires_grad = False

        # Remove original classifier
        self.base_model.classifier = nn.Identity()

        # Add custom layers
        self.global_avg_pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.inner = nn.Linear(2048, size_inner)  # New inner layer
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(droprate)  # Add dropout
        self.output_layer = nn.Linear(size_inner, num_classes)

    def forward(self, x):
        x = self.base_model(x)
        x = self.global_avg_pooling(x)
        x = torch.flatten(x, 1)
        x = self.inner(x)
        x = self.relu(x)
        x = self.dropout(x)  # Apply dropout
        x = self.output_layer(x)
        return x

In [8]:
def make_model(learning_rate=0.001, size_inner=128, droprate=0.2):
    model = GemstoneClassifierResnet101(droprate=droprate, size_inner=size_inner, num_classes=87)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    return model, optimizer

In [9]:
def download_and_rename_dataset(url, destination_filename):
    try:
        with requests.get(url, stream=True) as r:
            r.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)
            with open(destination_filename, 'wb') as f:
                shutil.copyfileobj(r.raw, f)
        print(f"Successfully downloaded and renamed file to: {os.path.abspath(destination_filename)}")
    except requests.exceptions.RequestException as e:
        print(f"An error occurred during download: {e}")
    except IOError as e:
        print(f"An error occurred while saving the file: {e}")


In [10]:
def load_data():
    data_url = 'https://www.kaggle.com/api/v1/datasets/download/lsind18/gemstones-images'
    new_name = 'gemstones_images.zip'

    download_and_rename_dataset(data_url, new_name)

    with zipfile.ZipFile(new_name, 'r') as zip_ref:
        # Extract all contents to the specified directory
        # If files already exist in './dataset', they will be overwritten
        zip_ref.extractall('./dataset')

In [11]:
def train_and_evaluate(model, optimizer, train_loader, val_loader, criterion, num_epochs, device, enable_checkout=False):
    best_val_accuracy = 0.0

    accuracy_progress = []
    loss_progress = []

    for epoch in range(num_epochs):
        # --- Training phase ---
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total

        # --- Validation phase ---
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total

        # Calculate Generalization Gap (Overfitting Measure)
        # A smaller gap means less overfitting.
        current_gap = abs(train_acc - val_acc)

        print(f'Epoch {epoch+1}/{num_epochs} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Gap: {current_gap:.4f}')

        # --- Balanced Checkpointing Logic ---
        # We save if:
        # 1. We hit a new high in validation accuracy AND the gap is reasonable (e.g., < 10%)
        # 2. OR it's the best balance of accuracy and gap we've seen yet.

        if enable_checkout:
            # Definition of "Best": High accuracy AND low gap
            # You can adjust the 0.1 (10%) threshold based on your needs
            is_best_acc = val_acc > best_val_accuracy
            is_low_overfit = current_gap < 0.10

            if is_best_acc and is_low_overfit:
                best_val_accuracy = val_acc
                checkpoint_path = f'gemstone_classifier_model_ep{epoch+1}_acc{val_acc:.3f}_gap{current_gap:.3f}.pth'
                torch.save(model.state_dict(), checkpoint_path)
                print(f'--> Best model saved (High Acc & Low Overfit)')

        accuracy_progress.append((train_acc, val_acc))
        loss_progress.append((train_loss, val_loss))

    return best_val_accuracy

In [15]:
def train():
    num_epochs = 10

    # the best learning_rate
    learning_rate = 0.001
    # the best size_inner
    size_inner = 128
    # The best drop rate
    droprate = 0.2

    train_dataset = GemstoneDataset(
        data_dir='./dataset/train',
        transform=train_transforms
    )

    val_dataset = GemstoneDataset(
        data_dir='./dataset/test',
        transform=val_transforms
    )

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    model, optimizer = make_model(
        learning_rate=learning_rate,
        size_inner=size_inner,
        droprate=droprate,
    )

    train_and_evaluate(model, optimizer, train_loader, val_loader, criterion, num_epochs, device, enable_checkout=True)

In [17]:
def get_accuracy(filename):
    # Search for 'acc' followed by digits and a decimal
    match = re.search(r"acc(\d+\.\d+)", filename)
    return float(match.group(1)) if match else 0.0

In [24]:
def load_best_model():
    # Load the best saved model
    model_list = glob.glob(f'gemstone_classifier_model_ep*.pth')
    best_model = max(model_list, key=get_accuracy)

    print(f"Loading best model: {best_model}")

    model = GemstoneClassifierResnet101()
    model.load_state_dict(torch.load(best_model))
    model.to(device)
    model.eval()

    return model

In [13]:
def export_to_onnx(model):
    # Create dummy input
    dummy_input = torch.randn(1, 3, 256, 256).to(device)

    # Export to ONNX
    onnx_path = "./gemstone_classifier_resnet101.onnx"

    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        verbose=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )

    print(f"Model exported to {onnx_path}")

In [14]:
load_data()
train()

Successfully downloaded and renamed file to: /content/gemstones_images.zip
Downloading: "https://download.pytorch.org/models/resnet101-cd907fc2.pth" to /root/.cache/torch/hub/checkpoints/resnet101-cd907fc2.pth


100%|██████████| 171M/171M [00:00<00:00, 193MB/s]


Epoch 1/10 | Train Acc: 0.1509 | Val Acc: 0.3361 | Gap: 0.1852
Epoch 2/10 | Train Acc: 0.4051 | Val Acc: 0.4959 | Gap: 0.0908
--> Best model saved (High Acc & Low Overfit)
Epoch 3/10 | Train Acc: 0.5368 | Val Acc: 0.6116 | Gap: 0.0748
--> Best model saved (High Acc & Low Overfit)
Epoch 4/10 | Train Acc: 0.6282 | Val Acc: 0.6253 | Gap: 0.0028
--> Best model saved (High Acc & Low Overfit)
Epoch 5/10 | Train Acc: 0.6716 | Val Acc: 0.6198 | Gap: 0.0517
Epoch 6/10 | Train Acc: 0.7132 | Val Acc: 0.6887 | Gap: 0.0245
--> Best model saved (High Acc & Low Overfit)
Epoch 7/10 | Train Acc: 0.7283 | Val Acc: 0.6832 | Gap: 0.0451
Epoch 8/10 | Train Acc: 0.7504 | Val Acc: 0.7025 | Gap: 0.0479
--> Best model saved (High Acc & Low Overfit)
Epoch 9/10 | Train Acc: 0.7882 | Val Acc: 0.6997 | Gap: 0.0884
Epoch 10/10 | Train Acc: 0.7987 | Val Acc: 0.6997 | Gap: 0.0989


IndexError: list index out of range

In [25]:
model = load_best_model()

Loading best model: gemstone_classifier_model_ep8_acc0.702_gap0.048.pth


In [26]:
export_to_onnx(model)

  torch.onnx.export(


[torch.onnx] Obtain model graph for `GemstoneClassifierResnet101([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `GemstoneClassifierResnet101([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 209 of general pattern rewrite rules.
Model exported to ./models/gemstone_classifier_resnet101.onnx
