In [None]:
!pip install datasets
!pip install timm

In [5]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import io
import os
from sklearn.preprocessing import MultiLabelBinarizer


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


csv_path = 'archive/Data_Entry_2017.csv'
df = pd.read_csv(csv_path)

# Separate the labels for each disease and generate single-label samples
expanded_rows = []
for _, row in df.iterrows():
    labels = row['Finding Labels'].split('|')
    for label in labels:
        new_row = row.copy()
        new_row['Finding Labels'] = label
        expanded_rows.append(new_row)

# Create a new DataFrame where each example has only one label
expanded_df = pd.DataFrame(expanded_rows)

all_labels = expanded_df['Finding Labels'].unique()
label_map = {label: idx for idx, label in enumerate(all_labels)}

# Convert each label to integer encoding
expanded_df['Encoded Labels'] = expanded_df['Finding Labels'].map(label_map)

class ChestXRayDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.data = dataframe
        self.img_dir = img_dir
        self.transform = transform
        
        self.image_paths = {}
        for root, _, files in os.walk(self.img_dir):
            for file in files:
                if file.endswith('.png'):
                    self.image_paths[file] = os.path.join(root, file)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['Image Index']
        
        img_path = self.image_paths.get(img_name)
        if img_path is None:
            raise FileNotFoundError(f"Image {img_name} not found in specified directories.")
        
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = self.data.iloc[idx]['Encoded Labels']
        
        return image, label

img_dir = 'archive/images_all/'
dataset = ChestXRayDataset(dataframe=expanded_df, img_dir=img_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

total_samples = len(data_loader.dataset)
print(f"Total samples in the dataset: {total_samples}")

# Get label distribution statistics
label_distribution = expanded_df['Finding Labels'].value_counts().to_dict()
print(f"Label distribution: {label_distribution}")

Total samples in the dataset: 141537
Label distribution: {'No Finding': 60361, 'Infiltration': 19894, 'Effusion': 13317, 'Atelectasis': 11559, 'Nodule': 6331, 'Mass': 5782, 'Pneumothorax': 5302, 'Consolidation': 4667, 'Pleural_Thickening': 3385, 'Cardiomegaly': 2776, 'Emphysema': 2516, 'Edema': 2303, 'Fibrosis': 1686, 'Pneumonia': 1431, 'Hernia': 227}


In [6]:
from torch.utils.data import DataLoader
import torch
from timm import create_model

# Load the pre-trained ConvNext model and remove the classification layer
def load_convnext_backbone():
    model = create_model("convnext_base", pretrained=True)
    model.reset_classifier(0)
    return model

device = "cuda" if torch.cuda.is_available() else "cpu"
backbone = load_convnext_backbone().to(device)
backbone.eval()

ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (norm): Identity()
            (fc2): Linear(in_features=512, out_features=128, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (shortcut): Identity()
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), g

In [7]:
from tqdm import tqdm

def extract_features(backbone, dataloader, device):
    features, labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting features"):
            inputs, batch_labels = batch

            inputs = inputs.to(device)
            batch_labels = batch_labels.to(device)

            outputs = backbone(inputs)
            features.extend(outputs.cpu().numpy())
            labels.extend(batch_labels.cpu().numpy())
    
    return features, labels

# Extracting features using ConvNext
train_features, train_labels = extract_features(backbone, data_loader, device)

print(f"Extracted feature shape: {np.array(train_features).shape}")
print(f"Extracted labels shape: {np.array(train_labels).shape}")


Extracting features: 100%|█████████████████████████████████████████████████████████| 4424/4424 [51:10<00:00,  1.44it/s]


Extracted feature shape: (141537, 1024)
Extracted labels shape: (141537,)


In [8]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [19]:
train_features_np = np.array(train_features)
train_labels_np = np.array(train_labels)


input_dim = train_features_np.shape[1]
print(input_dim)

num_classes = np.max(train_labels_np) + 1
print(f"Number of classes (num_classes): {num_classes}")

1024
Number of classes (num_classes): 15


In [20]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [22]:
# Define the fully connected layer model
class FullyConnectedModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(FullyConnectedModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes),
            #nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.fc(x)

# Split training set and testing set
train_features_np, val_features_np, train_labels_np, val_labels_np = train_test_split(train_features_np, train_labels_np, test_size=0.1, random_state=42)

train_features_tensor = torch.tensor(train_features_np, dtype=torch.float32)
train_labels_tensor = torch.tensor(train_labels_np, dtype=torch.long)
val_features_tensor = torch.tensor(val_features_np, dtype=torch.float32)
val_labels_tensor = torch.tensor(val_labels_np, dtype=torch.long)

train_dataset = TensorDataset(train_features_tensor, train_labels_tensor)
val_dataset = TensorDataset(val_features_tensor, val_labels_tensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_dim = train_features_np.shape[1]
model = FullyConnectedModel(input_dim=input_dim, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


epochs = 100
for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False)
    for inputs, labels in train_loader_tqdm:
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)

        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")

print("Training Completed！")

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

y_pred = np.array(all_preds)
y_true = np.array(all_labels)

accuracy = accuracy_score(y_true, y_pred)
print(f'Accuracy: {accuracy:.4f}')



                                                                                                                       

Epoch [1/100], Loss: 1.8335


                                                                                                                       

Epoch [2/100], Loss: 1.7859


                                                                                                                       

Epoch [3/100], Loss: 1.7693


                                                                                                                       

Epoch [4/100], Loss: 1.7602


                                                                                                                       

Epoch [5/100], Loss: 1.7527


                                                                                                                       

Epoch [6/100], Loss: 1.7484


                                                                                                                       

Epoch [7/100], Loss: 1.7431


                                                                                                                       

Epoch [8/100], Loss: 1.7391


                                                                                                                       

Epoch [9/100], Loss: 1.7355


                                                                                                                       

Epoch [10/100], Loss: 1.7329


                                                                                                                       

Epoch [11/100], Loss: 1.7309


                                                                                                                       

Epoch [12/100], Loss: 1.7273


                                                                                                                       

Epoch [13/100], Loss: 1.7248


                                                                                                                       

Epoch [14/100], Loss: 1.7241


                                                                                                                       

Epoch [15/100], Loss: 1.7207


                                                                                                                       

Epoch [16/100], Loss: 1.7203


                                                                                                                       

Epoch [17/100], Loss: 1.7192


                                                                                                                       

Epoch [18/100], Loss: 1.7167


                                                                                                                       

Epoch [19/100], Loss: 1.7156


                                                                                                                       

Epoch [20/100], Loss: 1.7140


                                                                                                                       

Epoch [21/100], Loss: 1.7118


                                                                                                                       

Epoch [22/100], Loss: 1.7115


                                                                                                                       

Epoch [23/100], Loss: 1.7094


                                                                                                                       

Epoch [24/100], Loss: 1.7080


                                                                                                                       

Epoch [25/100], Loss: 1.7080


                                                                                                                       

Epoch [26/100], Loss: 1.7070


                                                                                                                       

Epoch [27/100], Loss: 1.7050


                                                                                                                       

Epoch [28/100], Loss: 1.7045


                                                                                                                       

Epoch [29/100], Loss: 1.7031


                                                                                                                       

Epoch [30/100], Loss: 1.7027


                                                                                                                       

Epoch [31/100], Loss: 1.7019


                                                                                                                       

Epoch [32/100], Loss: 1.7018


                                                                                                                       

Epoch [33/100], Loss: 1.7001


                                                                                                                       

Epoch [34/100], Loss: 1.6996


                                                                                                                       

Epoch [35/100], Loss: 1.6976


                                                                                                                       

Epoch [36/100], Loss: 1.6983


                                                                                                                       

Epoch [37/100], Loss: 1.6974


                                                                                                                       

Epoch [38/100], Loss: 1.6961


                                                                                                                       

Epoch [39/100], Loss: 1.6961


                                                                                                                       

Epoch [40/100], Loss: 1.6944


                                                                                                                       

Epoch [41/100], Loss: 1.6944


                                                                                                                       

Epoch [42/100], Loss: 1.6940


                                                                                                                       

Epoch [43/100], Loss: 1.6928


                                                                                                                       

Epoch [44/100], Loss: 1.6921


                                                                                                                       

Epoch [45/100], Loss: 1.6928


                                                                                                                       

Epoch [46/100], Loss: 1.6911


                                                                                                                       

Epoch [47/100], Loss: 1.6907


                                                                                                                       

Epoch [48/100], Loss: 1.6906


                                                                                                                       

Epoch [49/100], Loss: 1.6901


                                                                                                                       

Epoch [50/100], Loss: 1.6886


                                                                                                                       

Epoch [51/100], Loss: 1.6879


                                                                                                                       

Epoch [52/100], Loss: 1.6883


                                                                                                                       

Epoch [53/100], Loss: 1.6876


                                                                                                                       

Epoch [54/100], Loss: 1.6878


                                                                                                                       

Epoch [55/100], Loss: 1.6865


                                                                                                                       

Epoch [56/100], Loss: 1.6860


                                                                                                                       

Epoch [57/100], Loss: 1.6851


                                                                                                                       

Epoch [58/100], Loss: 1.6848


                                                                                                                       

Epoch [59/100], Loss: 1.6849


                                                                                                                       

Epoch [60/100], Loss: 1.6837


                                                                                                                       

Epoch [61/100], Loss: 1.6830


                                                                                                                       

Epoch [62/100], Loss: 1.6829


                                                                                                                       

Epoch [63/100], Loss: 1.6834


                                                                                                                       

Epoch [64/100], Loss: 1.6816


                                                                                                                       

Epoch [65/100], Loss: 1.6835


                                                                                                                       

Epoch [66/100], Loss: 1.6818


                                                                                                                       

Epoch [67/100], Loss: 1.6815


                                                                                                                       

Epoch [68/100], Loss: 1.6814


                                                                                                                       

Epoch [69/100], Loss: 1.6812


                                                                                                                       

Epoch [70/100], Loss: 1.6806


                                                                                                                       

Epoch [71/100], Loss: 1.6793


                                                                                                                       

Epoch [72/100], Loss: 1.6792


                                                                                                                       

Epoch [73/100], Loss: 1.6799


                                                                                                                       

Epoch [74/100], Loss: 1.6790


                                                                                                                       

Epoch [75/100], Loss: 1.6785


                                                                                                                       

Epoch [76/100], Loss: 1.6774


                                                                                                                       

Epoch [77/100], Loss: 1.6771


                                                                                                                       

Epoch [78/100], Loss: 1.6775


                                                                                                                       

Epoch [79/100], Loss: 1.6762


                                                                                                                       

Epoch [80/100], Loss: 1.6761


                                                                                                                       

Epoch [81/100], Loss: 1.6762


                                                                                                                       

Epoch [82/100], Loss: 1.6754


                                                                                                                       

Epoch [83/100], Loss: 1.6753


                                                                                                                       

Epoch [84/100], Loss: 1.6756


                                                                                                                       

Epoch [85/100], Loss: 1.6748


                                                                                                                       

Epoch [86/100], Loss: 1.6736


                                                                                                                       

Epoch [87/100], Loss: 1.6733


                                                                                                                       

Epoch [88/100], Loss: 1.6742


                                                                                                                       

Epoch [89/100], Loss: 1.6737


                                                                                                                       

Epoch [90/100], Loss: 1.6733


                                                                                                                       

Epoch [91/100], Loss: 1.6723


                                                                                                                       

Epoch [92/100], Loss: 1.6735


                                                                                                                       

Epoch [93/100], Loss: 1.6732


                                                                                                                       

Epoch [94/100], Loss: 1.6729


                                                                                                                       

Epoch [95/100], Loss: 1.6723


                                                                                                                       

Epoch [96/100], Loss: 1.6720


                                                                                                                       

Epoch [97/100], Loss: 1.6706


                                                                                                                       

Epoch [98/100], Loss: 1.6714


                                                                                                                       

Epoch [99/100], Loss: 1.6708


                                                                                                                       

Epoch [100/100], Loss: 1.6722
Training Completed！
Accuracy: 0.4357
