In [1]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from timm import create_model
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm


2025-05-29 19:37:51.867361: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748547472.052812      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748547472.110066      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
def jaccard_index(y_true, y_pred, smooth=100):
    y_true_f = tf.reshape(tf.cast(y_true, tf.float32), [-1])
    y_pred_f = tf.reshape(tf.cast(y_pred, tf.float32), [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    total = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersection
    return (intersection + smooth) / (total + smooth)

def dice_coefficient(y_true, y_pred, smooth=1):
    y_true_f = tf.reshape(tf.cast(y_true, tf.float32), [-1])
    y_pred_f = tf.reshape(tf.cast(y_pred, tf.float32), [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)


In [3]:
segment_model = load_model(
    '/kaggle/input/segment_model/keras/default/1/best_model (1).keras',
    custom_objects={'dice_coefficient': dice_coefficient, 'jaccard_index': jaccard_index}
)


I0000 00:00:1748547507.778369      35 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1748547507.779052      35 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


In [4]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import cv2  # for resizing

class SegmentedChestXrayDataset(Dataset):
    def __init__(self, csv_path, image_folder, segment_model, transform=None, limit_normal=None):
        self.df = pd.read_csv(csv_path)
        self.image_folder = image_folder
        self.segment_model = segment_model
        self.transform = transform
        self.data = []

        normal_count = 0
        for _, row in self.df.iterrows():
            label = 1 if row['image_type'].strip().lower() == 'tb' else 0
            if label == 0 and limit_normal is not None:
                if normal_count >= limit_normal:
                    continue
                normal_count += 1
            self.data.append((row['fname'], label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, label = self.data[idx]
        img_path = os.path.join(self.image_folder, img_name)

        # Load image
        image = Image.open(img_path).convert("RGB")
        image_np = np.array(image)
        original_shape = image_np.shape[:2]  # (H, W)

        # Resize for segmentation model
        resized_input = cv2.resize(image_np, (256, 256)).astype(np.float32) / 255.0
        resized_input = np.expand_dims(resized_input, axis=0)  # shape (1, 256, 256, 3)

        # Predict mask
        mask = self.segment_model.predict(resized_input)[0]
        if mask.ndim == 3:
            mask = mask[:, :, 0]  # (256, 256)

        # Resize mask back to original image size and binarize
        mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
        mask_binary = (mask_resized > 0.5).astype(np.uint8)

        # Apply mask to image
        masked_image = image_np * np.expand_dims(mask_binary, axis=-1)

        # Convert to PIL image for transforms
        masked_pil = Image.fromarray(masked_image.astype(np.uint8))

        if self.transform:
            masked_pil = self.transform(masked_pil)

        return masked_pil, label


In [5]:
import os
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root_dir = '/kaggle/input/tbx11k-simplified/tbx11k-simplified'

transform_vit = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Define paths
csv_path = os.path.join(root_dir, "data.csv")
image_folder = os.path.join(root_dir, "images")

# Create dataset
dataset = SegmentedChestXrayDataset(
    csv_path=csv_path,
    image_folder=image_folder,
    segment_model=segment_model,  # make sure this model is already loaded
    transform=transform_vit,
    limit_normal=1000
)


In [6]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False)


In [4]:
vit_model = create_model('vit_base_patch16_224', pretrained=True, num_classes=2)
vit_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=1e-4)


NameError: name 'create_model' is not defined

In [2]:
output = vit_model(input_tensor)
print(output.shape)  # Should be [1, num_classes], e.g., [1, 2]
print(output)        # Check if logits look reasonable (not all zeros)


NameError: name 'vit_model' is not defined

In [8]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt


In [None]:
num_epochs = 100

train_losses = []
val_accuracies = []
val_recalls = []
val_precisions = []
val_f1s = []

for epoch in range(num_epochs):
    vit_model.train()
    total_loss = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = vit_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    # --- Validation ---
    vit_model.eval()
    y_true_val, y_pred_val = [], []
    val_loss = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = vit_model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            y_pred_val.extend(preds)
            y_true_val.extend(labels.cpu().numpy())

    val_acc = accuracy_score(y_true_val, y_pred_val)
    precision = precision_score(y_true_val, y_pred_val, average='binary')
    recall = recall_score(y_true_val, y_pred_val, average='binary')  # sensitivity
    f1 = f1_score(y_true_val, y_pred_val, average='binary')

    val_accuracies.append(val_acc)
    val_precisions.append(precision)
    val_recalls.append(recall)
    val_f1s.append(f1)

    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Train Loss: {avg_loss:.4f} - "
          f"Val Acc: {val_acc:.4f} - "
          f"Precision: {precision:.4f} - "
          f"Recall (Sensitivity): {recall:.4f} - "
          f"F1 Score: {f1:.4f}")


In [None]:
epochs = range(1, num_epochs+1)

plt.figure(figsize=(16, 10))

plt.subplot(2, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.title('Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()

plt.subplot(2, 2, 2)
plt.plot(epochs, val_accuracies, label='Val Accuracy', color='green')
plt.title('Accuracy vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)
plt.legend()

plt.subplot(2, 2, 3)
plt.plot(epochs, val_recalls, label='Recall (Sensitivity)', color='red')
plt.title('Sensitivity (Recall) vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Recall')
plt.grid(True)
plt.legend()

plt.subplot(2, 2, 4)
plt.plot(epochs, val_precisions, label='Precision', color='purple')
plt.plot(epochs, val_f1s, label='F1 Score', color='orange')
plt.title('Precision & F1 Score vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()


In [10]:
torch.save(vit_model.state_dict(), 'vit_model.pth')
print("Model saved to vit_model.pth")

Model saved to vit_model.pth
