In [None]:
import torch
import os
from utils import *
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights


def load_model():
    # --- 1. Load Pretrained EfficientNet ---
    weights = EfficientNet_B0_Weights.DEFAULT
    model = efficientnet_b0(weights=weights)

    num_classes = 5 
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, num_classes)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model


def train(model, train_loader, val_loader, preprocess_method):
    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # --- 6. Training Loop ---
    num_epochs = 5
    val_accs = []
    val_losses = []
    train_accs = []
    train_losses = []


    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 30)

        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total
        train_accs.append(train_acc)
        train_losses.append(train_loss)
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}')
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

        # --- 7. Evaluation (optional) ---
        acc, loss, precision, recall, f1 = evaluate(model, val_loader)
        val_accs.append(acc)
        val_losses.append(loss)

        print(f'Val Accuracy: {acc:.4f}')
        print(f'Val Loss: {loss:.4f}')
        print(f'Val Precision: {precision:.4f}')
        print(f'Val Recall: {recall:.4f}')
        print(f'Val F1 Score: {f1:.4f}')
        # Validation

        plot_accuracy_and_loss(train_accs, train_losses, val_accs, val_losses)

    
    if not os.path.exists('efficientNet'):
        os.makedirs('efficientNet')
        
    if preprocess_method == None:
        torch.save(model.state_dict(), f'efficientNet/regular.pth')
    else:
        torch.save(model.state_dict(), f'efficientNet/{preprocess_method.__name__}.pth')


def test(model, test_loader, preprocess_method):

    if preprocess_method == None:
         model.load_state_dict(torch.load(f'efficientNet/regular.pth'))
    else:
        model.load_state_dict(torch.load(f'efficientNet/{preprocess_method.__name__}.pth'))
    model.eval()

    acc, loss, precision, recall, f1 = evaluate(model, test_loader)

    print(f'Test Accuracy: {acc:.4f}')
    print(f'Test Loss: {loss:.4f}')
    print(f'Test  Precision: {precision:.4f}')
    print(f'Test  Recall: {recall:.4f}')
    print(f'Test  F1 Score: {f1:.4f}')


def main(preprocess_method=None):
    # Load training, validation, and test data
    full_train_balanced, full_val, full_test = load_data()

    # Get data loaders with no extra preprocessing
    train_loader, val_loader, test_loader, test_dataset = get_data_loaders(full_train_balanced, full_val, full_test, 16, preprocess_method)

    # Load efficientNet 
    model = load_model()

    # Train model
    train(model, train_loader, val_loader, preprocess_method)

    # Test model
    test(model, test_loader, preprocess_method)


In [None]:
# Train model using different preprocessing methods
main(gaussian_subtractive_normalization)

In [None]:
from utils import *

def test_from_checkpoint(preprocess_method=None):
    # Load training, validation, and test data
    full_train_balanced, full_val, full_test = load_data()

    # Get data loaders with no extra preprocessing
    train_loader, val_loader, test_loader, test_dataset = get_data_loaders(full_train_balanced, full_val, full_test, 16, preprocess_method)

    model = load_model()
    if preprocess_method == None:
         model.load_state_dict(torch.load(f'efficientNet/regular.pth'))
    else:
        model.load_state_dict(torch.load(f'efficientNet/{preprocess_method.__name__}.pth'))
    model.eval()

    acc, loss, precision, recall, f1 = evaluate(model, test_loader)

    print(f'Test Accuracy: {acc:.4f}')
    print(f'Test Loss: {loss:.4f}')
    print(f'Test  Precision: {precision:.4f}')
    print(f'Test  Recall: {recall:.4f}')
    print(f'Test  F1 Score: {f1:.4f}')

test_from_checkpoint(gaussian_subtractive_normalization)