In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image

# Custom dataset class for loading banana freshness images
class BananaFreshnessDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.images = []
        self.labels = []

        # Ensure the directory exists
        if not os.path.exists(data_dir):
            raise ValueError(f"Directory {data_dir} does not exist. Please check the path.")

        # Load images and labels
        for day in os.listdir(data_dir):
            day_path = os.path.join(data_dir, day)
            # Validate directory names (day1, day2, etc.)
            if os.path.isdir(day_path) and day.startswith('day') and day[3:].isdigit():
                day_label = int(day.replace('day', ''))  # Extract numeric label
                for filename in os.listdir(day_path):
                    if filename.endswith(('.jpg', '.png')):  # Support both .jpg and .png
                        self.images.append(os.path.join(day_path, filename))
                        self.labels.append(day_label)

        # If no images are found, raise an error
        if len(self.images) == 0:
            raise ValueError(f"No images found in {data_dir}. Please check the directory structure and file formats.")
        
        print(f"Loaded {len(self.images)} images from {data_dir}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = float(self.labels[idx])  # Convert label to float

        try:
            img = Image.open(img_path).convert('RGB')  # Convert image to RGB format
        except Exception as e:
            raise RuntimeError(f"Failed to load image at {img_path}: {e}")

        if self.transform:
            img = self.transform(img)
        
        return img, label

# Data transformations
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224 for GoogLeNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load datasets
data_dir = 'data'  # Path to the main data folder
try:
    train_dataset = BananaFreshnessDataset(data_dir, transform=data_transforms)
    val_dataset = BananaFreshnessDataset(data_dir, transform=data_transforms)  # Adjust if you have a separate validation set
except ValueError as e:
    print(e)
    raise SystemExit

# Create dataloaders
dataloaders = {
    'train': DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4),
    'val': DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4),
}

dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}

# Load and modify GoogLeNet model
model = models.googlenet(weights='IMAGENET1K_V1')
model.fc = nn.Linear(model.fc.in_features, 1)  # Single output for regression (freshness index)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss function and optimizer for regression
criterion = nn.MSELoss()  # Mean Squared Error for regression
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adjust learning rate as needed

# Training loop
def train_model(model, criterion, optimizer, num_epochs=25):
    best_model_wts = model.state_dict()
    best_mse = float('inf')

    for epoch in range(num_epochs):
        print(f"Epoch {epoch}/{num_epochs - 1}")
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_mse = 0.0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.float().unsqueeze(1).to(device)  # Reshape labels for regression

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_mse += ((outputs - labels) ** 2).sum().item()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_mse = running_mse / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} MSE: {epoch_mse:.4f}')

            # Save the best model
            if phase == 'val' and epoch_mse < best_mse:
                best_mse = epoch_mse
                best_model_wts = model.state_dict()

    print(f'Best val MSE: {best_mse:.4f}')
    model.load_state_dict(best_model_wts)
    return model

# Train the model
num_epochs = 25
model = train_model(model, criterion, optimizer, num_epochs=num_epochs)

# Inference example function
def predict(model, image):
    model.eval()
    with torch.no_grad():
        image = image.to(device)
        output = model(image.unsqueeze(0))  # Add batch dimension
        return output.item()

# Example usage of inference (replace with actual image loading code)
# image_path = 'path/to/your/image.jpg'
# image = Image.open(image_path).convert('RGB')
# image = data_transforms(image)
# freshness_index = predict(model, image)
# print(f"Predicted Freshness Index: {freshness_index:.2f}")

Loaded 1260 images from data
Loaded 1260 images from data
Epoch 0/24
----------
train Loss: 2.4987 MSE: 2.4987
val Loss: 0.3846 MSE: 0.3846
Epoch 1/24
----------
train Loss: 0.2373 MSE: 0.2373
val Loss: 0.0673 MSE: 0.0673
Epoch 2/24
----------
train Loss: 0.2660 MSE: 0.2660
val Loss: 0.6076 MSE: 0.6076
Epoch 3/24
----------
train Loss: 0.1753 MSE: 0.1753
val Loss: 0.0670 MSE: 0.0670
Epoch 4/24
----------
train Loss: 0.1609 MSE: 0.1609
val Loss: 0.1045 MSE: 0.1045
Epoch 5/24
----------
train Loss: 0.1378 MSE: 0.1378
val Loss: 0.1519 MSE: 0.1519
Epoch 6/24
----------
train Loss: 0.1903 MSE: 0.1903
val Loss: 0.1609 MSE: 0.1609
Epoch 7/24
----------
train Loss: 0.1983 MSE: 0.1983
val Loss: 0.0794 MSE: 0.0794
Epoch 8/24
----------
train Loss: 0.1310 MSE: 0.1310
val Loss: 0.1672 MSE: 0.1672
Epoch 9/24
----------
train Loss: 0.1864 MSE: 0.1864
val Loss: 0.4958 MSE: 0.4958
Epoch 10/24
----------
train Loss: 0.1004 MSE: 0.1004
val Loss: 0.0707 MSE: 0.0707
Epoch 11/24
----------
train Loss: 0.17

In [2]:
image_path = 'data/day2/IMG20240916123330.jpg'
image = Image.open(image_path).convert('RGB')
image = data_transforms(image)
freshness_index = predict(model, image)
print(f"Predicted Freshness Index: {freshness_index/7:.2f}")

Predicted Freshness Index: 0.28


In [3]:
# Saving the trained model
def save_model(model, file_path='banana_freshness_model.pth'):
    torch.save(model.state_dict(), file_path)
    print(f"Model saved to {file_path}")

# Example usage after training
save_model(model, 'banana_freshness_model.pth')

Model saved to banana_freshness_model.pth
