In [1]:
#imports
import json
import os
import argparse
import h5py
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence, PackedSequence, pack_padded_sequence, pad_packed_sequence
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.text.bleu import BLEUScore

from copy import deepcopy
import math

In [2]:
EMB_PATH = '/data04/shared/skapse/REG_processed/20x_512px_0px_overlap'
REPORTS_JSON_PATH = '/home/sbsingh/projects/report_gen/train.json'
model_ckpt_path = '/home/sbsingh/projects/report_gen/pretrained/model_best.pth'

In [3]:
def log_message(*message):
    if debug:
        print('_______'.join(message))

def read_json_file(json_path):
        with open(json_path) as f:
            d = json.load(f)
        return d

def penalty_builder(penalty_config):
    if penalty_config == '':
        return lambda x, y: y
    pen_type, alpha = penalty_config.split('_')
    alpha = float(alpha)
    if pen_type == 'wu':
        return lambda x, y: length_wu(x, y, alpha)
    if pen_type == 'avg':
        return lambda x, y: length_average(x, y, alpha)


def length_wu(length, logprobs, alpha=0.):
    """
    NMT length re-ranking score from
    "Google's Neural Machine Translation System" :cite:`wu2016google`.
    """

    modifier = (((5 + length) ** alpha) /
                ((5 + 1) ** alpha))
    return logprobs / modifier


def length_average(length, logprobs, alpha=0.):
    """
    Returns the average probability of tokens in a sequence.
    """
    if length<alpha:
        penalty= -1000
    else:
        penalty = logprobs / length
    log_message(f'length: {length}, logprobs: {logprobs}, penalty: {penalty}')
    return penalty


def split_tensors(n, x):
    if torch.is_tensor(x):
        assert x.shape[0] % n == 0
        x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
    elif type(x) is list or type(x) is tuple:
        x = [split_tensors(n, _) for _ in x]
    elif x is None:
        x = [None] * n
    return x


def repeat_tensors(n, x):
    """
    For a tensor of size Bx..., we repeat it n times, and make it Bnx...
    For collections, do nested repeat
    """
    if torch.is_tensor(x):
        x = x.unsqueeze(1)  # Bx1x...
        x = x.expand(-1, n, *([-1] * len(x.shape[2:])))  # Bxnx...
        x = x.reshape(x.shape[0] * n, *x.shape[2:])  # Bnx...
    elif type(x) is list or type(x) is tuple:
        x = [repeat_tensors(n, _) for _ in x]
    return x


def clones(module, N):
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

def pad_tokens(att_feats):
    # ---->pad
    H = att_feats.shape[1]
    _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
    add_length = _H * _W - H
    att_feats = torch.cat([att_feats, att_feats[:, :add_length, :]], dim=1)  # [B, N, L]
    return att_feats

def sort_pack_padded_sequence(input, lengths):
    lengths =lengths.cpu()
    # log_message(f'lengths: {lengths}')
    sorted_lengths, indices = torch.sort(lengths, descending=True)
    # log_message(input[indices], sorted_lengths)
    tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
    inv_ix = indices.clone()
    inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
    return tmp, inv_ix


def pad_unsort_packed_sequence(input, inv_ix):
    tmp, _ = pad_packed_sequence(input, batch_first=True)
    tmp = tmp[inv_ix]
    return tmp

def pack_wrapper(module, att_feats, att_masks):
    # print(module, att_feats, att_masks)
    if att_masks is not None:
        packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
        return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
    else:
        return module(att_feats)

In [4]:

class Tokenizer:

    def __init__(self, reports_json_path, threshold = 1):
        self.__threshold = threshold
        self.__token2idx, self.__idx2token = self.create_vocabulary(reports_json_path)


    def create_vocabulary(self, reports_json_path):
        total_tokens = []
        reports = read_json_file(reports_json_path)

        for report in reports:
            tokens = report['report'].split()
            # for token in tokens:
            #     total_tokens.append(token)
            total_tokens.extend(tokens)

        counter = Counter(total_tokens)
        vocab = [k for k, v in counter.items() if v >= self.__threshold] + ['<unk>']
        vocab.sort()
        token2idx, idx2token = {}, {}
        for idx, token in enumerate(vocab):
            token2idx[token] = idx + 1
            idx2token[idx + 1] = token

        return token2idx, idx2token

    def get_token_by_id(self, id):
        return self.__idx2token[id]

    def get_id_by_token(self, token):
        if token not in self.__token2idx:
            return self.__token2idx['<unk>']
        return self.__token2idx[token]

    def get_vocab_size(self):
        return len(self.__token2idx)

    def __call__(self, report):
        tokens = report.split()
        ids = []
        for token in tokens:
            ids.append(self.get_id_by_token(token))
        ids = [0] + ids + [0]
        return ids

    def decode(self, ids):
        # print(ids)
        txt = ''
        for i, idx in enumerate(ids):
            if idx > 0:
                if i >= 1:
                    txt += ' '
                txt += self.get_token_by_id(idx)
            else:
                break
        return txt

    def batch_decode(self, ids_batch):
        # print(f'ids_batch: {ids_batch}, {ids_batch.shape}')
        out = []
        for ids in ids_batch:
            out.append(self.decode(ids))
        return out

In [5]:
class EmbeddingDataset(Dataset):

    def __init__(self, embeddings_path, reports_json_path, tokenizer, max_seq_length):
        reports = read_json_file(reports_json_path)
        self.__reports = {report['id'].split('.')[0]: report['report'] for report in reports}
        self.__tokenizer = tokenizer
        self.__embeddings_path = embeddings_path
        self.__max_seq_length = max_seq_length

        files = os.listdir(embeddings_path)
        self.__slides = [file.split('.')[0] for file in files]

    def __len__(self):
        return len(self.__slides)

    def __getitem__(self, idx):
        slide_id = self.__slides[idx]
        with h5py.File(f'{self.__embeddings_path}/{slide_id}.h5', "r") as h5_file:
            coords_np = h5_file["coords"][:]
            embeddings_np = h5_file["features"][:]

            coords = torch.tensor(coords_np).float() 
            embedding = torch.tensor(embeddings_np)
            report_text = self.__reports[slide_id]
            report_ids = self.__tokenizer(report_text)

            if len(report_ids) < self.__max_seq_length:
                padding = [0] * (self.__max_seq_length-len(report_ids))
                report_ids.extend(padding)

            report_masks = [1] * len(report_ids)
            seq_length = len(report_ids)


        return slide_id, embedding, coords, report_ids, report_masks, seq_length

In [6]:
tokenizer = Tokenizer(REPORTS_JSON_PATH)
embeddings_path = f'{EMB_PATH}/features_conch_v15'
max_seq_length = 300
embedding_dataset = EmbeddingDataset(embeddings_path, REPORTS_JSON_PATH, tokenizer, max_seq_length)

In [7]:
embedding_dataset[0]

('PIT_01_02983_01',
 tensor([[ 0.1380,  0.5441, -1.0661,  ...,  1.4840, -1.2144, -0.5514],
         [ 1.1093,  0.8226, -2.6591,  ...,  1.5781, -0.9008, -0.4615],
         [ 1.2667,  0.8689, -2.5663,  ...,  1.6840, -0.2095,  0.2890],
         ...,
         [-1.5927,  0.4193, -2.3413,  ...,  2.2291, -1.5176, -0.0998],
         [-0.4034,  0.9471, -1.8775,  ...,  2.4828, -1.8006,  0.1766],
         [-0.6539, -0.3768, -1.6462,  ...,  2.0273, -0.9711, -0.4058]]),
 tensor([[  512., 12288.],
         [ 1024., 11776.],
         [ 1024., 12288.],
         [ 1024., 12800.],
         [ 1024., 13312.],
         [ 1024., 13824.],
         [ 1024., 14336.],
         [ 1536., 10240.],
         [ 1536., 10752.],
         [ 1536., 11264.],
         [ 1536., 11776.],
         [ 1536., 12288.],
         [ 1536., 12800.],
         [ 1536., 13312.],
         [ 1536., 13824.],
         [ 1536., 14336.],
         [ 2048., 10752.],
         [ 2048., 11264.],
         [ 2048., 11776.],
         [ 2048., 12288.]

In [8]:
#modules
class CaptionModel(nn.Module):
    def __init__(self):
        super().__init__()
        # super(CaptionModel, self).__init__()

    # implements beam search
    # calls beam_step and returns the final set of beams
    # augments log-probabilities with diversity terms when number of groups > 1

    def forward(self, *args, **kwargs):
        mode = kwargs.get('mode', 'forward')
        if 'mode' in kwargs:
            del kwargs['mode']
        return getattr(self, '_' + mode)(*args, **kwargs)

    def beam_search(self, init_state, init_logprobs, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
            local_time = t - divm
            unaug_logprobs = logprobs.clone()
            batch_size = beam_seq_table[0].shape[0]

            if divm > 0:
                change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
                for prev_choice in range(divm):
                    prev_decisions = beam_seq_table[prev_choice][:, :, local_time]  # Nxb
                    for prev_labels in range(bdash):
                        change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1),
                                            change.new_ones(batch_size, 1))

                if local_time == 0:
                    logprobs = logprobs - change * diversity_lambda
                else:
                    logprobs = logprobs - repeat_tensors(bdash, change) * diversity_lambda

            return logprobs, unaug_logprobs

        # does one step of classical beam search

        def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
            # INPUTS:
            # logprobs: probabilities augmented after diversity N*bxV
            # beam_size: obvious
            # t        : time instant
            # beam_seq : tensor contanining the beams
            # beam_seq_logprobs: tensor contanining the beam logprobs
            # beam_logprobs_sum: tensor contanining joint logprobs
            # OUPUTS:
            # beam_seq : tensor containing the word indices of the decoded captions Nxbxl
            # beam_seq_logprobs : log-probability of each decision made, NxbxlxV
            # beam_logprobs_sum : joint log-probability of each beam Nxb

            batch_size = beam_logprobs_sum.shape[0]
            vocab_size = logprobs.shape[-1]
            logprobs = logprobs.reshape(batch_size, -1, vocab_size)  # NxbxV
            if t == 0:
                assert logprobs.shape[1] == 1
                beam_logprobs_sum = beam_logprobs_sum[:, :1]
            candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs  # beam_logprobs_sum Nxb logprobs is NxbxV
            ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
            ys, ix = ys[:, :beam_size], ix[:, :beam_size]
            beam_ix = ix // vocab_size  # Nxb which beam
            selected_ix = ix % vocab_size  # Nxb # which world
            state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(
                -1)  # N*b which in Nxb beams

            if t > 0:
                # gather according to beam_ix
                assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) ==
                        beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
                beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))

                beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(
                    beam_seq_logprobs))

            beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1)  # beam_seq Nxbxl
            beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
                                logprobs.reshape(batch_size, -1).gather(1, ix)
            assert (beam_logprobs_sum == ys).all()
            _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
            beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1,
                                                                                      beam_ix.unsqueeze(-1).expand(-1,
                                                                                                                   -1,
                                                                                                                   vocab_size))  # NxbxV
            assert (_tmp_beam_logprobs == beam_logprobs).all()
            beam_seq_logprobs = torch.cat([
                beam_seq_logprobs,
                beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)

            new_state = [None for _ in state]
            for _ix in range(len(new_state)):
                #  copy over state in previous beam q to new beam at vix
                new_state[_ix] = state[_ix][:, state_ix]
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state

        # Start diverse_beam_search
        opt = kwargs['opt']
        temperature = opt.get('temperature', 1)  # This should not affect beam search, but will affect dbs
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        suppress_UNK = opt.get('suppress_UNK', 0)
        length_penalty = penalty_builder(opt.get('length_penalty', ''))
        bdash = beam_size // group_size  # beam per group

        batch_size = init_logprobs.shape[0]
        device = init_logprobs.device
        # INITIALIZATIONS
        beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
        beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in
                                   range(group_size)]
        beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
        state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
        logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
        # END INIT

        # Chunk elements in the args
        args = list(args)
        args = split_tensors(group_size, args)  # For each arg, turn (Bbg)x... to (Bb)x(g)x...
        if self.__class__.__name__ == 'AttEnsemble':
            args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
                    range(group_size)]  # group_name, arg_name, model_name
        else:
            args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]

        for t in range(self.max_seq_length + group_size - 1):
            for divm in range(group_size):
                if t >= divm and t <= self.max_seq_length + divm - 1:
                    # add diversity
                    logprobs = logprobs_table[divm]
                    # suppress previous word
                    if decoding_constraint and t - divm > 0:
                        logprobs.scatter_(1, beam_seq_table[divm][:, :, t - divm - 1].reshape(-1, 1).to(device),
                                          float('-inf'))
                    # suppress UNK tokens in the decoding
                    if suppress_UNK:
                        idx_unk = self.tokenizer.get_id_by_token('<unk>')
                        logprobs[:, idx_unk] = logprobs[:, idx_unk] - 1000
                        # diversity is added here
                    # the function directly modifies the logprobs values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    logprobs, unaug_logprobs = add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash)

                    # infer new beams
                    beam_seq_table[divm], \
                    beam_seq_logprobs_table[divm], \
                    beam_logprobs_sum_table[divm], \
                    state_table[divm] = beam_step(logprobs,
                                                  unaug_logprobs,
                                                  bdash,
                                                  t - divm,
                                                  beam_seq_table[divm],
                                                  beam_seq_logprobs_table[divm],
                                                  beam_logprobs_sum_table[divm],
                                                  state_table[divm])

                    # log_message(f' t: {t}, divm: {divm}, beam_seq_table[divm]: {beam_seq_table[divm]}, beam_logprobs_sum_table: {beam_logprobs_sum_table}')
                    # if time's up... or if end token is reached then copy beams
                    for b in range(batch_size):
                        is_end = beam_seq_table[divm][b, :, t - divm] == self.eos_idx
                        assert beam_seq_table[divm].shape[-1] == t - divm + 1
                        if t == self.max_seq_length + divm - 1:
                            is_end.fill_(1)
                        for vix in range(bdash):
                            if is_end[vix]:
                                final_beam = {
                                    'seq': beam_seq_table[divm][b, vix].clone(),
                                    'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
                                    'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
                                    'p': beam_logprobs_sum_table[divm][b, vix].item()
                                }
                                log_message(f"final_beam : {final_beam['seq']}, {final_beam['p']}, penalty: {length_penalty(t - divm + 1, final_beam['p'])}")
                                final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
                                done_beams_table[b][divm].append(final_beam)
                        beam_logprobs_sum_table[divm][b, is_end] -= 1000
                        
                    # log_message(f"done_beams_table: {list(map(lambda x: (x['seq'], x['p']), done_beams_table[0][0]))}")
                    # move the current group one step forward in time

                    it = beam_seq_table[divm][:, :, t - divm].reshape(-1)
                    logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
                            args[divm] + [state_table[divm]]))
                    logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)

        # all beams are sorted by their log-probabilities
        done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
                            for b in range(batch_size)]
        done_beams = [sum(_, []) for _ in done_beams_table]

        # print(f'done_beams: {done_beams}')
        log_message(f"done_beams: {list(map(lambda x: (x['seq'], x['p']), done_beams[0]))}")
        return done_beams


In [9]:
# attention model
class AttModel(CaptionModel):
    def __init__(self, args, tokenizer):
        super().__init__()
        self.args = args
        self.tokenizer = tokenizer
        self.vocab_size = tokenizer.get_vocab_size()
        self.input_encoding_size = args.d_model
        self.rnn_size = args.d_ff
        self.num_layers = args.num_layers
        self.drop_prob_lm = args.drop_prob_lm
        self.max_seq_length = args.max_seq_length
        self.att_feat_size = args.d_vf
        self.att_hid_size = args.d_model

        self.bos_idx = args.bos_idx
        self.eos_idx = args.eos_idx
        self.pad_idx = args.pad_idx

        self.use_bn = args.use_bn

        self.embed = lambda x: x
        self.fc_embed = lambda x: x
        self.att_embed = nn.Sequential(*(
                ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) +
                (nn.Linear(self.att_feat_size, self.input_encoding_size),
                 nn.ReLU(),
                 nn.Dropout(self.drop_prob_lm)) +
                ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ())))

        self.out1 = nn.Sequential(
            nn.Linear(self.att_feat_size, 1024),
            nn.Tanh()
        )
        self.out2 = nn.Sequential(
            nn.Linear(self.rnn_size, self.rnn_size),
            nn.Tanh()
        )
        self.ln = nn.LayerNorm(self.att_feat_size)

    def clip_att(self, att_feats, att_masks):
        # Clip the length of att_masks and att_feats to the maximum length
        if att_masks is not None:
            max_len = att_masks.data.long().sum(1).max()
            att_feats = att_feats[:, :max_len].contiguous()
            att_masks = att_masks[:, :max_len].contiguous()
        return att_feats, att_masks

    def multimodal_feat(self, att_feats, meshes):# Concate multimodal features
        return torch.cat((self.ln(att_feats),self.ln(meshes)),dim=1)
        # return torch.cat((self.ln(self.out1(att_feats)),self.ln(self.out2(meshes))),dim=1)

    def _prepare_feature(self, fc_feats, att_feats, att_masks):
        att_feats, att_masks = self.clip_att(att_feats, att_masks)

        # embed fc and att feats
        fc_feats = self.fc_embed(fc_feats)
        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)

        # Project the attention feats first to reduce memory and computation comsumptions.
        p_att_feats = self.ctx2att(att_feats)

        return fc_feats, att_feats, p_att_feats, att_masks

    def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
        # 'it' contains a word index
        xt = self.embed(it)

        output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
        if output_logsoftmax:
            logprobs = F.log_softmax(self.logit(output), dim=1)
        else:
            logprobs = self.logit(output)

        return logprobs, state

    def _sample_beam(self, fc_feats, att_feats, att_masks=None, meshes=None, opt={}):
        # print(f'opt: {opt}')
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        sample_n = opt.get('sample_n', 10)
        # when sample_n == beam_size then each beam is a sample.
        assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
        batch_size = fc_feats.size(0)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
        seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
        # lets process every image independently for now, for simplicity

        self.done_beams = [[] for _ in range(batch_size)]

        state = self.init_hidden(batch_size)

        # first step, feed bos
        it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
        logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(beam_size,
                                                                                  [p_fc_feats, p_att_feats,
                                                                                   pp_att_feats, p_att_masks]
                                                                                  )
        self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
        # print(f'self.done_beams: {self.done_beams}')
        for k in range(batch_size):
            if sample_n == beam_size:
                for _n in range(sample_n):
                    seq_len = self.done_beams[k][_n]['seq'].shape[0]
                    seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq']
                    seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps']
                    # print(f'---> seq: {seq}, seqLogprobs: {seqLogprobs}')
            else:
                seq_len = self.done_beams[k][0]['seq'].shape[0]
                seq[k, :seq_len] = self.done_beams[k][0]['seq']  # the first beam has highest cumulative score
                seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
        # return the samples and their log likelihoods
        # print(f'_sample_beam: seq {seq}, seqLogprobs {seqLogprobs}')
        return seq, seqLogprobs

    def _sample(self, fc_feats, att_feats, meshes=None, att_masks=None):
        opt = self.args.__dict__
        sample_method = opt.get('sample_method', 'greedy')
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        sample_n = int(opt.get('sample_n', 1))
        group_size = opt.get('group_size', 1)
        output_logsoftmax = opt.get('output_logsoftmax', 1)
        decoding_constraint = opt.get('decoding_constraint', 0)
        block_trigrams = opt.get('block_trigrams', 0)
        if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
            return self._sample_beam(fc_feats, att_feats, att_masks, meshes, opt)
        if group_size > 1:
            return self._diverse_sample(fc_feats, att_feats, att_masks, meshes, opt)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size * sample_n)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks, meshes)

        if sample_n > 1:
            p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(sample_n,
                                                                                      [p_fc_feats, p_att_feats,
                                                                                       pp_att_feats, p_att_masks]
                                                                                      )

        trigrams = []  # will be a list of batch_size dictionaries

        seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
        for t in range(self.max_seq_length + 1):
            if t == 0:  # input <bos>
                it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long)

            logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state,
                                                      output_logsoftmax=output_logsoftmax)

            if decoding_constraint and t > 0:
                tmp = logprobs.new_zeros(logprobs.size())
                tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
                logprobs = logprobs + tmp

            # Mess with trigrams
            # Copy from https://github.com/lukemelas/image-paragraph-captioning
            if block_trigrams and t >= 3:
                # Store trigram generated at last step
                prev_two_batch = seq[:, t - 3:t - 1]
                for i in range(batch_size):  # = seq.size(0)
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    current = seq[i][t - 1]
                    if t == 3:  # initialize
                        trigrams.append({prev_two: [current]})  # {LongTensor: list containing 1 int}
                    elif t > 3:
                        if prev_two in trigrams[i]:  # add to list
                            trigrams[i][prev_two].append(current)
                        else:  # create list
                            trigrams[i][prev_two] = [current]
                # Block used trigrams at next step
                prev_two_batch = seq[:, t - 2:t]
                mask = torch.zeros(logprobs.size(), requires_grad=False).cuda()  # batch_size x vocab_size
                for i in range(batch_size):
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    if prev_two in trigrams[i]:
                        for j in trigrams[i][prev_two]:
                            mask[i, j] += 1
                # Apply mask to log probs
                # logprobs = logprobs - (mask * 1e9)
                alpha = 2.0  # = 4
                logprobs = logprobs + (mask * -0.693 * alpha)  # ln(1/2) * alpha (alpha -> infty works best)

            # sample the next word
            if t == self.max_seq_length:  # skip if we achieve maximum length
                break
            it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)

            # stop when all finished
            if t == 0:
                unfinished = it != self.eos_idx
            else:
                it[~unfinished] = self.pad_idx  # This allows eos_idx not being overwritten to 0
                logprobs = logprobs * unfinished.unsqueeze(1).float()
                unfinished = unfinished * (it != self.eos_idx)
            seq[:, t] = it
            seqLogprobs[:, t] = logprobs
            # quit loop if all sequences have finished
            if unfinished.sum() == 0:
                break

        return seq, seqLogprobs

    def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):

        sample_method = opt.get('sample_method', 'greedy')
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        block_trigrams = opt.get('block_trigrams', 0)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        trigrams_table = [[] for _ in range(group_size)]  # will be a list of batch_size dictionaries

        seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in
                     range(group_size)]
        seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)]
        state_table = [self.init_hidden(batch_size) for _ in range(group_size)]

        for tt in range(self.max_seq_length + group_size):
            for divm in range(group_size):
                t = tt - divm
                seq = seq_table[divm]
                seqLogprobs = seqLogprobs_table[divm]
                trigrams = trigrams_table[divm]
                if t >= 0 and t <= self.max_seq_length - 1:
                    if t == 0:  # input <bos>
                        it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
                    else:
                        it = seq[:, t - 1]  # changed

                    logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats,
                                                                          p_att_masks, state_table[divm])  # changed
                    logprobs = F.log_softmax(logprobs / temperature, dim=-1)

                    # Add diversity
                    if divm > 0:
                        unaug_logprobs = logprobs.clone()
                        for prev_choice in range(divm):
                            prev_decisions = seq_table[prev_choice][:, t]
                            logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda

                    if decoding_constraint and t > 0:
                        tmp = logprobs.new_zeros(logprobs.size())
                        tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
                        logprobs = logprobs + tmp

                    # Mess with trigrams
                    if block_trigrams and t >= 3:
                        # Store trigram generated at last step
                        prev_two_batch = seq[:, t - 3:t - 1]
                        for i in range(batch_size):  # = seq.size(0)
                            prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                            current = seq[i][t - 1]
                            if t == 3:  # initialize
                                trigrams.append({prev_two: [current]})  # {LongTensor: list containing 1 int}
                            elif t > 3:
                                if prev_two in trigrams[i]:  # add to list
                                    trigrams[i][prev_two].append(current)
                                else:  # create list
                                    trigrams[i][prev_two] = [current]
                        # Block used trigrams at next step
                        prev_two_batch = seq[:, t - 2:t]
                        mask = torch.zeros(logprobs.size(), requires_grad=False).cuda()  # batch_size x vocab_size
                        for i in range(batch_size):
                            prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                            if prev_two in trigrams[i]:
                                for j in trigrams[i][prev_two]:
                                    mask[i, j] += 1
                        # Apply mask to log probs
                        # logprobs = logprobs - (mask * 1e9)
                        alpha = 2.0  # = 4
                        logprobs = logprobs + (mask * -0.693 * alpha)  # ln(1/2) * alpha (alpha -> infty works best)

                    it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)

                    # stop when all finished
                    if t == 0:
                        unfinished = it != self.eos_idx
                    else:
                        unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx
                        it[~unfinished] = self.pad_idx
                        unfinished = unfinished & (it != self.eos_idx)  # changed
                    seq[:, t] = it
                    seqLogprobs[:, t] = sampleLogprobs.view(-1)

        return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table,
                                                                                           1).reshape(
            batch_size * group_size, -1)

In [10]:
# common
def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

class SublayerConnection(nn.Module):
    def __init__(self, d_model, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [11]:
#encoder
class Encoder(nn.Module):
    def __init__(self, layer, N, PAM):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.d_model)
        self.PAM = clones(PAM, N)
        self.N = N

    def forward(self, x, mask):
        s=[]
        for i,layer in enumerate(self.layers):
            x = layer(x, mask)
            s.append(self.PAM[i](x))


        o = s[0]
        for i in range(1,len(s)):
            o +=  s[i]
        return o


class EncoderLayer(nn.Module):
    def __init__(self, d_model, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(d_model, dropout), 2)
        self.d_model = d_model

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [12]:
#decoder
class DecoderLayer(nn.Module):
    def __init__(self, d_model, self_attn, src_attn, feed_forward, dropout):
        super().__init__()
        self.d_model = d_model
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(d_model, dropout), 3)

    def forward(self, x, hidden_states, src_mask, tgt_mask):
        m = hidden_states
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)


class Decoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.d_model)

    def forward(self, x, hidden_states, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, hidden_states, src_mask, tgt_mask)
        return self.norm(x)

In [13]:
#transformer

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed


    def forward(self, src, tgt, src_mask, tgt_mask):
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, hidden_states, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), hidden_states, src_mask, tgt_mask)


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        x, self.attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

    def attention(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class PAM(nn.Module):
    def __init__(self, dim=512):
        super(PAM, self).__init__()
        self.proj = nn.Conv2d(dim, dim, 13, 1, 13//2, groups=dim)
        self.proj1 = nn.Conv2d(dim, dim, 7, 1, 7//2, groups=dim)
        self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim)

    def forward(self, x):
        B, H, C = x.shape
        assert int(math.sqrt(H))**2==H, f'{x.shape}'
        cnn_feat = x.transpose(1, 2).view(B, C, int(math.sqrt(H)), int(math.sqrt(H))).contiguous()
        x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
        x = x.flatten(2).transpose(1, 2)

        return x


class EncoderDecoder(AttModel):

    def __init__(self, args, tokenizer):
        super().__init__(args, tokenizer)
        self.args = args
        self.num_layers = args.num_layers
        self.d_model = args.d_model
        self.d_ff = args.d_ff
        self.num_heads = args.num_heads
        self.dropout = args.dropout

        tgt_vocab = self.vocab_size + 1

        self.embeded = Embeddings(args.d_vf, tgt_vocab)
        self.model = self.__build_model(tgt_vocab)
        self.__init_model()

        self.logit = nn.Linear(args.d_model, tgt_vocab)
        self.logit_mesh = nn.Linear(args.d_model, args.d_model)

    def __build_model(self, tgt_vocab):
        attn = MultiHeadedAttention(self.num_heads, self.d_model)
        ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout)
        position = PositionalEncoding(self.d_model, self.dropout)
        pp = PAM(self.d_model)
        model = Transformer(
            Encoder(EncoderLayer(self.d_model, deepcopy(attn), deepcopy(ff), self.dropout), self.num_layers, pp),
            Decoder(
                DecoderLayer(self.d_model, deepcopy(attn), deepcopy(attn), deepcopy(ff), self.dropout),
                self.num_layers),
            lambda x: x,
            nn.Sequential(Embeddings(self.d_model, tgt_vocab), deepcopy(position))
        )
        return model


    def init_hidden(self, bsz):
        return []

    def __init_model(self):
        for p in self.model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _prepare_feature(self, fc_feats, att_feats, att_masks, meshes=None):
        att_feats = pad_tokens(att_feats)
        att_feats, seq, _, att_masks, seq_mask, _ = self._prepare_feature_forward(att_feats, att_masks, meshes)
        memory = self.model.encode(att_feats, att_masks)

        return fc_feats[..., :1], att_feats[..., :1], memory, att_masks

    def _prepare_feature_mesh(self, att_feats, att_masks=None, meshes=None):
        att_feats = pad_tokens(att_feats)
        att_feats, att_masks = self.clip_att(att_feats, att_masks)
        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)

        if att_masks is None:
            att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
        att_masks = att_masks.unsqueeze(-2)

        if meshes is not None:
            # crop the last one
            meshes = meshes[:, :-1]
            meshes_mask = (meshes.data > 0)
            meshes_mask[:, 0] += True

            meshes_mask = meshes_mask.unsqueeze(-2)
            meshes_mask = meshes_mask & subsequent_mask(meshes.size(-1)).to(meshes_mask)
        else:
            meshes_mask = None

        return att_feats, meshes, att_masks, meshes_mask

    def _prepare_feature_forward(self, att_feats, att_masks=None, meshes=None, seq=None):

        att_feats, att_masks = self.clip_att(att_feats, att_masks)
        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)

        if att_masks is None:
            att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
        att_masks = att_masks.unsqueeze(-2)

        if seq is not None:
            # crop the last one
            seq = seq[:, :-1]
            seq_mask = (seq.data > 0)
            seq_mask[:, 0] += True

            seq_mask = seq_mask.unsqueeze(-2)
            seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
        else:
            seq_mask = None

        if meshes is not None:
            # crop the last one
            meshes = meshes[:, :-1]
            meshes_mask = (meshes.data > 0)
            meshes_mask[:, 0] += True

            meshes_mask = meshes_mask.unsqueeze(-2)
            meshes_mask = meshes_mask & subsequent_mask(meshes.size(-1)).to(meshes_mask)
        else:
            meshes_mask = None

        return att_feats, seq, meshes, att_masks, seq_mask, meshes_mask

    def _forward(self, fc_feats, att_feats, report_ids, att_masks=None):
        # log_message(fc_feats, att_feats, report_ids, att_masks)
        att_feats, report_ids, att_masks, report_mask = self._prepare_feature_mesh(att_feats, att_masks, report_ids)
        out = self.model(att_feats, report_ids, att_masks, report_mask)

        # print(f'out: {out}')
        outputs = F.log_softmax(self.logit(out), dim=-1)
        # print(f'outputs: {outputs}')

        return outputs

    def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):

        if len(state) == 0:
            ys = it.long().unsqueeze(1)
        else:
            ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
        out = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device))
        return out[:, -1], [ys.unsqueeze(0)]

    def _encode(self, fc_feats, att_feats, att_masks=None):

        att_feats, _, att_masks, _ = self._prepare_feature_mesh(att_feats, att_masks)
        out = self.model.encode(att_feats, att_masks)
        return out

In [14]:

class ReportGenModel(nn.Module):

    def __init__(self, args, tokenizer):
        super().__init__()
        self.__tokenizer = tokenizer

        self.prompt = nn.Parameter(torch.randn(1, 1, args.d_vf))


        self.positional_encoder = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, args.d_vf)
        )

        self.encoder_decoder = EncoderDecoder(args, tokenizer)


    def forward(self, image_embeddings, pos_embeddings, report_ids, patch_masks, mode='train'):
        coords_encoded = self.positional_encoder(pos_embeddings)
        patch_feats = image_embeddings + coords_encoded

        att_feats = torch.cat([self.prompt, patch_feats], dim=1)
        fc_feats = torch.sum(att_feats, dim=1)
        if mode == 'train':
            output = self.encoder_decoder(fc_feats, att_feats, report_ids, mode='forward')
        elif mode == 'sample':
            output, _ = self.encoder_decoder(fc_feats, att_feats, mode='sample')
        elif mode == 'encode':
            output = self.encoder_decoder(fc_feats, att_feats, mode='encode')

            logits = self.fc(output[0,0,:]).unsqueeze(0)
            Y_hat = torch.argmax(logits, dim=1)
            Y_prob = F.softmax(logits, dim=1)
            return Y_hat, Y_prob
        else:
            raise ValueError


        return output

In [15]:
# loss

class LanguageModelCriterion(nn.Module):
    def __init__(self):
        super(LanguageModelCriterion, self).__init__()

    def forward(self, input, target, mask):
        # truncate to the same size
        target = target[:, :input.size(1)]
        mask = mask[:, :input.size(1)]
        output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask
        output = torch.sum(output) / torch.sum(mask)

        return output

In [16]:
#models
class ReportModel(pl.LightningModule):

    def __init__(self, args, tokenizer, weight_decay=0.01):
        super().__init__()
        self.model = ReportGenModel(args, tokenizer)
        self.tokenizer = tokenizer
        self.__lr = args.lr
        self.__weight_decay = weight_decay
        self.rouge = ROUGEScore()
        self.bleu = BLEUScore(n_gram=1)

    def loss_fn(self, output, reports_ids, reports_masks):
        
        criterion = LanguageModelCriterion()
        loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean()
        return loss

    def training_step(self, batch):
        # print('train ---------->')
        _, patch_feats, pos_feats, report_ids, report_masks, patch_masks = batch
        output = self.model(patch_feats, pos_feats, report_ids, patch_masks, mode='train')
        # print(f'train output: {output}')
        loss = self.loss_fn(output, report_ids, report_masks)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # print('val ---------->')
        _, patch_feats, pos_feats, report_ids, report_masks, patch_masks = batch
        
        output_ = self.model(patch_feats, pos_feats, report_ids, patch_masks, mode='train')
        # print(f'val output: {output_}')
        loss = self.loss_fn(output_, report_ids, report_masks)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, sync_dist=True)

        
        if batch_idx%100==0:
            output = self.model(patch_feats, pos_feats, report_ids,patch_masks, mode='sample')
            pred_texts = self.tokenizer.batch_decode(output.cpu().numpy())
            target_texts = self.tokenizer.batch_decode(report_ids[:, 1:].cpu().numpy())
            # print(f'val output: {output_}')
            # print(f'report_ids: {report_ids}, output: {output}')
            print(f'pred_texts: {pred_texts}, target_texts: {target_texts}')

            rouge_score = self.rouge(pred_texts, target_texts)
            bleu_score = self.bleu(pred_texts, target_texts)
    
            self.log('val_rouge', rouge_score['rouge1_fmeasure'], on_epoch=True, prog_bar=True, sync_dist=True)
            self.log('val_bleu', bleu_score, on_epoch=True, prog_bar=True, sync_dist=True)

    def configure_optimizers(self):
        d_params = filter(lambda p: p.requires_grad, self.model.parameters())
        optimizer = torch.optim.AdamW(d_params, lr=self.__lr, weight_decay=self.__weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2)
        return [optimizer], [scheduler]

In [17]:
#datamodule

class PatchEmbeddingDataModule(pl.LightningDataModule):

    def __init__(self,args, tokenizer, split_frac, shuffle = False):
        super().__init__()
        self.test_ds = None
        self.val_ds = None
        self.train_ds = None
        self.__batch_size = args.batch_size
        self.__shuffle = shuffle
        self.__num_workers = args.num_workers
        self.__embeddings_path = args.embeddings_path
        self.__reports_json_path = args.reports_json_path
        self.__max_seq_length = args.max_seq_length
        self.__split_frac = split_frac
        self.__tokenizer = tokenizer

    def setup(self, stage=None):
        dataset = EmbeddingDataset(self.__embeddings_path, self.__reports_json_path, self.__tokenizer,
                              self.__max_seq_length)
        self.train_ds, self.val_ds, self.test_ds = random_split(dataset, self.__split_frac)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.__batch_size, shuffle=self.__shuffle, collate_fn = self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.__batch_size, collate_fn = self.collate_fn)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.__batch_size, collate_fn = self.collate_fn)

    @staticmethod
    def collate_fn(batch):
        slide_ids, patch_feats, coord_feats, report_ids, report_masks, seq_length = zip(*batch)
        patch_feats_pad = pad_sequence(patch_feats, batch_first=True)
        coord_feats_pad =  pad_sequence(coord_feats, batch_first=True)
        patch_mask = torch.zeros(patch_feats_pad.shape[:2], dtype=torch.float32)
        for i, p in enumerate(patch_feats):
            patch_mask[i, :p.shape[0]] = 1

        return (slide_ids, patch_feats_pad, coord_feats_pad, torch.LongTensor(report_ids),
                torch.FloatTensor(report_masks), torch.FloatTensor(patch_mask))




In [18]:
#trainer

class Trainer:

    def __init__(self, args, model, tokenizer, split_frac=(0.75, 0.12, 0.13)):
        self.ckpt_path = args.ckpt_path
        self.max_epochs = args.max_epochs
        self.split_frac = split_frac
        self.datamodule = PatchEmbeddingDataModule(args, tokenizer, split_frac)
        self.model = model
        pl.seed_everything(42)

    def train(self, fast_dev_run=False):
        checkpoint_callback = ModelCheckpoint(
            dirpath=self.ckpt_path,  # Directory to save checkpoints
            filename="best_model",  # Naming convention
            monitor="val_loss",  # Metric to monitor for saving best checkpoints
            mode="min",  # Whether to minimize or maximize the monitored metric
            save_top_k=1,  # Number of best checkpoints to keep
            save_last=True  # Save the last checkpoint regardless of the monitored metric
        )
        early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=3, verbose=True, mode="min")
        trainer = pl.Trainer(
            max_epochs=self.max_epochs,
            callbacks=[checkpoint_callback, early_stop_callback],
            accelerator='gpu',
            devices=[1],
            strategy='auto',
            enable_progress_bar=True,
            log_every_n_steps=2,
            fast_dev_run=fast_dev_run
        )
        train_metrics = trainer.fit(
            self.model, datamodule=self.datamodule
        )

        return train_metrics

In [23]:
config = {
    'max_fea_length': 10000,
    'max_seq_length': 50,
    'threshold': 1,
    'batch_size': 1,
    'ckpt_path': 'checkpoints/1',
    'max_epochs': 100,
    'd_ff': 512,
    'd_vf': 768,
    'd_model': 512,
    'num_heads': 4,
    'num_layers': 4,
    'dropout': 0.1,
    'drop_prob_lm': 0.5,
    'bos_idx': 0,
    'eos_idx': 1,
    'pad_idx': 2,
    'use_bn': 0,
    'beam_size': 3,
    'num_workers': 2,
    'embeddings_path': embeddings_path,
    'reports_json_path': REPORTS_JSON_PATH,
    'sample_n': 1,
    'group_size': 1,
    'lr': 1e-6,
    'sample_method': 'beam_search',
    'temperature':1.0,
    'output_logsoftmax': 1,
    'decoding_constraint': 1,
    'suppress_UNK': 1,
    'block_trigrams': 1,
    'length_penalty': 'avg_5'
    
}
debug = False
args = argparse.Namespace(**config)
split_frac = [0.85, 0.10, 0.05]
model = ReportModel(args, tokenizer)

trainer = Trainer(args, model, tokenizer, split_frac)
# metrics = trainer.train(fast_dev_run=False)
# print(metrics)


Seed set to 42


In [24]:
ckpt = torch.load(model_ckpt_path)

In [25]:
model_dict = model.state_dict()
state_dict = {k:v for k,v in ckpt.items() if k in model_dict}
model_dict.update(state_dict) 
model.load_state_dict(model_dict)

<All keys matched successfully>

In [26]:
trainer = Trainer(args, model, tokenizer, split_frac)
metrics = trainer.train(fast_dev_run=False)
print(metrics)

Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name  | Type           | Params | Mode 
-------------------------------------------------
0 | model | ReportGenModel | 19.8 M | train
1 | rouge | ROUGEScore     | 0      | train
2 | bleu  | BLEUScore      | 0      | train
-------------------------------------------------
19.8 M    Trainable params
0         Non-trainable params
19.8 M    Total params
79.104    Total estimated model params size (MB)
244       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

pred_texts: ['Electrosurgical pneumocyte granulomatous 7% pneumocyte 7% pneumocyte 7% pneumocyte 7% pneumocyte 7% pneumocyte granulomatous 7% Squamous pneumocyte Squamous pneumocyte Squamous 100% Squamous 7% pneumocyte 100% 7% 100% 7% 100% 3 7% Squamous 7% 100% 7% 100% 3 Squamous pneumocyte 100% 7% grade: 7% Squamous 7% Squamous 7% pneumocyte 100% 7%'], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']


Training: |                                                                                                   …

Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: [''], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Metric val_loss improved. New best score: 1.795


Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: [''], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Metric val_loss improved by 0.133 >= min_delta = 0.0001. New best score: 1.662


Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: [''], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Metric val_loss improved by 0.007 >= min_delta = 0.0001. New best score: 1.654


Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: [''], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Metric val_loss improved by 0.010 >= min_delta = 0.0001. New best score: 1.644


Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: ['biopsy;'], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Metric val_loss improved by 0.001 >= min_delta = 0.0001. New best score: 1.643


Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: ['biopsy;'], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Metric val_loss improved by 0.001 >= min_delta = 0.0001. New best score: 1.643


Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: ['biopsy;'], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: ['biopsy;'], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Validation: |                                                                                                 …

pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, moderately differentiated']
pred_texts: [''], target_texts: ['Colon, colonoscopic biopsy; Chronic nonspecific inflammation']
pred_texts: [''], target_texts: ['Stomach, endoscopic biopsy; Adenocarcinoma, well differentiated']
pred_texts: [''], target_texts: ['Prostate, biopsy; No tumor present']
pred_texts: [''], target_texts: ['Breast, biopsy; Invasive carcinoma of no special type, grade III (Tubule formation: 3, Nuclear grade: 3, Mitoses: 2)']
pred_texts: [''], target_texts: ["Prostate, biopsy; Acinar adenocarcinoma, Gleason's score 6 (3+3), grade group 1, tumor volume: 10%"]
pred_texts: ['biopsy;'], target_texts: ['Breast, core-needle biopsy; Papillary neoplasm']


Monitored metric val_loss did not improve in the last 3 records. Best score: 1.643. Signaling Trainer to stop.


None
