In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
import time
import copy

# Define device (Use GPU if available, otherwise use CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load pre-trained ResNet-50 model and modify the final layer for binary classification
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features  # Get the number of input features to the fully connected layer
model.fc = nn.Linear(num_ftrs, 2)  # Modify the output layer to match the number of classes (2)
model = model.to(device)  # Move model to appropriate device

# Define transformations for training and validation datasets
data_transforms = {
    'TRAIN': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]),
    'VAL': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]),
}

# Load datasets
data_dir = 'DATA_PATH'
image_datasets = {x: datasets.ImageFolder(root=f'{data_dir}/{x}', transform=data_transforms[x]) for x in ['TRAIN', 'VAL']}

# Create DataLoaders for training and validation
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=(x == 'TRAIN'), num_workers=8) for x in ['TRAIN', 'VAL']}

# Get class names
class_names = image_datasets['TRAIN'].classes

# Define loss function (cross-entropy for classification) and optimizer (Adam)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train_model(model, dataloaders, criterion, optimizer, num_epochs=50):
    """Function to train the model and evaluate on the validation set."""
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())  # Store the best model weights
    best_f1 = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['TRAIN', 'VAL']:
            if phase == 'TRAIN':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluation mode

            running_loss = 0.0
            all_preds = []
            all_labels = []

            # Iterate over data batches
            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)

                # Zero the gradients
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == 'TRAIN'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward pass and optimize in training phase
                    if phase == 'TRAIN':
                        loss.backward()
                        optimizer.step()

                # Track loss and predictions
                running_loss += loss.item() * inputs.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

            # Compute average loss and F1 score
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_f1 = f1_score(all_labels, all_preds, average='weighted')

            print(f'{phase} Loss: {epoch_loss:.4f} | F1 Score: {epoch_f1:.4f}')

            # Save best model based on validation F1 score
            if phase == 'VAL' and epoch_f1 > best_f1:
                best_f1 = epoch_f1
                best_model_wts = copy.deepcopy(model.state_dict())

    # Training complete
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best validation F1 Score: {best_f1:.4f}')

    # Load best model weights before returning
    model.load_state_dict(best_model_wts)
    return model

# Train the model
model = train_model(model, dataloaders, criterion, optimizer, num_epochs=50)

# Save the trained model
torch.save(model.state_dict(), 'models/model_resnet50.pth')
print("Model training complete and saved successfully.")