# SkinSight: Stage 1 - Filter Model Training

This notebook trains the **Filter Model** (Binary Classifier: Skin vs Random).
Its purpose is to filter out non-skin images before they reach the diagnostic model.

**Classes**:
0: Random Object
1: Skin (Melanoma + Tinea)

In [1]:
# 1. Dependencies
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [2]:
# 2. Configuration
DATASET_DIR = 'Dataset'
CACHE_DIR = 'Dataset_Cache_Filter' # Separate cache for safety
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 0.001

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed()

In [3]:
# 3. Preprocessing (Optimized with Caching)
class AdvancedPreprocessing:
    @staticmethod
    def hair_removal(image_np):
        gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (17, 17))
        blackhat = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, kernel)
        _, mask = cv2.threshold(blackhat, 10, 255, cv2.THRESH_BINARY)
        return cv2.inpaint(image_np, mask, 3, cv2.INPAINT_TELEA)

    @staticmethod
    def apply_clahe(image_np):
        lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        l = clahe.apply(l)
        return cv2.cvtColor(cv2.merge((l,a,b)), cv2.COLOR_LAB2RGB)

    @staticmethod
    def shades_of_grey(image_np, power=6):
        img_dtype = image_np.dtype
        r_norm = np.power(np.mean(np.power(image_np[:,:,0], power)), 1/power)
        g_norm = np.power(np.mean(np.power(image_np[:,:,1], power)), 1/power)
        b_norm = np.power(np.mean(np.power(image_np[:,:,2], power)), 1/power)
        norm_vec = np.array([r_norm, g_norm, b_norm])
        norm_vec = norm_vec / (np.sqrt(np.sum(np.square(norm_vec))) + 1e-6)
        uniform_illum = 1 / np.sqrt(3)
        manul_wb = np.diag([uniform_illum]*3) / (np.diag(norm_vec) + 1e-6)
        corrected = np.dot(image_np, manul_wb)
        return np.clip(corrected, 0, 255).astype(img_dtype)


In [4]:
# 4. Dataset Class (Filter Mode)
class SkinFilterDataset(Dataset):
    def __init__(self, file_paths, labels):
        self.file_paths = file_paths
        self.labels = labels
        self.resize = transforms.Resize((224, 224))
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        # Ensure Cache Directory
        os.makedirs(CACHE_DIR, exist_ok=True)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        label = self.labels[idx]
        
        # Caching Logic
        filename_id = f"{os.path.basename(os.path.dirname(img_path))}_{os.path.basename(img_path)}"
        cache_path = os.path.join(CACHE_DIR, filename_id)
        
        image = None
        if os.path.exists(cache_path):
            try:
                image = cv2.imread(cache_path)
                if image is not None:
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            except:
                pass
        
        if image is None:
            try:
                image = cv2.imread(img_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                # Apply preprocessing
                image = AdvancedPreprocessing.hair_removal(image)
                image = AdvancedPreprocessing.apply_clahe(image)
                image = AdvancedPreprocessing.shades_of_grey(image)
                # Save cache
                cv2.imwrite(cache_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
            except Exception as e:
                print(f"Error: {e}")
                image = np.zeros((224, 224, 3), dtype=np.uint8)

        image_pil = Image.fromarray(image)
        x = self.resize(image_pil)
        x = self.to_tensor(x)
        x = self.normalize(x)
        return x, label

In [5]:
# 5. Data Preparation
def prepare_filter_data(dataset_dir):
    melanoma = [os.path.join(dataset_dir, 'Melanoma', f) for f in os.listdir(os.path.join(dataset_dir, 'Melanoma'))]
    tinea = [os.path.join(dataset_dir, 'Tinea', f) for f in os.listdir(os.path.join(dataset_dir, 'Tinea'))]
    
    random_path = os.path.join(dataset_dir, 'random')
    if not os.path.exists(random_path): random_path = os.path.join(dataset_dir, 'Random_Obj')
    random_files = [os.path.join(random_path, f) for f in os.listdir(random_path)]
    
    # Class 0: Random, Class 1: Skin
    files = random_files + melanoma + tinea
    labels = [0]*len(random_files) + [1]*(len(melanoma) + len(tinea))
    
    X_train, X_temp, y_train, y_temp = train_test_split(files, labels, test_size=0.2, stratify=labels, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)
    
    return (X_train, y_train), (X_val, y_val), (X_test, y_test)

def get_weighted_sampler(labels):
    counts = np.bincount(labels)
    weights = 1. / counts
    sample_weights = [weights[l] for l in labels]
    return WeightedRandomSampler(sample_weights, len(sample_weights))

(train_X, train_y), (val_X, val_y), (test_X, test_y) = prepare_filter_data(DATASET_DIR)

train_ds = SkinFilterDataset(train_X, train_y)
val_ds = SkinFilterDataset(val_X, val_y)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=get_weighted_sampler(train_y), drop_last=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")

Train: 4375 | Val: 547


In [6]:
# 6. Model Definition (ResNet18)
class FilterModel(nn.Module):
    def __init__(self):
        super(FilterModel, self).__init__()
        self.model = models.resnet18(weights='IMAGENET1K_V1')
        self.model.fc = nn.Linear(self.model.fc.in_features, 2) 
    def forward(self, x):
        return self.model(x)

model = FilterModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [7]:
# 7. Training Loop
os.makedirs('models', exist_ok=True)
best_acc = 0.0

print("Starting Training...")
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    torch.cuda.empty_cache()
    
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            
    val_acc = 100 * correct / total
    print(f"Loss: {running_loss/len(train_loader):.4f} | Val Acc: {val_acc:.2f}%")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'models/filter_model.pth')
        print("Model Saved!")

print("Filter Model Training Completed.")

Starting Training...


Epoch 1/5: 100%|██████████████████████████████████████████████████████████| 273/273 [09:35<00:00,  2.11s/it]


Loss: 0.1396 | Val Acc: 98.35%
Model Saved!


Epoch 2/5: 100%|██████████████████████████████████████████████████████████| 273/273 [07:16<00:00,  1.60s/it]


Loss: 0.0914 | Val Acc: 98.16%


Epoch 3/5: 100%|██████████████████████████████████████████████████████████| 273/273 [06:44<00:00,  1.48s/it]


Loss: 0.0691 | Val Acc: 95.22%


Epoch 4/5: 100%|██████████████████████████████████████████████████████████| 273/273 [06:12<00:00,  1.36s/it]


Loss: 0.0604 | Val Acc: 96.88%


Epoch 5/5: 100%|██████████████████████████████████████████████████████████| 273/273 [05:46<00:00,  1.27s/it]


Loss: 0.0318 | Val Acc: 98.53%
Model Saved!
Filter Model Training Completed.
