In [1]:
import numpy as np
import torch
from torchvision import transforms
import apex
import data
import models

### Data Loading

In [2]:
batch_size = 8

train_dataset = data.XRayDataset(
    transform=transforms.Compose([
        transforms.Resize(2048),
        transforms.CenterCrop((2048,2048)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
        ]
))

print("Sample:")
image, impression = train_dataset.__getitem__(0)
print("* Image size:", image.size())
print("* Impression:", impression)
print("* Vocab:")
print(train_dataset.vocab)

train_dataloader = torch.utils.data.dataloader.DataLoader(train_dataset,
                                                          collate_fn=data.collate_fn,
                                                          pin_memory=True,
                                                          batch_size=batch_size,
                                                          num_workers=batch_size)

Number of reports: 3851
Skipped: 3648 images
Sample:
* Image size: torch.Size([3, 2048, 2048])
* Impression: tensor([37, 38, 41, 36, 24, 35,  0, 26, 31, 28, 42, 43,  0, 47,  6, 47, 47, 47,
        47,  7])
* Vocab:
['2', 'y', "'", '>', '-', 'm', 'e', 'v', 'o', 'b', 'u', '1', '6', 'a', 'c', '0', '3', 'k', '8', 'z', ':', 'r', '5', 'd', '.', '%', '<', '9', '(', 'i', 't', 'f', 'p', 'w', '"', 'q', ')', ' ', 'n', '7', '/', 'h', 's', 'x', ';', 'j', '[', 'g', 'l', '4']


### Build Model

In [3]:
embed_size = 128
hidden_size = 128
num_layers = 3
learning_rate = 0.001
memory_format = torch.channels_last
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = models.EncoderCNN(embed_size).to(device, memory_format=memory_format)
decoder = models.DecoderRNN(embed_size, hidden_size, len(train_dataset.vocab), num_layers).to(device)

criterion = torch.nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.parameters())
optimizer = apex.optimizers.FusedAdam(params, lr=learning_rate)

[encoder, decoder], optimizer = apex.amp.initialize([encoder, decoder], optimizer, opt_level="O1")

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


### Train Model

In [4]:
num_epochs = 3

total_step = len(train_dataloader.dataset)

print("Start training")

for epoch in range(num_epochs):
    for i, (images, captions, lengths) in enumerate(train_dataloader):

        # Set mini-batch dataset
        images = images.cuda(non_blocking=True).contiguous(memory_format=memory_format)
        captions = captions.cuda(non_blocking=True).contiguous()
        targets = torch.nn.utils.rnn.pack_padded_sequence(captions, lengths, batch_first=True)[0]
        
        encoder.zero_grad()
        decoder.zero_grad()

        # Forward, backward and optimize
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        
        with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()

        # Print log info
        if i % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}"
                  .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 

Start training
Epoch [0/3], Step [0/3484], Loss: 3.9329, Perplexity: 51.0573
Epoch [0/3], Step [100/3484], Loss: 2.9804, Perplexity: 19.6951
Epoch [0/3], Step [200/3484], Loss: 2.7066, Perplexity: 14.9776
Epoch [0/3], Step [300/3484], Loss: 2.3290, Perplexity: 10.2675
Epoch [0/3], Step [400/3484], Loss: 2.1858, Perplexity: 8.8977
Epoch [1/3], Step [0/3484], Loss: 1.6587, Perplexity: 5.2522
Epoch [1/3], Step [100/3484], Loss: 1.1066, Perplexity: 3.0240
Epoch [1/3], Step [200/3484], Loss: 0.9805, Perplexity: 2.6659


Exception in thread Thread-5:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/pin_memory.py", line 25, in _pin_memory_loop
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 294, in rebuild_storage_fd
    fd = df.detach()
  File "/opt/conda/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/opt/conda/lib/python3.6/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authkey=process.current_process().au

Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-be282afce6e4>", line 8, in <module>
    for i, (images, captions, lengths) in enumerate(train_dataloader):
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 841, in _next_data
    idx, data = self._get_data()
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 798, in _get_data
    success, data = self._try_get_data()
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 761, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/opt/conda/lib/python3.6/queue.py", line 173, in get
    self.not_empty.wait(remaining)
  File "/o

KeyboardInterrupt: 