In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from torchvision import transforms


from model.EfficientNet import EfficientNet
from trainer.process_data import CatTripletDataset, load_cat_data
from trainer.evaluate import evaluate_model  # NEW: Import evaluation module

print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA device: {torch.cuda.get_device_name(0)}')

Using device: cuda
PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3050 Laptop GPU


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Hyperparameters
CONFIG = {
    'BATCH_SIZE': 16,
    'EMBEDDING_DIM': 128,
    'NUM_EPOCHS': 50,
    'LEARNING_RATE': 1e-4,
    'SUBSET_SIZE': 100,  # Number of cats to use per epoch
    'TRIPLETS_PER_EPOCH': 1000,  # Number of triplets per epoch
    'VAL_SUBSET_SIZE': 50,
    'VAL_TRIPLETS': 500,
    'MARGIN': 1.0,  # Triplet loss margin
    'PATIENCE': 20,  # Early stopping patience
    'DATASET_PATH': './cat',
}
print("TRAINING CONFIGURATION")
print("-"*60)
for key, value in CONFIG.items():
    print(f"{key:<25} {value}")

Using device: cuda
TRAINING CONFIGURATION
------------------------------------------------------------
BATCH_SIZE                16
EMBEDDING_DIM             128
NUM_EPOCHS                50
LEARNING_RATE             0.0001
SUBSET_SIZE               100
TRIPLETS_PER_EPOCH        1000
VAL_SUBSET_SIZE           50
VAL_TRIPLETS              500
MARGIN                    1.0
PATIENCE                  20
DATASET_PATH              ./cat


In [3]:
print("\nLoading dataset...")
image_paths, labels = load_cat_data(CONFIG['DATASET_PATH'])

print(f"\nTotal images: {len(image_paths):,}")
print(f"Total unique cats: {len(set(labels)):,}")


Loading dataset...
Loading data from: ./cat
Found 164100 cat folders


Loading cats: 100%|██████████| 164100/164100 [00:14<00:00, 10969.03it/s]


Loaded 643539 images from 164100 unique cats

Total images: 643,539
Total unique cats: 164,100





In [4]:
unique_cats = sorted(list(set(labels)))
print(f'Total unique cats: {len(unique_cats):,}')

# Split cats 80/20
train_cat_ids, val_cat_ids = train_test_split(
    unique_cats, 
    test_size=0.2, 
    random_state=42
)

train_cat_ids = set(train_cat_ids)
val_cat_ids = set(val_cat_ids)

print(f'Train cats: {len(train_cat_ids):,}')
print(f'Val cats: {len(val_cat_ids):,}')

# Split images based on cat ID
train_paths, train_labels = [], []
val_paths, val_labels = [], []

for img_path, cat_id in zip(image_paths, labels):
    if cat_id in train_cat_ids:
        train_paths.append(img_path)
        train_labels.append(cat_id)
    else:
        val_paths.append(img_path)
        val_labels.append(cat_id)

print(f'\nTraining images: {len(train_paths):,}')
print(f'Validation images: {len(val_paths):,}')
print(f'Training cats: {len(set(train_labels)):,}')
print(f'Validation cats: {len(set(val_labels)):,}')

Total unique cats: 164,100
Train cats: 131,280
Val cats: 32,820

Training images: 515,159
Validation images: 128,380
Training cats: 131,280
Validation cats: 32,820


In [5]:
# Change the SIZE as needed
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

In [6]:
train_dataset = CatTripletDataset(
    train_paths, 
    train_labels, 
    transform=train_transform,
    subset_size=CONFIG['SUBSET_SIZE'],
    triplets_per_epoch=CONFIG['TRIPLETS_PER_EPOCH']
)

val_dataset = CatTripletDataset(
    val_paths, 
    val_labels, 
    transform=val_transform,
    subset_size=CONFIG['VAL_SUBSET_SIZE'],
    triplets_per_epoch=CONFIG['VAL_TRIPLETS']
)
# DataLoaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['BATCH_SIZE'], 
    shuffle=True, 
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['BATCH_SIZE'], 
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)

print(f'\n✓ Train batches: {len(train_loader):,}')
print(f'✓ Val batches: {len(val_loader):,}')

Total cats: 131280
Total images: 515159
Sampling from 100 cats...


Generating triplets: 100%|██████████| 1000/1000 [00:00<00:00, 136284.90it/s]


Total triplets generated: 1000
Total cats: 32820
Total images: 128380
Sampling from 50 cats...


Generating triplets: 100%|██████████| 500/500 [00:00<00:00, 173075.18it/s]

Total triplets generated: 500

✓ Train batches: 63
✓ Val batches: 32





In [7]:
model = EfficientNet(embedding_dim=CONFIG['EMBEDDING_DIM'])
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("MODEL SUMMARY")
print(f'{"-"*60}')
print(f'Total parameters:     {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
print(f'Model size:           {total_params * 4 / 1024 / 1024:.2f} MB')
print(f'Device:               {next(model.parameters()).device}')



MODEL SUMMARY
------------------------------------------------------------
Total parameters:     4,730,364
Trainable parameters: 4,730,364
Model size:           18.04 MB
Device:               cuda:0


In [8]:
criterion = nn.TripletMarginLoss(margin=CONFIG['MARGIN'], p=2)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['LEARNING_RATE'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print(f"\n✓ Loss: Triplet Margin Loss (margin={CONFIG['MARGIN']})")
print(f"✓ Optimizer: Adam (lr={CONFIG['LEARNING_RATE']})")
print(f"✓ Scheduler: StepLR (step=10, gamma=0.5)")


✓ Loss: Triplet Margin Loss (margin=1.0)
✓ Optimizer: Adam (lr=0.0001)
✓ Scheduler: StepLR (step=10, gamma=0.5)


In [9]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    
    pbar = tqdm(loader, desc="Training")
    for anchor, positive, negative in pbar:
        # Move to device
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        anchor_emb = model(anchor)
        pos_emb = model(positive)
        neg_emb = model(negative)
        
        # Compute loss
        loss = criterion(anchor_emb, pos_emb, neg_emb)
        
        # Backward
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)


@torch.no_grad()
def validate(model, loader, criterion, device):
    """Validate model"""
    model.eval()
    total_loss = 0
    
    pbar = tqdm(loader, desc="Validating")
    for anchor, positive, negative in pbar:
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        
        anchor_emb = model(anchor)
        pos_emb = model(positive)
        neg_emb = model(negative)
        
        loss = criterion(anchor_emb, pos_emb, neg_emb)
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)

In [11]:
history = {
    'train_loss': [],
    'val_loss': [],
    'learning_rates': []
}

best_val_loss = float('inf')
patience_counter = 0

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

for epoch in range(CONFIG['NUM_EPOCHS']):
    print(f'\nEpoch {epoch+1}/{CONFIG["NUM_EPOCHS"]}')
    print('-' * 60)
    
    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    history['train_loss'].append(train_loss)
    
    # Validate
    val_loss = validate(model, val_loader, criterion, device)
    history['val_loss'].append(val_loss)
    
    # Learning rate
    current_lr = optimizer.param_groups[0]['lr']
    history['learning_rates'].append(current_lr)
    
    print(f'\nTrain Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}')
    
    # Learning rate scheduler
    scheduler.step()
    
    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': CONFIG
        }, 'best_efficientnet_triplet.pth')
        print(f'✓ New best val loss: {best_val_loss:.4f} - Model saved!')
    else:
        patience_counter += 1
        print(f'No improvement for {patience_counter} epochs')
    
    if patience_counter >= CONFIG['PATIENCE']:
        print(f'\n⚠ Early stopping triggered after {epoch+1} epochs')
        break

print('\n' + "="*60)
print('✓ TRAINING COMPLETED!')


STARTING TRAINING

Epoch 1/50
------------------------------------------------------------


Training:   0%|          | 0/63 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 3.81 GiB of which 1.94 MiB is free. Including non-PyTorch memory, this process has 3.79 GiB memory in use. Of the allocated memory 3.61 GiB is allocated by PyTorch, and 97.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)