In [21]:
import torch
import os
from utils import *
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights


def load_model():
    # --- 1. Load Pretrained EfficientNet ---
    weights = EfficientNet_B0_Weights.DEFAULT
    model = efficientnet_b0(weights=weights)

    num_classes = 5 
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, num_classes)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model


def train(model, train_loader, val_loader, preprocess_method):
    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # --- 6. Training Loop ---
    num_epochs = 5
    val_accs = []
    val_losses = []
    train_accs = []
    train_losses = []


    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 30)

        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total
        train_accs.append(train_acc)
        train_losses.append(train_loss)
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}')
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

        # --- 7. Evaluation (optional) ---
        acc, loss, precision, recall, f1 = evaluate(model, val_loader)
        val_accs.append(acc)
        val_losses.append(loss)

        print(f'Val Accuracy: {acc:.4f}')
        print(f'Val Loss: {loss:.4f}')
        print(f'Val Precision: {precision:.4f}')
        print(f'Val Recall: {recall:.4f}')
        print(f'Val F1 Score: {f1:.4f}')
        # Validation

        plot_accuracy_and_loss(train_accs, train_losses, val_accs, val_losses)

    
    if not os.path.exists('efficientNet'):
        os.makedirs('efficientNet')
        
    torch.save(model.state_dict(), f'efficientNet/{preprocess_method.__name__}.pth')


def test(model, test_loader1, test_loader2, preprocess_method):
    # model = load_model()

    model.load_state_dict(torch.load(f'efficientNet/{preprocess_method.__name__}.pth'))
    model.eval()

    acc, loss, precision, recall, f1 = evaluate(model, test_loader1)
    acc2, loss2, precision2, recall2, f12 = evaluate(model, test_loader2)

    print(f'Metrics for Aptos Dataset')
    print(f'Test Accuracy: {acc:.4f}')
    print(f'Test Loss: {loss:.4f}')
    print(f'Test  Precision: {precision:.4f}')
    print(f'Test  Recall: {recall:.4f}')
    print(f'Test  F1 Score: {f1:.4f}')

    print(f'Metrics for DR Dataset')
    print(f'Test Accuracy: {acc2:.4f}')
    print(f'Test Loss: {loss2:.4f}')
    print(f'Test  Precision: {precision2:.4f}')
    print(f'Test  Recall: {recall2:.4f}')
    print(f'Test  F1 Score: {f12:.4f}')


def main(preprocess_method=None):
    # Load training, validation, and test data
    full_train_balanced, full_val, test_df1, test_df2 = load_data()

    # Get data loaders with no extra preprocessing
    train_loader, val_loader, test_loader1, test_loader2, test_dataset1, test_dataset2 = get_data_loaders(full_train_balanced, full_val, test_df1, test_df2, 16, preprocess_method)

    # Load efficientNet 
    model = load_model()

    # Train model
    train(model, train_loader, val_loader, preprocess_method)

    # Test model
    test(model, test_loader1, test_loader2)


In [22]:
# Train model using different preprocessing methods
main(gaussian_subtractive_normalization)

Epoch 1/5
------------------------------


  3%|▎         | 47/1563 [02:55<1:31:46,  3.63s/it]

In [7]:
model.load_state_dict(torch.load('efficientNet/regular.pth'))
model.eval()

acc, loss, precision, recall, f1 = evaluate(model, test_loader1)
acc2, loss2, precision2, recall2, f12 = evaluate(model, test_loader2)

print(f'Metrics for Aptos Dataset')
print(f'Test Accuracy: {acc:.4f}')
print(f'Test Loss: {loss:.4f}')
print(f'Test  Precision: {precision:.4f}')
print(f'Test  Recall: {recall:.4f}')
print(f'Test  F1 Score: {f1:.4f}')

print(f'Metrics for DR Dataset')
print(f'Test Accuracy: {acc2:.4f}')
print(f'Test Loss: {loss2:.4f}')
print(f'Test  Precision: {precision2:.4f}')
print(f'Test  Recall: {recall2:.4f}')
print(f'Test  F1 Score: {f12:.4f}')

100%|██████████| 46/46 [01:36<00:00,  2.10s/it]
100%|██████████| 440/440 [21:55<00:00,  2.99s/it]

Metrics for Aptos Dataset
Test Accuracy: 0.7722
Test Loss: 0.8601
Test  Precision: 0.7995
Test  Recall: 0.7722
Test  F1 Score: 0.7766
Metrics for DR Dataset
Test Accuracy: 0.6662
Test Loss: 0.8676
Test  Precision: 0.7446
Test  Recall: 0.6662
Test  F1 Score: 0.6952



