In [3]:
import zipfile
#from google.colab import drive

#drive.mount('/content/drive/')

In [1]:
!unzip "/Users/sarahcasale/Downloads/LesionAid-main/data/aug_balanced_imgs.zip"
#!unzip "/content/drive/MyDrive/DS4440_Project/preprocessed_images.zip"

'unzip' is not recognized as an internal or external command,
operable program or batch file.


In [1]:
# Imports
import pandas as pd
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from tqdm.auto import tqdm

from torch import device
from torch.cuda import is_available, device_count
from torchvision.utils import save_image
import numpy as np

In [2]:
# CONSTANTS

DEVICE = device("cuda" if is_available() else "cpu")  # (GPU if available)

nz: int = 128  # length of latent vector
ngf: int = 64  # depth of feature maps carried through the generator.
ndf: int = 32  # depth of feature maps propagated through the discriminator
nc: int = 3  # number of color channels (for color images = 3)
# niter: int = NUM_EPOCHS BATCH_SIZE  # 300
n_dnn: int = 32  # number of output features of the label's linear

# Image and Label Constants
IMAGE_SIZE: int = 64  # 128
LABEL_TO_CLASS: dict = {
    'N': 0,
    'D': 1,
    'G': 2,
    'C': 3,
    'A': 4,
    'H': 5,
    'M': 6,
    'O': 7
}
LABEL_TO_TITLE: dict = {
    'N': "Normal",
    'D': "Diabetes",
    'G': "Galucoma",
    'C': "Cataract",
    'A': "Age related Macular Degeneration",
    'H': "Hypertension",
    'M': "Pathological Myopia",
    'O': "Other diseases/abnormalities"
}
CLASS_TO_LABEL: dict = {_v: _k for _k, _v in LABEL_TO_CLASS.items()}
NUM_CLASSES: int = len(LABEL_TO_CLASS)

# ---------------------
# Training parameters
# ---------------------
BATCH_SIZE: int = 64 if IMAGE_SIZE == 64 else 32
NUM_EPOCHS: int = 200

# ---------------------
# Hyperparameters
# ---------------------
# kept same hyperparameters as https://arxiv.org/pdf/1511.06434.pdf
LEARNING_RATE: float = 0.0002
BETA_1: float = 0.5

In [8]:
# Create CustomDataset Class
class CustomDataset(Dataset):
  def __init__(self, root_dir, label_file, transform=None, label="ALL"):
        self.root_dir = root_dir
        self.label_file = label_file
        self.transform = transform
        self.label = label
        if self.label == "ALL":
          self.labels_df = pd.read_csv(label_file)
        else: # only load images of the given label
          label_df = pd.read_csv(label_file)
          label_df = label_df[label_df['labels']==self.label].reset_index(drop=True)
          self.labels_df = label_df
        self.image_files = os.listdir(root_dir)

  def __len__(self):
    return len(self.labels_df)

  def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        image = self.transform(image)

        # Extract label from the label file
        filename = self.image_files[idx]
        if self.label=='ALL':
          label = self.labels_df.loc[self.labels_df['filename'] == filename]['labels'].item()
        else:
          label = self.label
        # Convert label to label class, to get index for encoding
        label_ind = LABEL_TO_CLASS[label]
        # one hot encoding for labels:
        label_encoded = torch.zeros(NUM_CLASSES)
        label_encoded[label_ind] = 1

        return image, label_encoded

## GAN

Adapted Code from [gcastro-98/synthetic-medical-images](https://github.com/gcastro-98/synthetic-medical-images)



### `synthetic-medical-images` GAN
Structure and Code from [gcastro-98/synthetic-medical-images](https://github.com/gcastro-98/synthetic-medical-images):
- `Generator64` Class
- `Discriminator64` Class
- `train_gan` method
- `

In [62]:
class Generator(nn.Module):
    def __init__(self):
        assert IMAGE_SIZE == 64, \
            f"This architecture is not suitable for IMAGE_SIZE = {IMAGE_SIZE}"
        super(Generator, self).__init__()

        self.y_label = nn.Sequential(
            nn.Linear(NUM_CLASSES, n_dnn),  # 120, 1000
            nn.ReLU(True)
        )

        self.yz = nn.Sequential(
            nn.Linear(nz, 2 * nz),  # 100, 200
            nn.ReLU(True)
        )

        self.main = nn.Sequential(
            nn.ConvTranspose2d(n_dnn + 2 * nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Sigmoid()
            # state size. (nc) x 64 x 64
        )

    def forward(self, z, y):
        # mapping noise and label
        z = self.yz(z)
        y = self.y_label(y)

        # mapping concatenated input to the main generator network
        inp = torch.cat([z, y], 1)
        inp = inp.view(-1, n_dnn + 2 * nz, 1, 1)  # 1000 + 200
        output = self.main(inp)

        return output


class Discriminator(nn.Module):
    def __init__(self):
        assert IMAGE_SIZE == 64, \
            f"This architecture is not suitable for IMAGE_SIZE = {IMAGE_SIZE}"
        super(Discriminator, self).__init__()
        # self.ngpu = _ngpu
        self.y_label = nn.Sequential(
            nn.Linear(NUM_CLASSES, 64 * 64 * 1),
            nn.ReLU(True)
        )

        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc + 1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.y_label(y)
        y = y.view(-1, 1, 64, 64)
        inp = torch.cat([x, y], 1)
        output = self.main(inp)

        return output.view(-1, 1).squeeze(1)

In [63]:
def __generate_random_noise():
    return torch.randn(BATCH_SIZE, nz, device=DEVICE)


def __generate_random_labels():
    label = torch.zeros(BATCH_SIZE, NUM_CLASSES, device=DEVICE)
    for i in range(BATCH_SIZE):
        x = np.random.randint(0, NUM_CLASSES)
        label[i][x] = 1
    return label

_checkpoint_noise = __generate_random_noise()
_checkpoint_labels = __generate_random_labels()

In [64]:
def _plot_losses(g_losses, d_losses,
                 _show: bool = False) -> None:
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(g_losses, label="Generator")
    plt.plot(d_losses, label="Discriminator")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(os.path.join('.img', 'losses.png'), dpi=200)
    if _show:
        plt.show()
    plt.close()

In [65]:
def train_gan(
        data_loader, use_cpu: bool = False,
        save_best_model: bool = True, save_generated_images: bool = True,
        verbose: bool = False, _freq: int = 5):
    #DEVICE = device('cpu') if use_cpu else _DEVICE
    # initialize (with weights) generator and discriminator
    net_g = Generator().to(DEVICE)
    #net_g.apply(weights_init)
    net_d = Discriminator().to(DEVICE)
    #net_d.apply(weights_init)

    # loss function and optimizers
    criterion = nn.BCELoss()  # we are simply detecting whether it's real/fake

    real_label = float(1)
    fake_label = float(0)

    # setup optimizer
    optimizer_d = optim.Adam(
        net_d.parameters(), lr=LEARNING_RATE, betas=(BETA_1, 0.999))
    optimizer_g = optim.Adam(
        net_g.parameters(), lr=LEARNING_RATE, betas=(BETA_1, 0.999))
    d_error_epoch = []
    g_error_epoch = []

    for epoch in tqdm(range(NUM_EPOCHS)):
        # we will start iterating each batch element
        d_error_iter = 0
        g_error_iter = 0
        for i, data in enumerate(data_loader, 0):
            # DISCRIMINATOR
            # train with real
            net_d.zero_grad()
            real_cpu = data[0].to(DEVICE)
            batch_size = real_cpu.size(0)
            pathology_one_hot = data[1].to(DEVICE)
            label = torch.full((batch_size, ), real_label, device=DEVICE)

            output = net_d(real_cpu, pathology_one_hot)
            err_d_real = criterion(output, label)
            err_d_real.backward()
            # D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, device=DEVICE)
            fake = net_g(noise, pathology_one_hot)
            label.fill_(fake_label)
            output = net_d(fake.detach(), pathology_one_hot)
            err_d_fake = criterion(output, label)
            err_d_fake.backward()
            # D_G_z1 = output.mean().item()
            err_d = err_d_real + err_d_fake
            d_error_iter += err_d.item()
            optimizer_d.step()

            # GENERATOR
            net_g.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = net_d(fake, pathology_one_hot)
            err_g = criterion(output, label)
            g_error_iter += err_g.item()
            err_g.backward()
            # D_G_z2 = output.mean().item()
            optimizer_g.step()

            if (i + 1) % (BATCH_SIZE // 4) == 0 and verbose:
                # we print the losses
                _counter = f"Epoch [{epoch}/{NUM_EPOCHS}]" \
                           f"[{i}/{len(data_loader)}]"
                print(f"{_counter} --- Loss G: {err_g.item()}")
                print(f"{_counter} --- Loss D: {err_d.item()}")

        if (epoch + 1) % _freq == 0:
            # we save generated images
            with torch.no_grad():
                if save_generated_images:
                    print(
                        f"CHECKPOINT {epoch + 1}: saving some generated images at 'output/' directory")
                    checkpoint_images = net_g(
                        _checkpoint_noise, _checkpoint_labels)
                    # we re-scale generated images to [0, 1] and save them
                    save_image((checkpoint_images + 1) / 2,
                               f"output/epoch_{epoch + 1}.png", nrow=8, normalize=True)

            # save models as checkpoint
            if save_best_model:
                print(f"CHECKPOINT {epoch + 1}: saving the trained"
                             " models at 'models/' directory")
                torch.save(net_g.state_dict(),
                           "models/generator.pth")
                torch.save(net_d.state_dict(),
                           "models/discriminator.pth")

        # accumulate error for each epoch
        d_error_epoch.append(d_error_iter)
        g_error_epoch.append(g_error_iter)

    _plot_losses(g_error_epoch, d_error_epoch)

    # save the trained generator
    torch.save(net_g.state_dict(),
               "models/generator.pth")
    # as well as the trained discriminator
    torch.save(net_d.state_dict(),
               "models/discriminator.pth")

    return net_g, net_d

In [66]:
def plot_fake_images(
        generator, n_images: int = 9, _show: bool = False) -> None:
    cols, rows = 3, 3
    fig, axs = plt.subplots(rows, cols, sharex='all')
    axs = axs.flatten()

    gen_z, label, _label_names = __generate_random_inputs(n_images)
    gen_images = generator(gen_z, label)
    images = gen_images.to("cpu").clone().detach()
    images = images.numpy().transpose(0, 2, 3, 1)

    for i in range(9):
        axs[i].set_title(_label_names[i])
        axs[i].set_axis_off()
        axs[i].imshow(images[i])
    plt.tight_layout(pad=1.04)
    plt.savefig(os.path.join('.img', 'fake_samples.png'), dpi=200)
    if _show:
        plt.show()
    plt.close()

def __generate_random_inputs(n_images: int):
    gen_z = torch.randn(n_images, nz, device=DEVICE)
    label = torch.zeros(n_images, NUM_CLASSES, device=DEVICE)
    _label_names = []
    for i in range(n_images):
        x = np.random.randint(0, NUM_CLASSES)
        label[i][x] = 1
        _label_names.append(LABEL_TO_TITLE[CLASS_TO_LABEL[x]])
    return gen_z, label, _label_names

In [67]:
def load_model():
    print("Loading already trained and serialized generator model")
    generator = Generator().to(device(DEVICE))
    generator.load_state_dict(torch.load(
        os.path.join("models/generator.pth")))
    plot_fake_images(generator, _show=False)
    return

In [68]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

#dataset = CustomDataset(root_dir='/content/preprocessed_images', label_file='/content/drive/MyDrive/DS4440_Project/odir_labels.csv', transform=transform)
dataset = CustomDataset(root_dir='../data/original/preprocessed_images', label_file='../data/original/odir_labels.csv', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
generator, discriminator = train_gan(dataloader)
plot_fake_images(generator, _show=False)

## Diffusion Model

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
# Training Config
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 64 # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 200
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "occularaid-M"  # the model name locally and on the HF Hub

    push_to_hub = True  # whether to upload the saved model to the HF Hub
    hub_model_id = "cheungra/occularaid-M"  # the name of the repository to create on the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0


config = TrainingConfig()

In [9]:
transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

#dataset = CustomDataset(root_dir='/content/preprocessed_images', label_file='/content/drive/MyDrive/DS4440_Project/odir_labels.csv', transform=transform)
dataset = CustomDataset(root_dir='../data/preprocessed_images', label_file='../data/odir_labels.csv', transform=transform, label='M')
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

In [10]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

In [30]:
import torch
from PIL import Image
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
timesteps = torch.LongTensor([50])

In [31]:
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

In [32]:
from diffusers import DDPMPipeline
from diffusers.utils import make_image_grid
import os

def evaluate(config, epoch, pipeline):
    # Sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is `List[PIL.Image]`
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.manual_seed(config.seed),
    ).images

    # Make a grid out of the images
    image_grid = make_image_grid(images, rows=4, cols=4)

    # Save the images
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")

In [33]:
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import os

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        if config.push_to_hub:
            repo_id = create_repo(
                repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
            ).repo_id
        accelerator.init_trackers("train_example")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = batch[0]
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape, device=clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
                dtype=torch.int64
            )

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                if config.push_to_hub:
                    upload_folder(
                        repo_id=repo_id,
                        folder_path=config.output_dir,
                        commit_message=f"Epoch {epoch}",
                        ignore_patterns=["step_*", "epoch_*"],
                    )
                else:
                    pipeline.save_pretrained(config.output_dir)

In [None]:
from accelerate import notebook_launcher

args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

notebook_launcher(train_loop, args, num_processes=1)

In [37]:
# push to huggingface hub
model.push_to_hub("ocularaid-diffusion")
     

diffusion_pytorch_model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/cheungra/ocularaid-diffusion/commit/99e708d0536305f11d6d941e00ae14a066f2e036', commit_message='Upload model', commit_description='', oid='99e708d0536305f11d6d941e00ae14a066f2e036', pr_url=None, pr_revision=None, pr_num=None)

In [41]:
model.save_config('config.json')

In [49]:
from diffusers import UNet2DConditionModel

repo_id = "cheungra/ocularaid-diffusion"
unet = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet", class_embed_type = "projection",projection_class_embeddings_input_dim=16,revision=None, low_cpu_mem_usage=False, device_map=None)

OSError: cheungra/ocularaid-diffusion does not appear to have a file named config.json.