In [None]:
# Melanoma Detection Tool

**Contributors**
- S. Antoniadis
- N. Lin
- J. Lim

**Model**
- ResNet-50 (pretrained, fine-tuned)

**Performance**
| Metric                | Value   |
|----------------------:|--------:|
| Train Loss            | 0.1372  |
| Validation Accuracy   | 93.83%  |
| Test Accuracy         | 93.15%  |

Dataset source:
www.kaggle.com/datasets/drscarlat/melanoma

Code source:
www.digitalocean.com/community/tutorials/writing-resnet-from-scratch-in-pytorch

In [None]:
import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [None]:
# 1. Set up transforms
# Resize images to 224x224 and normalize using ImageNet means/stddevs
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# 2. Load datasets
train_dataset = datasets.ImageFolder('/train_sep', transform=transform)
val_dataset   = datasets.ImageFolder('valid', transform=transform)
test_dataset  = datasets.ImageFolder('/test', transform=transform)

In [None]:
# 3. Create DataLoader objects
# Loads images in batches; shuffle=True for training for better generalization; no benefit for validation or testing
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# 4. Load pre-trained ResNet50
# Replace for binary classification
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # two output classes: melanoma and no_melanoma

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# 5. Set loss function and optimizer
# CrossEntropyLoss is standard for classification
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# 6. Training the ResNet50 model
num_epochs = 10  # Can be increased

best_val_acc = 0.0

for epoch in range(num_epochs):
    model.train()  # set to training mode
    running_loss = 0.0
    train_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for images, labels in train_iter:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        train_iter.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}")
    # Validate after each epoch
    model.eval()  # set model to evaluation mode (disables dropout/batchnorm)
    val_correct = 0
    val_total = 0
    with torch.no_grad():  # no gradients needed
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)              # Get predicted class
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    val_acc = 100 * val_correct / val_total
    print(f'Validation Accuracy: {val_acc:.2f}%')

In [None]:
# 7. Test Evaluation (after training is complete)
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

In [None]:
torch.save(model.state_dict(), 'best_resnet50_melanoma.pth')
print(f"Model saved with Validation Accuracy: {val_acc:.2f}%")

In [None]:
test_acc = 100 * test_correct / test_total
print(f'Test Accuracy: {test_acc:.2f}%')