In [1]:
import numpy as np
import torch
from torchvision import transforms
import apex
import csv
import data
import models
import ast
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

### Data Loading

In [2]:
batch_size = 8

def parse_list(input_str):    
    return ast.literal_eval(input_str)

reports = {}

with open("./cleaned_reports.csv") as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    line_count = 0
    for row in csv_reader:
        if line_count == 0:
            line_count += 1
        else:
            uid, problems, findings, impression = row[1:]
            reports[str(uid)] = (parse_list(problems), findings, impression)

def create_report_splits(reports, seed=1337):
    uid_list = list(reports.keys())
    train_uids, valtest_uids = train_test_split(uid_list, test_size=0.2, random_state=seed)
    valid_uids, test_uids = train_test_split(valtest_uids, test_size=0.5, random_state=seed)
    
    train_reports = {}
    valid_reports = {}
    test_reports = {}
    splits = [train_uids, valid_uids, test_uids]
    output_reports = [train_reports, valid_reports, test_reports]
    
    for i in range(len(splits)):
        for uid in splits[i]:
            output_reports[i][str(uid)] = reports[str(uid)]
            
    return output_reports

train_reports, _, _ = create_report_splits(reports)

train_dataset = data.XRayDataset(
    reports=train_reports,
    transform=transforms.Compose([
        transforms.Resize(2048),
        transforms.CenterCrop((2048,2048)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        ]
))

print("Sample:")
image, impression = train_dataset.__getitem__(0)
print("* Image size:", image.size())
print("* Impression:", impression)
print("* Vocab:")
print(train_dataset.vocab)

train_dataloader = torch.utils.data.dataloader.DataLoader(train_dataset,
                                                          collate_fn=data.collate_fn,
                                                          pin_memory=True,
                                                          shuffle=True,
                                                          batch_size=batch_size,
                                                          num_workers=batch_size)

Sample:
* Image size: torch.Size([3, 2048, 2048])
* Impression: tensor([39, 40,  0, 26, 28, 46, 45, 30,  0, 28, 33, 26, 39, 32, 30, 44,  0, 31,
        43, 40, 38,  0, 41, 43, 34, 40, 43,  0, 34, 38, 26, 32, 34, 39, 32,  8])
* Vocab:
['4', 'g', '9', 'n', 'b', '3', '6', '1', 's', 'p', 'x', 'y', 'z', ';', ',', '/', "'", 'j', '&', 'e', '>', 'k', 'r', 'q', '<', 'v', ')', 'd', '2', 'm', '-', 'f', 'a', 'l', 'h', '.', 'o', ' ', 'w', 'i', '5', '%', '[', '7', 'u', ']', 'c', 't', '8', '0', '(', ':']


### Build Model

In [3]:
embed_size = 128
hidden_size = 128
num_layers = 3
learning_rate = 0.001
memory_format = torch.channels_last
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = models.EncoderCNN(embed_size).to(device, memory_format=memory_format)
decoder = models.DecoderRNN(embed_size, hidden_size, len(train_dataset.vocab), num_layers).to(device)

criterion = torch.nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.parameters())
optimizer = apex.optimizers.FusedAdam(params, lr=learning_rate)

[encoder, decoder], optimizer = apex.amp.initialize([encoder, decoder], optimizer, opt_level="O1")

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


### Train Model

In [None]:
num_epochs = 4
total_step = len(train_dataloader.dataset)//batch_size
encoder.train()
decoder.train()

print("Start training")

for epoch in range(num_epochs):
    print("\nEpoch", epoch+1, "/", num_epochs, ":\n")
    for i, (images, captions, lengths) in enumerate(tqdm(train_dataloader, total=total_step)):

        # Set mini-batch dataset
        images = images.cuda(non_blocking=True).contiguous(memory_format=memory_format)
        captions = captions.cuda(non_blocking=True).contiguous()
        targets = torch.nn.utils.rnn.pack_padded_sequence(captions, lengths, batch_first=True)[0]
        
        encoder.zero_grad()
        decoder.zero_grad()

        # Forward, backward and optimize
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        
        with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()

    print("train_loss - ", loss.item(), "- perplexity -", np.exp(loss.item()))

Start training

Epoch 1 / 4 :



HBox(children=(FloatProgress(value=0.0, max=339.0), HTML(value='')))

### Inference

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

encoder.eval()
decoder.eval()

In [None]:
for index in range(0, 10):
    image, impression = train_dataset.__getitem__(index)
    image_tensor = image.unsqueeze(0).cuda()
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    sampled_ids = list(sampled_ids[0].cpu().numpy())

    plt.title("Image: "+str(index))
    plt_img = np.moveaxis(image.numpy(), 0, -1)
    plt.imshow(plt_img)
    plt.show()

    print(" Original:", train_dataset.tokenizer.decode(impression))
    print("Generated:", train_dataset.tokenizer.decode(sampled_ids))
    print("")