In [1]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention,HIPT_LGP_FC
from dataloader import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

In [2]:
word_map=read_obj('./word_map_cls256.pickle')
# word_map

In [3]:
# Data parameters
# data_folder = '/media/ssd/caption data'  # folder with data files saved by create_input_files.py
df_path='/home/ss4yd/vision_transformer/captioning_vision_transformer/prepared_prelim_data_tokenized_cls256.pickle'
data_name = 'hipt_captioning_task'  # base name shared by data files

# Model parameters
emb_dim = 192  # dimension of word embeddings
encoder_dim=192
attention_dim = 192  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 120  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 1
workers = 1  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [4]:
def main():
    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       encoder_dim=encoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout,
                                       device=device)
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder =HIPT_LGP_FC(freeze_4k=True, pretrain_4k='vit4k_xs_dino', n_classes=2)
    #     encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None
    else:
            checkpoint = torch.load(checkpoint)
            start_epoch = checkpoint['epoch'] + 1
            epochs_since_improvement = checkpoint['epochs_since_improvement']
            best_bleu4 = checkpoint['bleu-4']
            decoder = checkpoint['decoder']
            decoder_optimizer = checkpoint['decoder_optimizer']
            encoder = checkpoint['encoder']
            encoder_optimizer = checkpoint['encoder_optimizer']
            if fine_tune_encoder is True and encoder_optimizer is None:
    #             encoder.fine_tune(fine_tune_encoder)
                encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                     lr=encoder_lr)

    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    train_loader = torch.utils.data.DataLoader(
            PreLoadedReps_v2(df_path,'train'),
            batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        PreLoadedReps_v2(df_path,'val'),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best)


In [5]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    Performs one epoch's training.
    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """

    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to GPU, if available
        imgs = imgs.to(device).squeeze(0)
        caps = caps.to(device)
        caplens = caplens.to(device)
        
        # Forward prop.
        imgs,_ = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores, _,_,_ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets, _,_,_ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))


def validate(val_loader, encoder, decoder, criterion):
    """
    Performs one epoch's validation.
    :param val_loader: DataLoader for validation data.
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :return: BLEU-4 score
    """
    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # explicitly disable gradient calculation to avoid CUDA memory error
    # solves the issue #57
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

            # Move to device, if available
            imgs = imgs.to(device).squeeze(0)
            caps = caps.to(device)
            caplens = caplens.to(device)

            # Forward prop.
            if encoder is not None:
                imgs,_ = encoder(imgs)
#             print(imgs.shape)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores, _,_,_ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
            targets, _,_,_ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # References
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)

            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)
        
        rev_word_map = {v: k for k, v in word_map.items()}
        
        ind=random.randint(0, len(references)-1)
        tref = references[ind]
        thyp = hypotheses[ind]
        print(thyp)
        print('Random reference length: {}, ind: {}'.format(len(references), ind))
        tref = [[rev_word_map[i] for i in x] for x in tref]
        thyp = [rev_word_map[i] for i in thyp]
        print('References: {}'.format(tref))
        print('\n---------------------------------------\n')
        print(thyp)
        print('\n---------------------------------------\n')

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4


In [6]:
if __name__ == '__main__':
    main()

# of Patches: 196
Loading Pretrained Local VIT model...
Done!
Freezing Pretrained Local VIT model
Done




Epoch: [0][0/110]	Batch Time 0.692 (0.692)	Data Load Time 0.432 (0.432)	Loss 69.0235 (69.0235)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][100/110]	Batch Time 0.086 (0.127)	Data Load Time 0.000 (0.004)	Loss 38.8560 (221.5551)	Top-5 Accuracy 57.143 (50.612)
Validation: [0/8]	Batch Time 0.169 (0.169)	Loss 228.7152 (228.7152)	Top-5 Accuracy 50.000 (50.000)	


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


[1, 2, 3, 3, 152, 152, 152, 154]
Random reference length: 8, ind: 1
References: [['2', 'pieces', ';', 'chronic', 'inflammation', 'and', 'fibrosis', '<end>']]

---------------------------------------

['2', 'pieces', ',', ',', '<unk>', '<unk>', '<unk>', '<end>']

---------------------------------------


 * LOSS - 355.417, TOP-5 ACCURACY - 51.327, BLEU-4 - 3.6148560179405957e-78

Epoch: [1][0/110]	Batch Time 0.163 (0.163)	Data Load Time 0.057 (0.057)	Loss 51.9502 (51.9502)	Top-5 Accuracy 62.500 (62.500)
Epoch: [1][100/110]	Batch Time 0.170 (0.121)	Data Load Time 0.000 (0.001)	Loss 103.1827 (223.8084)	Top-5 Accuracy 54.545 (57.166)
Validation: [0/8]	Batch Time 0.217 (0.217)	Loss 49.9007 (49.9007)	Top-5 Accuracy 100.000 (100.000)	
[1, 2, 3, 152, 152, 154, 154, 154]
Random reference length: 8, ind: 2
References: [['2', 'pieces', ';', 'chronic', 'inflammation', 'and', 'fibrosis', '<end>']]

---------------------------------------

['2', 'pieces', ',', '<unk>', '<unk>', '<end>', '<end>', '<e

Epoch: [10][0/110]	Batch Time 0.182 (0.182)	Data Load Time 0.058 (0.058)	Loss 227.1035 (227.1035)	Top-5 Accuracy 81.250 (81.250)
Epoch: [10][100/110]	Batch Time 0.118 (0.121)	Data Load Time 0.000 (0.001)	Loss 227.1631 (221.4153)	Top-5 Accuracy 81.250 (86.932)
Validation: [0/8]	Batch Time 0.279 (0.279)	Loss 52.0193 (52.0193)	Top-5 Accuracy 62.500 (62.500)	
[1, 2, 3, 152, 124, 154, 152, 154]
Random reference length: 8, ind: 0
References: [['2', 'pieces', ';', 'chronic', 'inflammation', 'and', 'fibrosis', '<end>']]

---------------------------------------

['2', 'pieces', ',', '<unk>', 'macrovesicular', '<end>', '<unk>', '<end>']

---------------------------------------


 * LOSS - 355.012, TOP-5 ACCURACY - 56.637, BLEU-4 - 0.08876809667701213


Epochs since last improvement: 7

Epoch: [11][0/110]	Batch Time 0.148 (0.148)	Data Load Time 0.053 (0.053)	Loss 50.1035 (50.1035)	Top-5 Accuracy 87.500 (87.500)
Epoch: [11][100/110]	Batch Time 0.088 (0.123)	Data Load Time 0.000 (0.001)	Loss 36.464

Epoch: [19][0/110]	Batch Time 0.274 (0.274)	Data Load Time 0.078 (0.078)	Loss 842.1014 (842.1014)	Top-5 Accuracy 93.333 (93.333)
Epoch: [19][100/110]	Batch Time 0.183 (0.121)	Data Load Time 0.000 (0.001)	Loss 257.0497 (219.2267)	Top-5 Accuracy 88.235 (98.604)
Validation: [0/8]	Batch Time 0.078 (0.078)	Loss 844.8488 (844.8488)	Top-5 Accuracy 60.000 (60.000)	
[1, 2, 3, 152, 29, 32, 79, 32, 19, 25, 47, 136, 21, 3, 21, 152]
Random reference length: 8, ind: 2
References: [['2', 'pieces', ',', 'predominantly', 'fatty', 'with', '<unk>', '%', 'fibrous', 'stroma', 'and', 'ductal/lobular', 'elements', ',', '<unk>', '<end>']]

---------------------------------------

['2', 'pieces', ',', '<unk>', 'changes', 'elements', 'gynecomastoid', 'elements', 'tissue', 'fat', 'and', 'attached', 'delineated', ',', 'delineated', '<unk>']

---------------------------------------


 * LOSS - 355.550, TOP-5 ACCURACY - 59.292, BLEU-4 - 0.09831005086310621


Epochs since last improvement: 16


DECAYING learning rat