In [None]:
# Model
class CancerClassifierCNN(nn.Module):
    def __init__(self):
        super(CancerClassifierCNN, self).__init__()
          
        # Input : 3 x 96 x 96 [Here 3 are the 3 color channels i.e, Red, Green and Blue]
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 3, padding = 1) 
        # 16x96x96 (the 3 input channels are convolved by 16 3D filters resulting in 16 output channels(Feature Maps) You as the model designer decide the number of out channels )
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2) # Halves size each time
        # Each of the feature map [out_channels] go through pooling by a 2 x 2 matrix (kernel) by a stride of 2.
        # Resulting each of the feature map to halve in size.

        self.conv2 = nn.Conv2d(16, 32, 3, padding=1) # The 16 feature maps are convolved into 32 out_channels through a 3x3 matrix 
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # The 32 feature maps are convolved into 64 out_channels through a 3x3 matrix 

        # After 3 pools -> Final Feature map size = 64 x 12 x 12
        self.fc1 = nn.Linear(64 * 12 * 12, 128)
        self.fc2 = nn.Linear(128, 1) # Binary Output
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # [3, 96, 96] -> [16, 48, 48]
        x = self.pool(F.relu(self.conv2(x))) # [16, 48, 48] -> [32, 24, 24]
        x = self.pool(F.relu(self.conv3(x))) # [32, 24, 24] -> [64, 12, 12]

        x = x.view(-1, 64 * 12 * 12) # Flattening for FC Layer
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
model = CancerClassifierCNN().to(device)

In [None]:
# Define class imbalance weight
pos_weight = torch.tensor([130908 / 89117], dtype=torch.float32).to(device)
# Loss function for binary classification with imbalance
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Using device: {device}")
if device.type == 'cuda':
    print(f"🖥️ GPU: {torch.cuda.get_device_name(0)}")

best_mps = 0
epochs = 10

f1_list = []
recall_list = []
precision_list = []
medical_priority_list = []
roc_auc_list = []

for epoch in range(epochs):
    print(f"\n📘 Epoch {epoch+1}")

    model.train()
    train_loader_tqdm = tqdm(train_loader, desc="🔁 Training", leave=True)
    for images, labels in train_loader_tqdm:
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loader_tqdm.set_postfix(loss=loss.item())

    model.eval()
    total_val_loss = 0
    y_true, y_pred, y_prob = [], [], []

    val_loader_tqdm = tqdm(val_loader, desc="🧪 Validating", leave=True)
    with torch.inference_mode():
        for images, labels in val_loader_tqdm:
            images = images.to(device)
            labels = labels.float().unsqueeze(1).to(device)

            outputs = model(images)
            loss = loss_fn(outputs, labels)
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            total_val_loss += loss.item()
            y_true += labels.cpu().numpy().flatten().tolist()
            y_pred += preds.cpu().numpy().flatten().tolist()
            y_prob += probs.cpu().numpy().flatten().tolist()

            val_loader_tqdm.set_postfix(loss=loss.item())

    avg_loss_val = total_val_loss / len(val_loader)

    # Metrics
    f1 = f1_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    roc = roc_auc_score(y_true, y_pred)
    Priority = medical_priority_score(precision, recall, roc)

    f1_list.append(f1)
    precision_list.append(precision)
    recall_list.append(recall)
    roc_auc_list.append(roc)
    medical_priority_list.append(Priority)

    print(f"\n📊 Epoch: {epoch+1} | Priority Medical Score: {Priority:.4f} | F1: {f1:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | "
          f"ROC AUC: {roc:.4f} | Accuracy: {accuracy_score(y_true, y_pred):.4f} | Loss: {avg_loss_val:.4f}")
    print("🧾 Confusion Matrix:\n", cm)
    plot_roc_curve(y_true, y_prob, roc, epoch=epoch+1)

    if Priority > best_mps:
        best_mps = Priority
        torch.save({'model_state': model.state_dict(),
                   'epoch': epoch,
                   'f1_score': f1,
                   'roc_score': roc,
                   'precision': precision,
                   'Priority': Priority}, 
                   "best_cnn_model.pth")
        print("✅ Best model saved (↑ Medical Priority Score)")
