## Imports

In [1]:
import torch
# import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Precision, Recall, RunningAverage #, EpochMetric
# from ignite.contrib.metrics import ROC_AUC

# from sklearn.metrics import roc_auc_score
import numpy as np
import warnings

# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
import time
# import os

In [2]:
%run dataset.py

In [3]:
%run model.py

In [4]:
%run utils.py

In [5]:
%run utilsT.py

In [6]:
# FIXME: fix deprecations
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.current_device()

0

In [8]:
print(device)

cuda


## Common

In [9]:
# In server
base_dir = '/mnt/data/chest-x-ray-8'

In [10]:
model_save_dir = base_dir + '/savedModels'
log_dir = base_dir + '/runs'
dataset_dir = base_dir + '/dataset'

## Loss

In [11]:
def weighted_bce(output, target):
    """Computes weighted binary cross entropy loss.
    
    If a multi-label array is given, the BCE is summed across labels."""
    output = output.clamp(min=1e-5, max=1-1e-5)
    target = target.float()

    # Calculate weights
    BP = 1
    BN = 1

    total = np.prod(target.size())
    positive = int((target > 0).sum())
    negative = total - positive

    if positive != 0 and negative != 0:
        BP = total / positive
        BN = total / negative

    loss = -BP * target * torch.log(output) - BN * (1 - target) * torch.log(1 - output)

    return torch.sum(loss)

## Transform functions

In [12]:
mean = 0.50576189

In [13]:
transform_image = transforms.Compose([transforms.Resize(512),
                                      transforms.ToTensor(),
                                      transforms.Normalize([mean], [1.])
                                     ])

## Train params

In [14]:
BATCH_SIZE = 4
N_EPOCHS = 100

learning_rate = 1e-6
optimizer_moment = 0.9
weight_decay = 0
regularization = 0

## Data loading

### Select only some diseases

In [15]:
chosen_diseases = ["Infiltration"] # list(ALL_DISEASES[:1])
chosen_diseases

['Infiltration']

### Training dataset

In [16]:
train_dataset = CXRDataset(dataset_dir, transform=transform_image, diseases=chosen_diseases)
train_dataset.size()

(45, 15)

In [17]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

In [18]:
TRAINED_WITH_DISEASES = list(train_dataset.classes)
TRAINED_WITH_DISEASES

['Infiltration']

### Validation dataset

In [19]:
val_dataset = CXRDataset(dataset_dir, dataset_type="val", transform=transform_image, diseases=chosen_diseases)
val_dataset.size()

(14, 15)

In [20]:
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

## Model

In [21]:
model = Model().to(device)

In [22]:
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=optimizer_moment, weight_decay=weight_decay)

## Training

### Step function

In [23]:
def get_step_fn(training=True):
    def step(engine, data_batch):
        """."""
        # Input and sizes
        images, labels, names, _, _ = data_batch
        n_samples, n_labels = labels.size()

        # Move tensors to GPU
        images = images.to(device)
        labels = labels.to(device)

        # Enable training
        model.train(training)
        torch.set_grad_enabled(training) # enable recording gradients

        # zero the parameter gradients
        optimizer.zero_grad()

        # Forward, receive outputs from the model and segments (bboxes)
        outputs, segments = model(images)

        # Compute classification loss
        loss = weighted_bce(outputs, labels)

        batch_loss = loss.item()
        
        if training:
            loss.backward()
            optimizer.step()

        return batch_loss, outputs, labels
    
    return step

### Train utils

In [24]:
def get_transform_one_label(label_index, use_round=True):
    def transform(output):
        _, y_pred, y_true = output
        y_pred = y_pred[:, label_index]
        y_true = y_true[:, label_index]
        
        if use_round:
            y_pred = torch.round(y_pred)

        return y_pred, y_true
    return transform

In [25]:
def attach_metrics(engine, metric_name, MetricClass, use_round):
    """Attaches onemetric per label to an engine."""
    for index, disease in enumerate(TRAINED_WITH_DISEASES):
        transform_disease = get_transform_one_label(index, use_round=use_round)

        metric = MetricClass(output_transform=transform_disease)
        metric.attach(engine, "{}_{}".format(metric_name, disease))

### Train run

In [26]:
log_metrics = ["roc_auc"]

In [27]:
loss_name = "wbce_loss"

In [28]:
flush_secs = 10 # Use a low value when debugging

In [29]:
run_timestamp = get_timestamp()

run_name = run_timestamp # + "_something"
print("This run: ", run_timestamp)

writer = SummaryWriter(log_dir=base_dir + "/runs/experiments/" + run_name, flush_secs=flush_secs)

This run:  2019-09-28-23-31-31


### Validator engine

In [30]:
validator = Engine(get_step_fn(False))

In [31]:
avg_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
avg_loss.attach(validator, loss_name)

In [32]:
attach_metrics(validator, "prec", Precision, True)
attach_metrics(validator, "recall", Recall, True)
attach_metrics(validator, "acc", Accuracy, True)
attach_metrics(validator, "roc_auc", RocAucMetric, False)

### Trainer engine

In [33]:
trainer = Engine(get_step_fn(True))

In [34]:
def write_results(run_type, metrics, epoch, wall_time):
    loss = metrics.get(loss_name, 0)
    
    writer.add_scalar("Loss/" + run_type, loss, epoch, wall_time)
    
    for metric_base_name in log_metrics:
        for disease in TRAINED_WITH_DISEASES:
            metric_value = metrics.get("{}_{}".format(metric_base_name, disease), -1)
            writer.add_scalar("{}_{}/{}".format(metric_base_name, disease, run_type), metric_value, epoch, wall_time)

In [35]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_results(trainer):
    epoch = trainer.state.epoch
    max_epochs = trainer.state.max_epochs
    
    # Run on evaluation
    validator.run(val_dataloader, 1)
    
    # Log metrics
    wall_time = time.time()
    
    write_results("train", trainer.state.metrics, epoch, wall_time)
    write_results("val", validator.state.metrics, epoch, wall_time)
    
    train_loss = trainer.state.metrics.get(loss_name, 0)
    val_loss = validator.state.metrics.get(loss_name, 0)
    
    print("Finished epoch {}/{}, loss {} (val {})".format(epoch, max_epochs, train_loss, val_loss))

In [36]:
avg_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
avg_loss.attach(trainer, loss_name)

In [37]:
attach_metrics(trainer, "acc", Accuracy, True)
attach_metrics(trainer, "prec", Precision, True)
attach_metrics(trainer, "recall", Recall, True)
attach_metrics(trainer, "roc_auc", RocAucMetric, False)

## Train

In [38]:
%%time
# %%capture train_output
trainer.run(train_dataloader, 10)

Finished epoch 1/10, loss 22.10311508178711 (val 22.23708152770996)
Finished epoch 2/10, loss 21.754119873046875 (val 22.12881851196289)
Finished epoch 3/10, loss 22.06107521057129 (val 22.014162063598633)
Finished epoch 4/10, loss 21.454994201660156 (val 21.79087257385254)
Finished epoch 5/10, loss 21.158710479736328 (val 21.436138153076172)
Finished epoch 6/10, loss 21.601560592651367 (val 21.173686981201172)
Finished epoch 7/10, loss 20.78107261657715 (val 21.03730010986328)
Finished epoch 8/10, loss 19.821331024169922 (val 20.71066665649414)
Finished epoch 9/10, loss 20.337358474731445 (val 20.58563804626465)
Finished epoch 10/10, loss 20.258663177490234 (val 20.3443603515625)
CPU times: user 55.1 s, sys: 12.1 s, total: 1min 7s
Wall time: 49.3 s


<ignite.engine.engine.State at 0x7fd8117aed30>

In [251]:
train_output.show()

Finished epoch 1/3, loss 45.58734893798828 (val 44.312461853027344)
Finished epoch 2/3, loss 44.71006774902344 (val 44.178382873535156)
Finished epoch 3/3, loss 45.79662322998047 (val 43.93351745605469)


<ignite.engine.engine.State at 0x7fed9c0452e8>

### Write graph to Tensorboard (optional)

In [54]:
images = next(iter(train_dataloader))[0]
images = images.to(device)
images.size()

torch.Size([4, 1, 512, 512])

In [55]:
writer.add_graph(model, images)

### Close SummaryWriter

In [51]:
writer.close()

In [138]:
val_dataset.label_index

Unnamed: 0,FileName,Atelectasis,Cardiomegaly,Effusion,Infiltration,Mass,Nodule,Pneumonia,Pneumothorax,Consolidation,Edema,Emphysema,Fibrosis,Pleural_Thickening,Hernia
0,00000011_000.png,0,0,1,0,0,0,0,0,0,0,0,0,0,0
1,00000011_001.png,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,00000011_002.png,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,00000011_003.png,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,00000011_004.png,0,0,0,0,0,0,0,0,0,0,0,0,0,0
5,00000011_005.png,0,0,0,1,0,0,0,0,0,0,0,0,0,0
6,00000011_006.png,1,0,0,0,0,0,0,0,0,0,0,0,0,0
7,00000011_007.png,0,0,0,1,0,0,0,0,0,0,0,0,0,0
8,00000011_008.png,0,0,0,0,0,0,0,0,0,0,0,0,0,0
9,00000016_000.png,0,0,0,0,0,0,0,0,0,0,0,0,0,0


***
***

# Old

In [27]:
# FIXME: this code is to calculate the loss of the bbox prediction,
# though the metric should be different? don't use binary cross entropy?

#     # Get valid bbox_list
#     # REVIEW: make bbox_list a tensor? and then send to device? same on seg_list below
#     bbox_list = []
#     n_samples, n_labels = bbox_valids.size()
#     for i in range(n_samples):
#         bbox_list.append([])
#         for j in range(n_labels):
#             if bbox_valids[i][j] == 1:
#                 bbox_list[i].append(bboxes[i][j])
#         bbox_list[i] = torch.stack(bbox_list[i]).to(device)
    
    
#     # Segmentation lists
#     seg_list = []
#     for i in range(n_samples):
#         seg_list.append([])
#         for j in range(n_labels):
#             if bbox_valids[i][j] == 1:
#                 seg_list[i].append(segments[i][j])
#         seg_list[i] = torch.stack(seg_list[i]).to(device)
    
#     # Compute bbox loss
#     for i in range(len(seg_list)):
#         # REVIEW: do this with a tensor? avoid the loop
#         loss += 5/(512*512) * weighted_cross_entropy(seg_list[i], bbox_list[i], weights=(10, 1))
        
#         break
# #         print(seg_list[i].size())
# #         print(bbox_list[i].size())
# #         break

In [20]:
def train_iteration(model, data_batch, training=True):
    """."""
    # Input and sizes
    images, labels, names, _, _ = data_batch
    n_samples, n_labels = labels.size()
    
    # Move tensors to GPU
    images = images.to(device)
    labels = labels.to(device)

    # Enable training
    model.train(training)
    torch.set_grad_enabled(training) # enable recording gradients
    
    # zero the parameter gradients
    optimizer.zero_grad()

    # Forward, receive outputs from the model and segments (bboxes)
    outputs, segments = model(images)
    
    # Compute classification loss
    loss = weighted_bce(outputs, labels)
    
    batch_loss = loss.item()
    if training:
        loss.backward()
        optimizer.step()

    return batch_loss

In [21]:
def train(dataloader, n_epochs=N_EPOCHS, training=True):
    for i_epoch in range(n_epochs):

        training_loss = 0.0

        for i_batch, data in enumerate(dataloader):
            batch_loss = train_iteration(model, data, training=training)
            
            training_loss += batch_loss

        print(i_epoch, training_loss)

In [24]:
train(train_dataloader, n_epochs=5)

0 497.5148229598999
1 487.6466898918152
2 474.2359209060669
3 477.64395093917847
4 477.08784341812134
