In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from PIL import Image

In [10]:
# **Paths**
DATA_DIR = "D:/oral_cancer_detection/data"
MODEL_PATH = "trained_models/resnet_model.pth"

# **Set device**``
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Data Transformations**
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# **Function to Remove Corrupt Images**
def remove_corrupt_images(data_dir):
    corrupt_images = []
    for class_folder in os.listdir(data_dir):
        class_path = os.path.join(data_dir, class_folder)
        if not os.path.isdir(class_path):
            continue
        for img_file in os.listdir(class_path):
            img_path = os.path.join(class_path, img_file)
            try:
                with Image.open(img_path) as img:
                    img.verify()  # Check if image is corrupt
            except (IOError, SyntaxError):
                print(f"🚨 Removing corrupt image: {img_path}")
                corrupt_images.append(img_path)

    # Delete corrupt images
    for img_path in corrupt_images:
        os.remove(img_path)

# **Remove Corrupt Images Before Loading Dataset**
remove_corrupt_images(DATA_DIR)

# **Load Dataset with Transformations**
dataset = datasets.ImageFolder(DATA_DIR, transform=transform)

# **Check Class Labels (Ensure Cancer & Non-Cancer Are Correct)**
print(f"Class Mapping: {dataset.class_to_idx}")

# **Split Dataset (80% Train, 20% Validation)**
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# **Data Loaders**
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# **Load Pretrained ResNet50 and Modify the Classifier**
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
for param in model.parameters():
    param.requires_grad = False  # Freeze base layers

# **Modify Fully Connected Layer**
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, 2)  # 2 Classes: Cancer & Non-Cancer
)

model.to(device)

# **Loss and Optimizer**
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

# **Training Loop with Progress Bar**
EPOCHS = 30
for epoch in range(EPOCHS):
    model.train()
    total_loss, correct = 0, 0

    loop = tqdm(train_loader, leave=True, desc=f"Epoch [{epoch+1}/{EPOCHS}]")
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

        loop.set_postfix(loss=total_loss / len(train_loader), acc=100 * correct / len(train_dataset))

# **Save Model**
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
torch.save(model.state_dict(), MODEL_PATH)
print("✅ Training complete. Model saved.")


Using device: cuda
Class Mapping: {'cancer': 0, 'non_cancer': 1}


Epoch [1/30]: 100%|██████████| 29/29 [00:17<00:00,  1.63it/s, acc=78.8, loss=0.443]
Epoch [2/30]: 100%|██████████| 29/29 [00:17<00:00,  1.70it/s, acc=87.2, loss=0.308]
Epoch [3/30]: 100%|██████████| 29/29 [00:17<00:00,  1.68it/s, acc=90.2, loss=0.241] 
Epoch [4/30]: 100%|██████████| 29/29 [00:16<00:00,  1.74it/s, acc=92.2, loss=0.2]   
Epoch [5/30]: 100%|██████████| 29/29 [00:17<00:00,  1.69it/s, acc=93.3, loss=0.164] 
Epoch [6/30]: 100%|██████████| 29/29 [00:17<00:00,  1.69it/s, acc=94.9, loss=0.12]  
Epoch [7/30]: 100%|██████████| 29/29 [00:17<00:00,  1.70it/s, acc=95.7, loss=0.11]  
Epoch [8/30]: 100%|██████████| 29/29 [00:17<00:00,  1.68it/s, acc=96.3, loss=0.109] 
Epoch [9/30]: 100%|██████████| 29/29 [00:17<00:00,  1.67it/s, acc=93.9, loss=0.151] 
Epoch [10/30]: 100%|██████████| 29/29 [00:16<00:00,  1.71it/s, acc=97.9, loss=0.0666]
Epoch [11/30]: 100%|██████████| 29/29 [00:15<00:00,  1.86it/s, acc=98.5, loss=0.0499]
Epoch [12/30]: 100%|██████████| 29/29 [00:15<00:00,  1.91it/s, ac

✅ Training complete. Model saved.


In [22]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import requests
from io import BytesIO
import os

# **Path to Trained Model**
MODEL_PATH = "trained_models/resnet_model.pth"

# **Set Device**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# **Image Preprocessing (Same as Training)**
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# **Load ResNet50 Model**
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, 2)  # 2 Classes: Cancer & Non-Cancer
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

# **Function to Load Image from URL or Local Path**
def load_image(image_path_or_url):
    try:
        if image_path_or_url.startswith("http"):  # **Load from URL**
            response = requests.get(image_path_or_url)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        elif os.path.exists(image_path_or_url):  # **Load from Local Path**
            image = Image.open(image_path_or_url).convert("RGB")
        else:
            return None, "Error: Invalid path or URL"

        # **Preprocess Image**
        image = transform(image).unsqueeze(0).to(device)
        return image, None

    except Exception as e:
        return None, f"Error loading image: {e}"

# **Function to Predict Cancer**
def predict_cancer(image_path_or_url):
    image, error = load_image(image_path_or_url)
    if error:
        return error

    # **Model Prediction**
    with torch.no_grad():
        output = model(image)
        pred_class = torch.argmax(output, dim=1).item()

    # **Class Mapping**
    class_names = ["Cancerous", "Non-Cancerous"]
    return f"Prediction: {class_names[pred_class]}"



# For Local Image
image_path =   r"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTqsHJCmC0V85uixyzu-q1icAA922jtAKhgnOHB7fYh5uvBK_eGaY7_hzmedKZFxNScTNg&usqp=CAU"# Replace with actual file path
print(predict_cancer(image_path))


Prediction: Non-Cancerous
