In [3]:
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from transformers import TrainingArguments, Trainer
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report

In [4]:
nih = load_dataset("nih_chest_xray")
train_ds = nih["train"]
test_ds = nih["test"]

DatasetNotFoundError: Dataset 'nih_chest_xray' doesn't exist on the Hub or cannot be accessed.

In [None]:
diseases = ['Atelectasis','Cardiomegaly','Consolidation','Edema','Effusion',
            'Emphysema','Fibrosis','Hernia','Infiltration','Mass','No Finding',
            'Nodule','Pleural_Thickening','Pneumonia','Pneumothorax']

def encode_labels(example):
    lbls = example["Finding Labels"].split("|")
    example["labels"] = [1 if d in lbls else 0 for d in diseases]
    return example

train_ds = train_ds.map(encode_labels)
test_ds  = test_ds.map(encode_labels)


In [None]:
extractor = AutoFeatureExtractor.from_pretrained("akhaliq/chexnet")

model = AutoModelForImageClassification.from_pretrained(
    "akhaliq/chexnet",
    num_labels=len(diseases),
    problem_type="multi_label_classification"
)


In [None]:
for p in model.densenet121.features.parameters():
    p.requires_grad = False


In [None]:
# weight = 2.0 for Edema, 1.0 for others
pos_weight = torch.ones(len(diseases))
pos_weight[diseases.index("Edema")] = 2.0

loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)


In [None]:
def preprocess(batch):
    imgs = [img.convert("RGB") for img in batch["image"]]
    enc = extractor(imgs, return_tensors="pt")
    batch["pixel_values"] = enc["pixel_values"]
    batch["labels"] = torch.tensor(batch["labels"]).float()
    return batch

train_ds = train_ds.with_format("torch").map(preprocess, batched=True, batch_size=16)
test_ds  = test_ds.with_format("torch").map(preprocess, batched=True, batch_size=16)


In [None]:
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss


In [None]:
def predict(model, dataset):
    preds, trues = [], []
    dl = torch.utils.data.DataLoader(dataset, batch_size=32)
    model.eval()
    for batch in dl:
        with torch.no_grad():
            logits = model(batch["pixel_values"]).logits
            probs = torch.sigmoid(logits).cpu().numpy()
        preds.append(probs)
        trues.append(batch["labels"].cpu().numpy())
    return np.vstack(preds), np.vstack(trues)

print("Running baseline eval (before training)...")
pred_before, true_before = predict(model, test_ds)

edema_idx = diseases.index("Edema")

fp_before = np.sum((true_before[:, edema_idx] == 0) & (pred_before[:, edema_idx] > 0.5))
fn_before = np.sum((true_before[:, edema_idx] == 1) & (pred_before[:, edema_idx] <= 0.5))

print("Baseline Edema FPs:", fp_before)
print("Baseline Edema FNs:", fn_before)


In [None]:
args = TrainingArguments(
    output_dir="chexnet_nih_finetuned",
    learning_rate=1e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
)

trainer = WeightedTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
)

trainer.train()


In [None]:
print("Running eval AFTER fine-tuning...")
pred_after, true_after = predict(model, test_ds)

fp_after = np.sum((true_after[:, edema_idx] == 0) & (pred_after[:, edema_idx] > 0.5))
fn_after = np.sum((true_after[:, edema_idx] == 1) & (pred_after[:, edema_idx] <= 0.5))

print("After Fine-Tune Edema FPs:", fp_after)
print("After Fine-Tune Edema FNs:", fn_after)


In [None]:
import pandas as pd

summary = pd.DataFrame({
    "Metric": ["Edema False Positives", "Edema False Negatives"],
    "Before": [fp_before, fn_before],
    "After":  [fp_after, fn_after]
})

summary
