# Imports

In [10]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms

from medmnist import BloodMNIST, INFO

import argparse
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

data_flag = "bloodmnist"
info = INFO[data_flag]
n_classes = len(info["label"])
print("Classes:", n_classes, info["label"])

Device: cuda
Classes: 8 {'0': 'basophil', '1': 'eosinophil', '2': 'erythroblast', '3': 'immature granulocytes(myelocytes, metamyelocytes and promyelocytes)', '4': 'lymphocyte', '5': 'monocyte', '6': 'neutrophil', '7': 'platelet'}


In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])


# Plot Funtions

In [12]:
def plot(x, y, ylabel="", name=""):
    plt.clf()
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.plot(list(x), y)
    plt.savefig(f"{name}.pdf", bbox_inches="tight")


In [13]:
def plot_compare(x, series_dict, ylabel="", title="", filename="compare"):
    plt.clf()
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.title(title)
    for label, y in series_dict.items():
        plt.plot(list(x), y, label=label)
    plt.legend()
    plt.savefig(f"{filename}.pdf", bbox_inches="tight")

# CNN model

In [14]:
class CNN(nn.Module):
    def __init__(self, num_classes=8, use_pool=True):
        super().__init__()
        self.use_pool = use_pool

        # Convs iguais para os dois cenários
        self.conv1 = nn.Conv2d(3, 32, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)

        # Heads diferentes porque o flatten muda
        self.fc1_pool = nn.Linear(128 * 3 * 3, 256)
        self.fc2_pool = nn.Linear(256, num_classes)

        self.fc1_nopool = nn.Linear(128 * 28 * 28, 256)
        self.fc2_nopool = nn.Linear(256, num_classes)

    def forward(self, x, use_pool=None):
        if use_pool is None:
            use_pool = self.use_pool

        x = F.relu(self.conv1(x))
        if use_pool:
            x = self.pool(x)

        x = F.relu(self.conv2(x))
        if use_pool:
            x = self.pool(x)

        x = F.relu(self.conv3(x))
        if use_pool:
            x = self.pool(x)

        x = torch.flatten(x, 1)

        if use_pool:
            x = F.relu(self.fc1_pool(x))
            x = self.dropout(x)
            x = self.fc2_pool(x)
        else:
            x = F.relu(self.fc1_nopool(x))
            x = self.dropout(x)
            x = self.fc2_nopool(x)

        return x

# Train Epoch

In [15]:
def train_epoch(loader, model, criterion, optimizer, use_pool_flag):
    model.train()
    total_loss = 0.0

    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.squeeze().long().to(device)

        optimizer.zero_grad()
        logits = model(imgs, use_pool=use_pool_flag)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


# Evaluate

In [16]:
def evaluate_both(loader, model, use_pool_flag):
    """
    Devolve (acc_sem_softmax, acc_com_softmax).
    """
    model.eval()
    preds_no_soft, preds_soft, targets = [], [], []

    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.squeeze().long().to(device)

            logits = model(imgs, use_pool=use_pool_flag)
            probs = F.softmax(logits, dim=1)

            preds_no_soft += logits.argmax(dim=1).cpu().tolist()
            preds_soft    += probs.argmax(dim=1).cpu().tolist()
            targets       += labels.cpu().tolist()

    acc_no_soft = accuracy_score(targets, preds_no_soft)
    acc_soft    = accuracy_score(targets, preds_soft)
    return acc_no_soft, acc_soft

In [17]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Inicialização do modelo

In [18]:
batch_size = 64
train_dataset = BloodMNIST(split="train", transform=transform, download=True, size=28)
val_dataset   = BloodMNIST(split="val",   transform=transform, download=True, size=28)
test_dataset  = BloodMNIST(split="test",  transform=transform, download=True, size=28)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)


epochs = 200
lr = 1e-3
use_pooling_flags = [False, True]

results = {
    "train_loss": {},
    "val_acc_no_soft": {},
    "val_acc_soft": {},
    "test_acc_no_soft": {},
    "test_acc_soft": {},
    "time_sec": {},
    "params": {}
}

global_start = time.time()

for pool_flag in use_pooling_flags:
    tag = "with_pooling" if pool_flag else "no_pooling"
    print("\n" + "="*70)
    print("EXPERIMENT:", tag)
    print("="*70)

    model = CNN(num_classes=n_classes, use_pool=pool_flag).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    print("Trainable params:", count_params(model))

    train_losses = []
    val_no_soft_list = []
    val_soft_list = []

    exp_start = time.time()

    for epoch in range(epochs):
        t0 = time.time()

        tr_loss = train_epoch(train_loader, model, criterion, optimizer, use_pool_flag=pool_flag)
        val_no_soft, val_soft = evaluate_both(val_loader, model, use_pool_flag=pool_flag)

        train_losses.append(tr_loss)
        val_no_soft_list.append(val_no_soft)
        val_soft_list.append(val_soft)

        dt = time.time() - t0

        print(
            f"Epoch {epoch+1:03d}/{epochs} | "
            f"Loss: {tr_loss:.4f} | "
            f"ValAcc no-soft: {val_no_soft:.4f} | "
            f"ValAcc soft: {val_soft:.4f} | "
            f"Time: {dt:.2f}s"
        )

    test_no_soft, test_soft = evaluate_both(test_loader, model, use_pool_flag=pool_flag)
    exp_time = time.time() - exp_start

    print(f"TestAcc no-soft: {test_no_soft:.4f} | TestAcc soft: {test_soft:.4f}")
    print(f"Total time ({tag}): {exp_time/60:.2f} min ({exp_time:.2f} s)")

    # Save model uniquely (avoid overwrite)
    torch.save(model.state_dict(), f"bloodmnist_cnn_{tag}.pth")
    print(f"Saved: bloodmnist_cnn_{tag}.pth")

    # Store results
    results["train_loss"][tag] = train_losses
    results["val_acc_no_soft"][tag] = val_no_soft_list
    results["val_acc_soft"][tag] = val_soft_list
    results["test_acc_no_soft"][tag] = test_no_soft
    results["test_acc_soft"][tag] = test_soft
    results["time_sec"][tag] = exp_time
    results["params"][tag] = count_params(model)

    # Per-model plots
    ep = range(len(train_losses))

    plot(ep, train_losses, ylabel="Loss",
         name=f"CNN-training-loss_{tag}_lr{lr}")

    plot(ep, val_no_soft_list, ylabel="Accuracy",
         name=f"CNN-validation-accuracy_{tag}_no-softmax_lr{lr}")

    plot(ep, val_soft_list, ylabel="Accuracy",
         name=f"CNN-validation-accuracy_{tag}_softmax_lr{lr}")

# Comparison plots (same figure)
ep = range(epochs)

plot_compare(
    ep,
    {
        "with pooling": results["train_loss"]["with_pooling"],
        "no pooling": results["train_loss"]["no_pooling"],
    },
    ylabel="Loss",
    title="Training Loss: pooling vs no pooling",
    filename=f"COMPARE_training_loss_pool_vs_nopool_lr{lr}"
)

plot_compare(
    ep,
    {
        "pool | no-soft": results["val_acc_no_soft"]["with_pooling"],
        "pool | soft": results["val_acc_soft"]["with_pooling"],
        "no-pool | no-soft": results["val_acc_no_soft"]["no_pooling"],
        "no-pool | soft": results["val_acc_soft"]["no_pooling"],
    },
    ylabel="Accuracy",
    title="Validation Accuracy: pooling vs no pooling (soft vs no-soft)",
    filename=f"COMPARE_val_accuracy_pool_nopool_soft_nosoft_lr{lr}"
)

global_time = time.time() - global_start
print("\n" + "="*70)
print("DONE")
print("Total run time:", f"{global_time/60:.2f} min ({global_time:.2f} s)")
print("Params:", results["params"])
print("Test:", {
    k: (results["test_acc_no_soft"][k], results["test_acc_soft"][k])
    for k in results["test_acc_no_soft"].keys()
})


EXPERIMENT: no_pooling
Trainable params: 26082896
Epoch 001/200 | Loss: 0.8928 | ValAcc no-soft: 0.8452 | ValAcc soft: 0.8452 | Time: 14.11s
Epoch 002/200 | Loss: 0.4624 | ValAcc no-soft: 0.8715 | ValAcc soft: 0.8715 | Time: 10.17s
Epoch 003/200 | Loss: 0.3651 | ValAcc no-soft: 0.8931 | ValAcc soft: 0.8931 | Time: 11.29s
Epoch 004/200 | Loss: 0.3194 | ValAcc no-soft: 0.8984 | ValAcc soft: 0.8984 | Time: 11.60s


KeyboardInterrupt: 