In [1]:
import numpy as np
import torch
from torchvision import transforms
import apex
import data
import models

### Data Loading

In [2]:
batch_size = 8

train_dataset = data.XRayDataset(
    transform=transforms.Compose([
        transforms.Resize(2048),
        transforms.CenterCrop((2048,2048)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        ]
))

print("Sample:")
image, impression = train_dataset.__getitem__(0)
print("* Image size:", image.size())
print("* Impression:", impression)
print("* Vocab:")
print(train_dataset.vocab)

train_dataloader = torch.utils.data.dataloader.DataLoader(train_dataset,
                                                          collate_fn=data.collate_fn,
                                                          pin_memory=True,
                                                          shuffle=True,
                                                          batch_size=batch_size,
                                                          num_workers=batch_size)

Number of reports: 3851
Skipped: 3648 images
Sample:
* Image size: torch.Size([3, 2048, 2048])
* Impression: tensor([37, 38, 41, 36, 24, 35,  0, 26, 31, 28, 42, 43,  0, 47,  6, 47, 47, 47,
        47,  7])
* Vocab:
['k', 'o', 'i', 'p', 'j', "'", ')', 'd', 'x', '%', ':', '4', 'f', '5', 'e', '[', 'l', '1', 'c', 's', 'u', 'b', '9', '3', '0', '"', 'z', 'q', ' ', 'm', 'h', 'y', '6', 't', '-', 'r', 'n', 'w', '/', '>', '8', '7', 'g', '2', '.', 'a', 'v', '<', '(', ';']


### Build Model

In [3]:
embed_size = 128
hidden_size = 128
num_layers = 3
learning_rate = 0.001
memory_format = torch.channels_last
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = models.EncoderCNN(embed_size).to(device, memory_format=memory_format)
decoder = models.DecoderRNN(embed_size, hidden_size, len(train_dataset.vocab), num_layers).to(device)

criterion = torch.nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.parameters())
optimizer = apex.optimizers.FusedAdam(params, lr=learning_rate)

[encoder, decoder], optimizer = apex.amp.initialize([encoder, decoder], optimizer, opt_level="O1")

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


### Train Model

In [None]:
num_epochs = 4
total_step = len(train_dataloader.dataset)//batch_size
encoder.train()
decoder.train()

print("Start training")

for epoch in range(num_epochs):
    for i, (images, captions, lengths) in enumerate(train_dataloader):

        # Set mini-batch dataset
        images = images.cuda(non_blocking=True).contiguous(memory_format=memory_format)
        captions = captions.cuda(non_blocking=True).contiguous()
        targets = torch.nn.utils.rnn.pack_padded_sequence(captions, lengths, batch_first=True)[0]
        
        encoder.zero_grad()
        decoder.zero_grad()

        # Forward, backward and optimize
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        
        with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()

        # Print log info
        if i % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}"
                  .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 

Start training
Epoch [0/4], Step [0/435], Loss: 3.9344, Perplexity: 51.1323
Epoch [0/4], Step [100/435], Loss: 2.9513, Perplexity: 19.1316
Epoch [0/4], Step [200/435], Loss: 2.5882, Perplexity: 13.3057


### Inference

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

encoder.eval()
decoder.eval()

In [None]:
for index in range(0, 10):
    image, impression = train_dataset.__getitem__(index)
    image_tensor = image.unsqueeze(0).cuda()
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    sampled_ids = list(sampled_ids[0].cpu().numpy())

    plt.title("Image: "+str(index))
    plt_img = np.moveaxis(image.numpy(), 0, -1)
    plt.imshow(plt_img)
    plt.show()

    print(" Original:", train_dataset.tokenizer.decode(impression))
    print("Generated:", train_dataset.tokenizer.decode(sampled_ids))
    print("")