# Binary Classification Using the ChestX-ray14 Dataset

## Step 0: Install PyHealth

In [None]:
%pip install pyhealth ipywidgets

## Step 1: Load Dataset

In [None]:
from pyhealth.datasets import ChestXray14Dataset

dataset = ChestXray14Dataset(download=True, partial=True)
dataset.stats()

## Step 2: Define Task

In [None]:
from pyhealth.tasks import ChestXray14BinaryClassification

task = ChestXray14BinaryClassification(disease="infiltration")
samples = dataset.set_task(task)

In [None]:
from pyhealth.datasets import get_dataloader, split_by_sample

train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])

train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)

## Step 3: Define Model

In [None]:
from pyhealth.models import CNN

model = CNN(dataset=samples)

## Step 4: Train Model

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model)
trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1)

## Step 5: Evaluate Model

In [None]:
trainer.evaluate(test_loader)