In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

from data import OralCancerDataset
from utils import save_checkpoint, load_checkpoint, latest_checkpoint_path

In [2]:
batch_size = 48
num_epochs = 1000
learning_rate = 0.001
num_workers = 4

transform = transforms.Compose([
    transforms.Resize((299, 299)), 
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [3]:
path_to_csv = 'cancer-classification-challenge-2024/train.csv'
path_to_train_images = 'cancer-classification-challenge-2024/train'
path_to_test_images = 'cancer-classification-challenge-2024/test'
model_dir = 'checkpoints/'


train_dataset = OralCancerDataset(path_to_csv=path_to_csv, path_to_image=path_to_train_images, train=1, transform=transform)
val_dataset = OralCancerDataset(path_to_csv=path_to_csv, path_to_image=path_to_train_images, train=0, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)


In [4]:
def get_inceptionv3_model(num_classes=2, pretrained=True):
    model = models.inception_v3(pretrained=pretrained) 
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

def getConvNeXt(num_classes=2):
    model = models.convnext_tiny
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_inceptionv3_model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
writer = SummaryWriter()                                                                    



In [6]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, checkpoint_dir):
    try:
        _, _, _, start_epoch = load_checkpoint(latest_checkpoint_path(model_dir, "G_*.pth"), model, optimizer)
    except:
        start_epoch = 1
    writer = SummaryWriter()

    for epoch in range(start_epoch, num_epochs):
        model.train()
        running_loss = 0.0
        aux_running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs, aux_outputs = model(images)
            loss1 = criterion(outputs, labels)
            loss2 = criterion(aux_outputs, labels)
            loss = loss1 + 0.4 * loss2
            loss.backward()
            optimizer.step()

            running_loss += loss1.item()
            aux_running_loss += loss2.item()

        avg_loss = running_loss / len(train_loader)
        avg_aux_loss = aux_running_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Aux Loss: {avg_aux_loss:.4f}')

        writer.add_scalar('Loss/train', avg_loss, epoch)
        writer.add_scalar('Aux Loss/train', avg_aux_loss, epoch)
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs, _ = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        avg_val_loss = val_loss / len(val_loader)
        accuracy = 100 * correct / total
        print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.2f}%')

        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        writer.add_scalar('Accuracy/val', accuracy, epoch)

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_checkpoint(model, optimizer, learning_rate, epoch + 1, checkpoint_dir)

    writer.close()

In [8]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, model_dir)

KeyboardInterrupt: 

In [None]:
def infer(image_path, model, transform, checkpoint_dir):
    model, _, _, _ = load_checkpoint(checkpoint_dir, model)
    model.eval()

    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output.data, 1)
    
    return predicted.item()