In [1]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from dataloader import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

In [2]:
word_map=read_obj('./word_map.pickle')
# word_map

In [3]:
# Data parameters
# data_folder = '/media/ssd/caption data'  # folder with data files saved by create_input_files.py
df_path='/home/ss4yd/vision_transformer/captioning_vision_transformer/prepared_prelim_data_tokenized.pickle'
data_name = 'hipt_captioning_task'  # base name shared by data files

# Model parameters
emb_dim = 192  # dimension of word embeddings
encoder_dim=192
attention_dim = 192  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 120  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 32
workers = 1  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [4]:
def main():
    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       encoder_dim=encoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout,
                                       device=device)
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
    #     encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None
    else:
            checkpoint = torch.load(checkpoint)
            start_epoch = checkpoint['epoch'] + 1
            epochs_since_improvement = checkpoint['epochs_since_improvement']
            best_bleu4 = checkpoint['bleu-4']
            decoder = checkpoint['decoder']
            decoder_optimizer = checkpoint['decoder_optimizer']
            encoder = checkpoint['encoder']
            encoder_optimizer = checkpoint['encoder_optimizer']
            if fine_tune_encoder is True and encoder_optimizer is None:
    #             encoder.fine_tune(fine_tune_encoder)
                encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                     lr=encoder_lr)

    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    train_loader = torch.utils.data.DataLoader(
            PreLoadedReps(df_path,'train'),
            batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        PreLoadedReps(df_path,'val'),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best)


In [5]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    Performs one epoch's training.
    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """

    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores, _,_,_ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets, _,_,_ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))


def validate(val_loader, encoder, decoder, criterion):
    """
    Performs one epoch's validation.
    :param val_loader: DataLoader for validation data.
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :return: BLEU-4 score
    """
    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # explicitly disable gradient calculation to avoid CUDA memory error
    # solves the issue #57
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

            # Move to device, if available
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
#             print(imgs.shape)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores, _,_,_ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
            targets, _,_,_ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # References
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)

            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)
        
        rev_word_map = {v: k for k, v in word_map.items()}
        
        ind=random.randint(0, len(references)-1)
        tref = references[ind]
        thyp = hypotheses[ind]
        print(thyp)
        print('Random reference length: {}, ind: {}'.format(len(references), ind))
        tref = [[rev_word_map[i] for i in x] for x in tref]
        thyp = [rev_word_map[i] for i in thyp]
        print('References: {}'.format(tref))
        print('\n---------------------------------------\n')
        print(thyp)
        print('\n---------------------------------------\n')

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4


In [6]:
if __name__ == '__main__':
    main()

Epoch: [0][0/40]	Batch Time 1.928 (1.928)	Data Load Time 1.780 (1.780)	Loss 199.5123 (199.5123)	Top-5 Accuracy 3.800 (3.800)
Validation: [0/3]	Batch Time 0.209 (0.209)	Loss 136.4376 (136.4376)	Top-5 Accuracy 47.696 (47.696)	


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


[1, 2, 3, 3, 3, 650, 650, 650, 650, 650, 650, 650, 650, 650, 650]
Random reference length: 65, ind: 41
References: [['2', 'pieces', ';', 'congested', ',', 'includes', 'capsule', '(', 'target', 'is', '5mm', 'below', 'capsule', ')', '<end>']]

---------------------------------------

['2', 'pieces', ',', ',', ',', '<end>', '<end>', '<end>', '<end>', '<end>', '<end>', '<end>', '<end>', '<end>', '<end>']

---------------------------------------


 * LOSS - 162.493, TOP-5 ACCURACY - 48.285, BLEU-4 - 2.365110512942126e-78

Epoch: [1][0/40]	Batch Time 0.141 (0.141)	Data Load Time 0.074 (0.074)	Loss 176.9534 (176.9534)	Top-5 Accuracy 45.387 (45.387)
Validation: [0/3]	Batch Time 0.089 (0.089)	Loss 152.8171 (152.8171)	Top-5 Accuracy 50.900 (50.900)	
[1, 2, 16, 62, 3, 3, 650, 650, 650]
Random reference length: 65, ind: 49
References: [['2', 'pieces', ';', 'some', 'fibrosis', ',', 'pigmented', 'macrophages', '<end>']]

---------------------------------------

['2', 'pieces', ';', 'spermatogenesis'

Epoch: [12][0/40]	Batch Time 0.131 (0.131)	Data Load Time 0.073 (0.073)	Loss 180.9724 (180.9724)	Top-5 Accuracy 80.929 (80.929)
Validation: [0/3]	Batch Time 0.095 (0.095)	Loss 152.0077 (152.0077)	Top-5 Accuracy 67.769 (67.769)	
[47, 2, 3, 151, 151, 650, 650]
Random reference length: 65, ind: 58
References: [['6', 'pieces', ',', 'minimal', '<unk>', 'serosa', '<end>']]

---------------------------------------

['6', 'pieces', ',', 'atherosis', 'atherosis', '<end>', '<end>']

---------------------------------------


 * LOSS - 161.232, TOP-5 ACCURACY - 66.836, BLEU-4 - 0.13648932277729614

Epoch: [13][0/40]	Batch Time 0.133 (0.133)	Data Load Time 0.074 (0.074)	Loss 150.4352 (150.4352)	Top-5 Accuracy 83.467 (83.467)
Validation: [0/3]	Batch Time 0.093 (0.093)	Loss 185.8270 (185.8270)	Top-5 Accuracy 65.865 (65.865)	
[1, 2, 16, 6, 50, 86, 38, 16, 16, 648, 38, 16, 648, 38, 650]
Random reference length: 65, ind: 10
References: [['2', 'pieces', ';', 'fibrosis', 'and', 'lymphoid', 'infiltrate', '

Epoch: [24][0/40]	Batch Time 0.141 (0.141)	Data Load Time 0.074 (0.074)	Loss 157.5320 (157.5320)	Top-5 Accuracy 94.211 (94.211)
Validation: [0/3]	Batch Time 0.092 (0.092)	Loss 168.8684 (168.8684)	Top-5 Accuracy 64.450 (64.450)	
[1, 2, 3, 4, 334, 650, 191, 21, 21, 21, 650]
Random reference length: 65, ind: 16
References: [['2', 'pieces', ',', '<unk>', 'thyroiditis', ',', 'rep', 'lymphoid', 'aggregates', 'delineated', '<end>']]

---------------------------------------

['2', 'pieces', ',', 'no', 'colloid', '<end>', 'few', 'delineated', 'delineated', 'delineated', '<end>']

---------------------------------------


 * LOSS - 166.054, TOP-5 ACCURACY - 67.598, BLEU-4 - 0.14508619218381924

Epoch: [25][0/40]	Batch Time 0.138 (0.138)	Data Load Time 0.073 (0.073)	Loss 145.7591 (145.7591)	Top-5 Accuracy 97.347 (97.347)
Validation: [0/3]	Batch Time 0.087 (0.087)	Loss 156.6779 (156.6779)	Top-5 Accuracy 68.123 (68.123)	
[1, 2, 3, 6, 3, 72, 86, 50, 129, 185, 79, 650, 378, 315, 22, 22, 650]
Random r

Epoch: [35][0/40]	Batch Time 0.123 (0.123)	Data Load Time 0.073 (0.073)	Loss 123.9988 (123.9988)	Top-5 Accuracy 97.983 (97.983)
Validation: [0/3]	Batch Time 0.092 (0.092)	Loss 154.8651 (154.8651)	Top-5 Accuracy 64.785 (64.785)	
[1, 2, 3, 90, 419, 86, 50, 86, 220, 180, 79, 650, 648, 648, 648, 650]
Random reference length: 65, ind: 41
References: [['2', 'pieces', ';', 'diffuse', 'interstitial', 'fibrosis', 'with', '<unk>', 'and', 'chronic', 'inflammation', ',', 'abundant', 'pigmented', 'macrophages', '<end>']]

---------------------------------------

['2', 'pieces', ',', 'severe', 'alveolar', 'fibrosis', 'and', 'fibrosis', 'pneumonia', 'hemorrhage', 'inflammation', '<end>', '<unk>', '<unk>', '<unk>', '<end>']

---------------------------------------


 * LOSS - 160.692, TOP-5 ACCURACY - 66.709, BLEU-4 - 0.13074555838787116


Epochs since last improvement: 10

Epoch: [36][0/40]	Batch Time 0.135 (0.135)	Data Load Time 0.073 (0.073)	Loss 197.5754 (197.5754)	Top-5 Accuracy 99.535 (99.535)
V