In [None]:
# pip install pyhealth

### STEP 1: load the chest Xray data

In [None]:
from pyhealth.datasets import COVID19CXRDataset

root = "/srv/local/data/COVID-19_Radiography_Dataset"
base_dataset = COVID19CXRDataset(root, refresh_cache=False)

base_dataset.stat()

### STEP 2: set task and processing the data

In [None]:
from torchvision import transforms

sample_dataset = base_dataset.set_task()

# the transformation automatically normalize the pixel intensity into [0, 1]
transform = transforms.Compose([
    transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)), # only use the first channel
    transforms.Resize((128, 128)),
])

def encode(sample):
    sample["path"] = transform(sample["path"])
    return sample

sample_dataset.set_transform(encode)

In [None]:
from pyhealth.datasets import split_by_visit, get_dataloader

# split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
    sample_dataset, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

In [None]:
data = next(iter(train_dataloader))
print (data["path"][0].shape)

print(
    "loader size: train/val/test",
    len(train_dataset),
    len(val_dataset),
    len(test_dataset),
)

### STEP3: define the GAN model

In [None]:
from pyhealth.models import GAN

model = GAN(
    input_channel=3,
    input_size=128,
    hidden_dim = 256,
)

### STEP4: training the GAN model in an adversarial way

In [None]:
import torch
from tqdm import tqdm

# Loss function
loss = torch.nn.BCELoss()
opt_G = torch.optim.AdamW(model.generator.parameters(), lr=1e-3)
opt_D = torch.optim.AdamW(model.discriminator.parameters(), lr=1e-4)

device = "cuda:4"
model.to(device)

curve_D, curve_G = [], []

for epoch in range(50):
    curve_G.append(0)
    curve_D.append(0)
    for batch in tqdm(train_dataloader):
        
        """ train discriminator """
        opt_D.zero_grad()

        real_imgs = torch.stack(batch["path"], dim=0).to(device)
        batch_size = real_imgs.shape[0]
        fake_imgs = model.generate_fake(batch_size, device)
        
        real_loss = loss(model.discriminator(real_imgs), torch.ones(batch_size, 1).to(device))
        fake_loss = loss(model.discriminator(fake_imgs.detach()), torch.zeros(batch_size, 1).to(device))
        loss_D = (real_loss + fake_loss) / 2

        loss_D.backward()
        opt_D.step()
        
        """ train generator """
        opt_G.zero_grad()
        loss_G = loss(model.discriminator(fake_imgs), torch.ones(batch_size, 1).to(device))

        loss_G.backward()
        opt_G.step()
        
        curve_G[-1] += loss_G.item()
        curve_D[-1] += loss_D.item()

    print (f"epoch: {epoch} --- loss of G: {curve_G[-1]}, loss of D: {curve_D[-1]}")

### EXP 2: synthesize random images

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    fake_imgs = model.generate_fake(1, device).detach().cpu()
    plt.imshow(fake_imgs[0][0], cmap="gray")
    plt.show()