In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
import timm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix

# === CONFIG === #
BATCH_SIZE = 8
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAIN_DIR = "/content/drive/MyDrive/datatset/train"
TEST_DIR = "/content/drive/MyDrive/datatset/test"

# === TRANSFORMS === #
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


# === DATASET === #
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.images = []
        self.labels = []
        self.transform = transform

        for label_name in ["healthy", "sick"]:
            class_root = os.path.join(root_dir, label_name)
            if not os.path.exists(class_root):
                continue


            for folder in os.listdir(class_root):
                sub_path = os.path.join(class_root, folder, "Segmentadas")
                if not os.path.exists(sub_path):
                    continue

                for file_name in os.listdir(sub_path):
                    if file_name.lower().endswith((".jpg", ".jpeg", ".png")):
                        full_path = os.path.join(sub_path, file_name)
                       # print(f"✅ Found: {full_path}")

                        self.images.append(full_path)
                        self.labels.append(0 if label_name == "healthy" else 1)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        return image, label



In [None]:
!ls /content/drive/MyDrive/datatset/

test  train


In [None]:
# === LOADERS === #
train_dataset = CustomDataset(TRAIN_DIR, transform=transform)
test_dataset = CustomDataset(TEST_DIR, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# === MODEL WRAPPER === #
def get_model(model_name):
    model = timm.create_model(model_name, pretrained=True, num_classes=2)
    return model.to(DEVICE)

# === TRAINING === #
def train_model(model, train_loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    model.train()
    for epoch in range(NUM_EPOCHS):
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

# === EVALUATION === #
def evaluate_model(model, test_loader):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(DEVICE)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            y_true.extend(labels.numpy())
            y_pred.extend(preds)

    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    sens = tp / (tp + fn) if (tp + fn) > 0 else 0
    spec = tn / (tn + fp) if (tn + fp) > 0 else 0

    return acc, prec, sens, spec

# === MAIN WORKFLOW === #
print("\n🌀 Training Swin Transformer...")
swin = get_model("swin_tiny_patch4_window7_224")
train_model(swin, train_loader)
acc, prec, _, _ = evaluate_model(swin, test_loader)
print(f"\n🎯 Swin - Accuracy: {acc:.4f}, Precision: {prec:.4f}")

print("\n🌀 Training Vision Transformer...")
vit = get_model("vit_base_patch16_224")
train_model(vit, train_loader)
_, _, sens, spec = evaluate_model(vit, test_loader)
print(f"\n🎯 ViT - Sensitivity: {sens:.4f}, Specificity: {spec:.4f}")



🌀 Training Swin Transformer...

🎯 Swin - Accuracy: 0.7281, Precision: 0.6527

🌀 Training Vision Transformer...


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]


🎯 ViT - Sensitivity: 0.9812, Specificity: 0.2062
