In [1]:
import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
from timm import create_model

In [None]:

# **Paths**
DATA_DIR = "D:/oral_cancer_detection/data"
MODEL_DIR = "trained_models"
MODEL_PATH = os.path.join(MODEL_DIR, "vit_model.pth")

# **Set device**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Ensure Model Directory Exists**
os.makedirs(MODEL_DIR, exist_ok=True)

# **Data Transformations**
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# **Function to Remove Corrupt Images**
def remove_corrupt_images(data_dir):
    corrupt_images = []
    for class_folder in os.listdir(data_dir):
        class_path = os.path.join(data_dir, class_folder)
        if not os.path.isdir(class_path):
            continue
        for img_file in os.listdir(class_path):
            img_path = os.path.join(class_path, img_file)
            try:
                with Image.open(img_path) as img:
                    img.verify()
            except (IOError, SyntaxError):
                print(f"\U0001F6A8 Removing corrupt image: {img_path}")
                corrupt_images.append(img_path)
    for img_path in corrupt_images:
        os.remove(img_path)

# **Remove Corrupt Images Before Loading Dataset**
remove_corrupt_images(DATA_DIR)

# **Load Dataset with Transformations**
dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
print(f"Class Mapping: {dataset.class_to_idx}")

# **Split Dataset (80% Train, 20% Validation)**
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# **Data Loaders**
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# **Load ViT Model**
try:
    model = create_model("vit_base_patch16_224", pretrained=True)
except:
    print("Downloading ViT model...")
    os.system("pip install timm")
    from timm import create_model
    model = create_model("vit_base_patch16_224", pretrained=True)

# **Modify Classifier**
model.head = nn.Sequential(
    nn.Linear(model.head.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, 2)
)

model.to(device)

# **Loss and Optimizer**
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.head.parameters(), lr=0.001)

# **Training Loop with Progress Bar**
EPOCHS = 30
for epoch in range(EPOCHS):
    model.train()
    total_loss, correct = 0, 0
    loop = tqdm(train_loader, leave=True, desc=f"Epoch [{epoch+1}/{EPOCHS}]")
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

        loop.set_postfix(loss=total_loss / len(train_loader), acc=100 * correct / len(train_dataset))

# **Save Model**
torch.save(model.state_dict(), MODEL_PATH)
print(f"✅ Training complete. Model saved at {MODEL_PATH}")


Using device: cuda
Class Mapping: {'cancer': 0, 'non_cancer': 1}


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Epoch [1/30]: 100%|██████████| 29/29 [18:07<00:00, 37.49s/it, acc=82.2, loss=0.435]
Epoch [2/30]: 100%|██████████| 29/29 [17:15<00:00, 35.72s/it, acc=90.8, loss=0.244] 
Epoch [3/30]: 100%|██████████| 29/29 [17:07<00:00, 35.44s/it, acc=92, loss=0.195]   
Epoch [4/30]: 100%|██████████| 29/29 [30:08<00:00, 62.35s/it, acc=93.5, loss=0.152] 
Epoch [5/30]: 100%|██████████| 29/29 [14:53<00:00, 30.82s/it, acc=95.1, loss=0.132] 
Epoch [6/30]: 100%|██████████| 29/29 [14:42<00:00, 30.44s/it, acc=94.7, loss=0.131] 
Epoch [7/30]: 100%|██████████| 29/29 [14:42<00:00, 30.41s/it, acc=95.2, loss=0.116] 
Epoch [8/30]: 100%|██████████| 29/29 [14:54<00:00, 30.83s/it, acc=97.6, loss=0.0658]
Epoch [9/30]: 100%|██████████| 29/29 [15:03<00:00, 31.15s/it