In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split

from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm


# =========================
# DEVICE
# =========================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


# =========================
# DATA PATH
# =========================

data_dir = "/kaggle/input/datasets/yasserhessein/the-kvasir-dataset/kvasir-dataset-v2"  
# change if needed


# =========================
# TRANSFORMS
# =========================

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


# =========================
# DATASET
# =========================

full_dataset = datasets.ImageFolder(data_dir, transform=transform)

class_names = full_dataset.classes
num_classes = len(class_names)

print("Classes:", class_names)


# =========================
# SPLIT DATASET
# =========================

train_size = int(0.7 * len(full_dataset))
val_size   = int(0.15 * len(full_dataset))
test_size  = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset,
    [train_size, val_size, test_size]
)


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)


# =========================
# MODEL
# =========================

model = models.alexnet(pretrained=True)

num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, num_classes)

model = model.to(device)


# =========================
# LOSS & OPTIMIZER
# =========================

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.0001)


# =========================
# TRAIN FUNCTION
# =========================

def train(model, loader):

    model.train()

    preds_all = []
    targets_all = []
    running_loss = 0

    for x, y in tqdm(loader):

        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        outputs = model(x)

        loss = criterion(outputs, y)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        preds_all.append(outputs.argmax(1).cpu().numpy())
        targets_all.append(y.cpu().numpy())

    preds_all = np.concatenate(preds_all)
    targets_all = np.concatenate(targets_all)

    acc = accuracy_score(targets_all, preds_all)

    return running_loss / len(loader), acc


# =========================
# EVALUATE FUNCTION
# =========================

def evaluate(model, loader):

    model.eval()

    preds_all = []
    targets_all = []
    running_loss = 0

    with torch.no_grad():

        for x, y in tqdm(loader):

            x = x.to(device)
            y = y.to(device)

            outputs = model(x)

            loss = criterion(outputs, y)

            running_loss += loss.item()

            preds_all.append(outputs.argmax(1).cpu().numpy())
            targets_all.append(y.cpu().numpy())

    preds_all = np.concatenate(preds_all)
    targets_all = np.concatenate(targets_all)

    acc = accuracy_score(targets_all, preds_all)

    return running_loss / len(loader), acc


# =========================
# TRAINING LOOP
# =========================

EPOCHS = 15
best_val_acc = 0

for epoch in range(EPOCHS):

    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    train_loss, train_acc = train(model, train_loader)
    val_loss, val_acc = evaluate(model, val_loader)

    print(f"Train Acc: {train_acc*100:.2f}%")
    print(f"Val   Acc: {val_acc*100:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        print("Best model saved")


# =========================
# LOAD BEST MODEL
# =========================

model.load_state_dict(torch.load("best_model.pth"))


# =========================
# TESTING
# =========================

def test(model, loader):

    model.eval()

    preds_all = []
    targets_all = []

    with torch.no_grad():

        for x, y in tqdm(loader):

            x = x.to(device)
            y = y.to(device)

            outputs = model(x)

            preds_all.append(outputs.argmax(1).cpu().numpy())
            targets_all.append(y.cpu().numpy())

    preds_all = np.concatenate(preds_all)
    targets_all = np.concatenate(targets_all)

    acc = accuracy_score(targets_all, preds_all)

    print("\n===== TEST RESULTS =====")
    print(f"Test Accuracy: {acc*100:.2f}%\n")

    print(classification_report(targets_all, preds_all, target_names=class_names))


test(model, test_loader)

Device: cuda
Classes: ['dyed-lifted-polyps', 'dyed-resection-margins', 'esophagitis', 'normal-cecum', 'normal-pylorus', 'normal-z-line', 'polyps', 'ulcerative-colitis']





Epoch 1/15


100%|██████████| 175/175 [01:24<00:00,  2.08it/s]
100%|██████████| 38/38 [00:16<00:00,  2.26it/s]


Train Acc: 72.84%
Val   Acc: 81.08%
Best model saved

Epoch 2/15


100%|██████████| 175/175 [01:24<00:00,  2.08it/s]
100%|██████████| 38/38 [00:16<00:00,  2.27it/s]


Train Acc: 85.88%
Val   Acc: 86.42%
Best model saved

Epoch 3/15


100%|██████████| 175/175 [01:21<00:00,  2.14it/s]
100%|██████████| 38/38 [00:16<00:00,  2.31it/s]


Train Acc: 88.95%
Val   Acc: 83.50%

Epoch 4/15


100%|██████████| 175/175 [01:23<00:00,  2.10it/s]
100%|██████████| 38/38 [00:17<00:00,  2.23it/s]


Train Acc: 91.62%
Val   Acc: 87.00%
Best model saved

Epoch 5/15


100%|██████████| 175/175 [01:23<00:00,  2.10it/s]
100%|██████████| 38/38 [00:16<00:00,  2.27it/s]


Train Acc: 93.96%
Val   Acc: 87.33%
Best model saved

Epoch 6/15


100%|██████████| 175/175 [01:22<00:00,  2.12it/s]
100%|██████████| 38/38 [00:16<00:00,  2.24it/s]


Train Acc: 94.86%
Val   Acc: 86.83%

Epoch 7/15


100%|██████████| 175/175 [01:23<00:00,  2.09it/s]
100%|██████████| 38/38 [00:16<00:00,  2.24it/s]


Train Acc: 95.59%
Val   Acc: 88.42%
Best model saved

Epoch 8/15


100%|██████████| 175/175 [01:23<00:00,  2.10it/s]
100%|██████████| 38/38 [00:16<00:00,  2.26it/s]


Train Acc: 95.95%
Val   Acc: 88.92%
Best model saved

Epoch 9/15


100%|██████████| 175/175 [01:22<00:00,  2.12it/s]
100%|██████████| 38/38 [00:16<00:00,  2.33it/s]


Train Acc: 96.41%
Val   Acc: 86.92%

Epoch 10/15


100%|██████████| 175/175 [01:22<00:00,  2.12it/s]
100%|██████████| 38/38 [00:17<00:00,  2.23it/s]


Train Acc: 97.36%
Val   Acc: 87.58%

Epoch 11/15


100%|██████████| 175/175 [01:24<00:00,  2.08it/s]
100%|██████████| 38/38 [00:17<00:00,  2.22it/s]


Train Acc: 97.14%
Val   Acc: 85.42%

Epoch 12/15


100%|██████████| 175/175 [01:23<00:00,  2.08it/s]
100%|██████████| 38/38 [00:17<00:00,  2.21it/s]


Train Acc: 97.38%
Val   Acc: 87.00%

Epoch 13/15


100%|██████████| 175/175 [01:24<00:00,  2.07it/s]
100%|██████████| 38/38 [00:16<00:00,  2.24it/s]


Train Acc: 98.50%
Val   Acc: 87.33%

Epoch 14/15


100%|██████████| 175/175 [01:23<00:00,  2.09it/s]
100%|██████████| 38/38 [00:16<00:00,  2.24it/s]


Train Acc: 98.48%
Val   Acc: 85.92%

Epoch 15/15


100%|██████████| 175/175 [01:23<00:00,  2.10it/s]
100%|██████████| 38/38 [00:16<00:00,  2.30it/s]


Train Acc: 98.57%
Val   Acc: 87.25%


100%|██████████| 38/38 [00:17<00:00,  2.13it/s]


===== TEST RESULTS =====
Test Accuracy: 89.92%

                        precision    recall  f1-score   support

    dyed-lifted-polyps       0.92      0.91      0.91       153
dyed-resection-margins       0.91      0.93      0.92       148
           esophagitis       0.87      0.66      0.75       137
          normal-cecum       0.94      0.97      0.96       155
        normal-pylorus       0.99      0.99      0.99       172
         normal-z-line       0.74      0.90      0.81       143
                polyps       0.91      0.89      0.90       158
    ulcerative-colitis       0.91      0.91      0.91       134

              accuracy                           0.90      1200
             macro avg       0.90      0.89      0.89      1200
          weighted avg       0.90      0.90      0.90      1200




