In [1]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from timm import create_model
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm


2025-04-20 23:03:43.243469: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745190223.449418      55 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745190223.510516      55 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
def jaccard_index(y_true, y_pred, smooth=100):
    y_true_f = tf.reshape(tf.cast(y_true, tf.float32), [-1])
    y_pred_f = tf.reshape(tf.cast(y_pred, tf.float32), [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    total = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersection
    return (intersection + smooth) / (total + smooth)

def dice_coefficient(y_true, y_pred, smooth=1):
    y_true_f = tf.reshape(tf.cast(y_true, tf.float32), [-1])
    y_pred_f = tf.reshape(tf.cast(y_pred, tf.float32), [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)


In [3]:
segment_model = load_model(
    '/kaggle/input/segment_model/keras/default/1/best_model (1).keras',
    custom_objects={'dice_coefficient': dice_coefficient, 'jaccard_index': jaccard_index}
)


I0000 00:00:1745190249.389918      55 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1745190249.390660      55 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


In [4]:
class SegmentedChestXrayDataset(Dataset):
    def __init__(self, root_dir, categories, segment_model, transform=None, limit_normal=1000):
        self.images = []
        self.labels = []
        self.transform = transform
        self.segment_model = segment_model

        for label_index, category in enumerate(categories):
            category_path = os.path.join(root_dir, category)
            image_files = [f for f in os.listdir(category_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            if category.lower() == 'normal':
                image_files = image_files[:limit_normal]

            for file in tqdm(image_files, desc=f"Loading {category}"):
                try:
                    img_path = os.path.join(category_path, file)
                    img = cv2.imread(img_path)
                    img = cv2.resize(img, (256, 256)) / 255.0
                    input_img = np.expand_dims(img, axis=0)
                    mask = self.segment_model.predict(input_img, verbose=0)[0]
                    mask = (mask > 0.5).astype(np.float32)
                    segmented = img * mask
                    segmented_uint8 = (segmented * 255).astype(np.uint8)
                    if self.transform:
                        segmented_uint8 = self.transform(segmented_uint8)
                    self.images.append(segmented_uint8)
                    self.labels.append(label_index)
                except Exception as e:
                    print(f"Skipping {file}: {e}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root_dir = '/kaggle/input/tuberculosis-tb-chest-xray-dataset/TB_Chest_Radiography_Database'
categories = ['Normal', 'Tuberculosis']

transform_vit = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = SegmentedChestXrayDataset(
    root_dir=root_dir,
    categories=categories,
    segment_model=segment_model,
    transform=transform_vit,
    limit_normal=1000
)


Loading Normal:   0%|          | 0/1000 [00:00<?, ?it/s]I0000 00:00:1745190253.217717      92 cuda_dnn.cc:529] Loaded cuDNN version 90300
Loading Normal: 100%|██████████| 1000/1000 [01:54<00:00,  8.77it/s]
Loading Tuberculosis: 100%|██████████| 700/700 [01:17<00:00,  8.99it/s]


In [6]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False)


In [7]:
vit_model = create_model('vit_base_patch16_224', pretrained=True, num_classes=2)
vit_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=1e-4)


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [8]:
num_epochs = 20

for epoch in range(num_epochs):
    vit_model.train()
    total_loss = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = vit_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    # --- Validation Accuracy ---
    vit_model.eval()
    y_true_val, y_pred_val = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = vit_model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            y_pred_val.extend(preds)
            y_true_val.extend(labels.numpy())

    val_acc = accuracy_score(y_true_val, y_pred_val)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f} - Val Accuracy: {val_acc:.4f}")


Epoch 1/20: 100%|██████████| 85/85 [00:48<00:00,  1.76it/s]


Epoch 1/20 - Loss: 0.6719 - Val Accuracy: 0.7265


Epoch 2/20: 100%|██████████| 85/85 [00:51<00:00,  1.65it/s]


Epoch 2/20 - Loss: 0.4369 - Val Accuracy: 0.9206


Epoch 3/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 3/20 - Loss: 0.3339 - Val Accuracy: 0.8941


Epoch 4/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 4/20 - Loss: 0.2455 - Val Accuracy: 0.9029


Epoch 5/20: 100%|██████████| 85/85 [00:50<00:00,  1.69it/s]


Epoch 5/20 - Loss: 0.1921 - Val Accuracy: 0.8853


Epoch 6/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 6/20 - Loss: 0.1638 - Val Accuracy: 0.8441


Epoch 7/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 7/20 - Loss: 0.1061 - Val Accuracy: 0.8588


Epoch 8/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 8/20 - Loss: 0.1072 - Val Accuracy: 0.8912


Epoch 9/20: 100%|██████████| 85/85 [00:50<00:00,  1.67it/s]


Epoch 9/20 - Loss: 0.0920 - Val Accuracy: 0.9176


Epoch 10/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 10/20 - Loss: 0.1397 - Val Accuracy: 0.8647


Epoch 11/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 11/20 - Loss: 0.0933 - Val Accuracy: 0.9000


Epoch 12/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 12/20 - Loss: 0.0584 - Val Accuracy: 0.8471


Epoch 13/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 13/20 - Loss: 0.0928 - Val Accuracy: 0.9118


Epoch 14/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 14/20 - Loss: 0.0777 - Val Accuracy: 0.9059


Epoch 15/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 15/20 - Loss: 0.0357 - Val Accuracy: 0.8500


Epoch 16/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 16/20 - Loss: 0.0306 - Val Accuracy: 0.8941


Epoch 17/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 17/20 - Loss: 0.0276 - Val Accuracy: 0.9118


Epoch 18/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 18/20 - Loss: 0.0326 - Val Accuracy: 0.8853


Epoch 19/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 19/20 - Loss: 0.0610 - Val Accuracy: 0.8971


Epoch 20/20: 100%|██████████| 85/85 [00:50<00:00,  1.68it/s]


Epoch 20/20 - Loss: 0.0292 - Val Accuracy: 0.9235


In [9]:
vit_model.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = vit_model(images)
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(labels.numpy())

acc = accuracy_score(y_true, y_pred)
print("\n--- Evaluation Metrics ---")
print("Accuracy:", acc)
print(classification_report(y_true, y_pred, target_names=categories))



--- Evaluation Metrics ---
Accuracy: 0.9235294117647059
              precision    recall  f1-score   support

      Normal       0.93      0.95      0.94       204
Tuberculosis       0.92      0.89      0.90       136

    accuracy                           0.92       340
   macro avg       0.92      0.92      0.92       340
weighted avg       0.92      0.92      0.92       340

