### ライブラリの準備

###モジュールのインポートとGoogleドライブのマウント

In [None]:
import os
import glob
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import datetime
#from tqdm import tqdm
from tqdm.notebook import tqdm
import pickle
import random
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from PIL import Image
import skimage.transform
from collections import deque
from typing import Sequence, Dict, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import models
import torchvision.transforms as T
import torchvision.datasets as dataset
from torchvision.transforms import v2

from timm.scheduler import CosineLRScheduler
from transformers import  get_linear_schedule_with_warmup

#from transformers import AutoImageProcessor, AutoModel, AutoProcessor, CLIPVisionModel
from transformers import CLIPVisionModel, BertTokenizer, BertModel, BertForMaskedLM

import sys

import util
import levenshtein
#from model_ import TopLayer
from nltk import bleu_score
import ssl
from torch.amp import autocast, GradScaler

In [None]:
class PositionalEmbedding(nn.Module):
    '''
    位置埋め込み （Positional embedding）
    dim_embedding: 埋込み次元
    max_len      : 入力の最大系列長
    '''
    def __init__(self, dim_embedding: int, max_len: int=2048):
        super().__init__()

        self.pos_emb = nn.Embedding(max_len, dim_embedding)

    '''
    位置エンコーディングの順伝播
    x: 位置エンコーディングを埋め込む対象のテンソル,
       [バッチサイズ, 系列長, 埋め込み次元]
    '''
    def forward(self, x: torch.Tensor):
        seq = x.shape[1]
        positions = torch.arange(start=0, end=seq, step=1, device=x.device).to(torch.long)
        positions = self.pos_emb(positions)[:seq,:]
        
        return positions

### CaptioningTransformer

In [None]:
def logsumexp(x, dim=1):
    return torch.logsumexp(x.float(), dim=dim).type_as(x)

class DynamicCRF(nn.Module):
    def __init__(self, num_embedding, low_rank=32, beam_size=64):
        super().__init__()

        self.E1 = nn.Embedding(num_embedding, low_rank)
        self.E2 = nn.Embedding(num_embedding, low_rank)

        self.vocb = num_embedding
        self.rank = low_rank
        self.beam = beam_size

    def extra_repr(self):
        return "vocab_size={}, low_rank={}, beam_size={}".format(
            self.vocb, self.rank, self.beam)

    def forward(self, emissions, targets, masks, beam=None):
        numerator = self._compute_score(emissions, targets, masks)
        denominator = self._compute_normalizer(emissions, targets, masks, beam)
        return numerator - denominator

    def forward_decoder(self, emissions, masks=None, beam=None):
        return self._viterbi_decode(emissions, masks, beam)

    def _compute_score(self, emissions, targets, masks=None):
        batch_size, seq_len = targets.size()
        emission_scores = emissions.gather(2, targets[:, :, None])[:, :, 0]  # B x T
        transition_scores = (self.E1(targets[:, :-1]) * self.E2(targets[:, 1:])).sum(2)

        scores = emission_scores
        scores[:, 1:] += transition_scores

        if masks is not None:
            scores = scores * masks.type_as(scores)
        
        return scores.sum(-1)

    def _compute_normalizer(self, emissions, targets=None, masks=None, beam=None):
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        if targets is not None:
            #_emissions = emissions.scatter(2, targets[:, :, None], np.float('inf'))
            _emissions = emissions.scatter(2, targets[:, :, None], float('inf'))
            beam_targets = _emissions.topk(beam, 2)[1]
            beam_emission_scores = emissions.gather(2, beam_targets)
        else:
            beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D; position i - 1, previous step.
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D; position i, current step.
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        for i in range(1, seq_len):
            next_score = score[:, :, None] + beam_transition_matrix[:, i-1]
            next_score = logsumexp(next_score, dim=1) + beam_emission_scores[:, i]

            if masks is not None:
                score = torch.where(masks[:, i:i+1], next_score, score)
            else:
                score = next_score

        # Sum (log-sum-exp) over all possible tags
        return logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions, masks=None, beam=None):
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)

        traj_tokens, traj_scores = [], []
        finalized_tokens, finalized_scores = [], []

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        #print( "score size:", score.size() )
        dummy = torch.arange(beam, device=score.device).expand(*score.size()).contiguous()

        for i in range(1, seq_len):
            traj_scores.append(score)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1]
            _score, _index = _score.max(dim=1)
            _score = _score + beam_emission_scores[:, i]

            if masks is not None:
                score = torch.where(masks[:, i: i+1], _score, score)
                index = torch.where(masks[:, i: i+1], _index, dummy)
            else:
                score, index = _score, _index
            traj_tokens.append(index)

        # now running the back-tracing and find the best
        best_score, best_index = score.max(dim=1)
        finalized_tokens.append(best_index[:, None])
        finalized_scores.append(best_score[:, None])

        for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)):
            previous_index = finalized_tokens[-1]
            finalized_tokens.append(idx.gather(1, previous_index))
            finalized_scores.append(scs.gather(1, previous_index))

        finalized_tokens.reverse()
        finalized_tokens = torch.cat(finalized_tokens, 1)
        finalized_tokens = beam_targets.gather(2, finalized_tokens[:, :, None])[:, :, 0]

        finalized_scores.reverse()
        finalized_scores = torch.cat(finalized_scores, 1)
        finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]

        return finalized_scores, finalized_tokens

In [None]:
class TopLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim, crf_low_rank, crf_beam_size, dropout, padding_idx):
        super(TopLayer, self).__init__()

        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.padding_idx = padding_idx
        print( "in TopLyaer:" )
        self.crf_layer = DynamicCRF(num_embedding = vocab_size, low_rank = crf_low_rank, 
                                    beam_size = crf_beam_size)

        #self.one_more_layer_norm = nn.LayerNorm(embed_dim)
        #self.tgt_word_prj = nn.Linear(self.embed_dim, self.vocab_size)

    def forward(self, src_representation, src_input, tgt_input, is_training):
        '''
            src_representation : bsz x seqlen x embed_dim
            src_input : bsz x seqlen
            tgt_input : bsz x seqlen
        '''
        #assert src_input.size() == tgt_input.size()

        src_input = src_input.transpose(0, 1) # src_len x bsz
        #seqlen, bsz = src_input.size()
        seqlen, bsz = src_input.shape[:2]

        src_representation = F.dropout(src_representation, p=self.dropout, training=is_training)
        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim

        #src = src_representation

        #emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)
        emissions = src_representation
        #log_probs = torch.log_softmax(emissions, -1)
        #assert log_probs.size() == torch.Size([seqlen, bsz, self.vocab_size])

        emissions = emissions.transpose(0, 1) # [bsz x src_len x vocab_size]
        #emission_mask = ~tgt_input.eq(self.padding_idx) # [bsz x src_len] #pad のところは 0 padでないところが 1
        emission_mask = torch.ones_like( tgt_input, dtype=torch.bool ) #全部　pad でないとして 1
        batch_crf_loss = -1 * self.crf_layer(emissions, tgt_input, emission_mask) # [bsz]
        assert batch_crf_loss.size() == torch.Size([bsz])
        #return log_probs, batch_crf_loss
        #return torch.mean( batch_crf_loss )
        #print( "batch_crf_loss:", batch_crf_loss )
        return batch_crf_loss

    def decoding(self, src_representation, src_input):
        '''
            src_representation : bsz x seqlen x embed_dim
            src_input : bsz x seqlen
            tgt_input : bsz x seqlen
        '''
        src_input = src_input.transpose(0, 1) # src_len x bsz
        seqlen, bsz = src_input.size()

        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim
        src = src_representation

        emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)

        emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        _, finalized_tokens = self.crf_layer.forward_decoder(emissions)
        assert finalized_tokens.size() == torch.Size([bsz, seqlen])
        return finalized_tokens

    def length_ratio_decoding(self, src_representation, src_input, length_ratio):
        '''
            src_representation : 1 x seqlen x embed_dim
            src_input : 1 x seqlen
        '''
        src_input = src_input.transpose(0, 1) # src_len x bsz
        seqlen, bsz = src_input.size()

        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim
        src = src_representation

        emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)

        emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        valid_len = int(seqlen * length_ratio) + 1
        valid_emissions = emissions[:, :valid_len+1,:]
        _, finalized_tokens = self.crf_layer.forward_decoder(valid_emissions)
        return finalized_tokens


In [None]:
class CaptioningTransformer(nn.Module):
    '''
    CaptioningTransformerのコンストラクタ
    dim_embedding  : 埋め込み次元
    dim_feedforward: FNNの中間特徴次元
    num_heads      : マルチヘッドアテンションのヘッド数
    num_layers     : Transformerデコーダ層の数
    vocab_size     : 辞書の次元
    null_index     : NULLのID
    dropout        : ドロップアウト確率
    '''
    def __init__(self, img_size: int, dim_embedding: int, length_max: int,
                  vocab_size: int, tokenizer, dropout: float=0.1, model_id: str=''):
        super().__init__()

        #CLIP
        clip_model_id = "openai/clip-vit-large-patch14-336"
        self.clip_model = CLIPVisionModel.from_pretrained(clip_model_id, output_hidden_states = True)
        images = torch.randn( ( 1, 3, img_size, img_size ) )
        memory = self.clip_model( images )
        memory = memory.last_hidden_state
        img_length = memory.size(1)
        clip_dim = memory.size(2)
        
        # Dense Connector
        self.dc_linear = nn.Linear( clip_dim * 3, dim_embedding )
        self.dropout = nn.Dropout( dropout )        
        self.ln_memory = nn.LayerNorm( dim_embedding )

        # Down Sampling
        stride = img_length // ( length_max - 1 )
        self.conv1 = nn.Conv1d( dim_embedding, dim_embedding, 1, stride )
        print( "img_length:", img_length )
        print( "text_length_max:", length_max )
        print( "stride:", stride )
        memory = self.conv1( memory.transpose(1,2) ).transpose(1,2)
        print( "bert in memory size:", memory.size() )

        self.pos_emb = PositionalEmbedding( dim_embedding )
        
        self.bert = BertModel.from_pretrained( model_id )

        ## 単語出力分布計算
        self.ln_outputs = nn.LayerNorm( dim_embedding )
        self.linear = nn.Linear(dim_embedding, vocab_size)

        crf_low_rank = 32
        crf_beam_size = 256
        top_dropout = 0.0
        tgt_padding_idx = tokenizer.pad_token_id
        print( "initialize self.toplayer" )
        self.toplayer = TopLayer( vocab_size, dim_embedding, crf_low_rank, crf_beam_size, top_dropout, tgt_padding_idx )
        
        self.dim_embedding = dim_embedding

    ''' CaptioningTransformerの順伝播処理
    features: 画像特徴量 [バッチサイズ, 埋め込み次元]
    captions: 正解キャプション [バッチサイズ, 系列長]
    '''
    def forward(self, images: torch.Tensor ):

        self.device = images.device

        memory = self.clip_model( images )
        memory = self.dense_connector( memory )
        memory = self.dropout( memory )
        memory = self.ln_memory( memory )

        memory = self.conv1( memory.transpose(1,2) ).transpose(1,2)

        memory += self.pos_emb( memory )
        
        outputs = self.bert( inputs_embeds = memory ).last_hidden_state
        outputs = self.ln_outputs( outputs )
        logits = self.linear( outputs )
        
        return logits

    def inference(self, images: torch.Tensor ):

        self.device = images.device

        memory = self.clip_model( images )
        memory = self.dense_connector( memory )
        memory = self.dropout( memory )
        memory = self.ln_memory( memory )

        memory = self.conv1( memory.transpose(1,2) ).transpose(1,2)

        memory += self.pos_emb( memory )
        
        outputs = self.bert( inputs_embeds = memory ).last_hidden_state
        outputs = self.ln_outputs( outputs )
        emissions = self.linear( outputs )

        #emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        seqlen = emissions.size(1) #added by T.uchi
        length_ratio = 1.0         #added by T.uchi
        #length_ratio = 0.7         #added by T.uchi
        valid_len = int(seqlen * length_ratio) + 1
        valid_emissions = emissions[:, :valid_len+1,:]
        _, finalized_tokens = self.toplayer.crf_layer.forward_decoder(valid_emissions)
        return finalized_tokens

    def dense_connector(self, memory ):
        tmp1 = torch.tensor([], device = self.device )
        tmp2 = torch.tensor([], device = self.device )
        tmp_full = len( memory.hidden_states )
        tmp_half = tmp_full // 2
        for i in range( 0, tmp_half ):
            tmp1 = torch.cat( [tmp1, memory.hidden_states[i][None]], dim = 0 )
        tmp1 = torch.sum(tmp1, dim=0) / tmp_half
        for i in range( tmp_half, tmp_full ):
            tmp2 = torch.cat( [tmp2, memory.hidden_states[i][None]], dim = 0 )
        tmp2 = torch.sum(tmp2, dim=0 ) / ( tmp_full - tmp_half )
        tmp3 = torch.cat([tmp1, tmp2], dim=-1)
        tmp3 = torch.cat( [ memory.last_hidden_state, tmp3], dim = -1 )
        tmp3 = self.dc_linear( tmp3 )
        return tmp3

    def my_decode(self, token_list, tokenizer ):

        def my_index( l, x ):
            if x in l:
                return l.index(x)
            else:
                return -1
        if my_index( token_list, eos_token_id ) != -1:
            token_list = token_list[:my_index( token_list, eos_token_id )]
        else:
            token_list = token_list
            
        text = tokenizer.decode( token_list, skip_special_tokens = True )
        
        return text

In [None]:
class MyDataset(Dataset):
    def __init__(self, file_path: str, img_directory: str, transforms, tokenizer, length_max = None ) -> None:
        super().__init__()
        self.img_directory = img_directory
        self.transforms = transforms
        # TODO: fix to original data
        #画像の前処理
        self.img_file = []
        self.tokens = []
        #vocab_size = len( tokenizer )
        #c1 = torch.zeros( ( vocab_size ) )
        #c2 = torch.zeros( ( vocab_size, vocab_size ) )
        if length_max == None:
            self.length_max = 0
        else:
            self.length_max = length_max
        length_sum = 0

        with open( file_path, 'rb') as f:
            data = pickle.load(f)
        for i, line_data in enumerate( data ):
            if i % 100000 == 0:
                print( "i:", i )
            self.img_file.append( line_data['img_file'] )
            id_tokens = line_data['id_tokens']
            id_tokens.append( eos_token_id )
            id_tokens.append( eos_token_id )
            #if i % 100 == 0:
            #    print( "id_tokens:", id_tokens )
            length_sum += len( id_tokens )
            if length_max != None:
                id_tokens = torch.tensor( id_tokens )[:self.length_max]
            else:
                if self.length_max < len( id_tokens ):
                    self.length_max = len( id_tokens )
                id_tokens = torch.tensor( id_tokens )
            self.tokens.append( id_tokens )
        # w1, w2 を作る時は length_max = None　でお願いします。
        #    for i2 in range( len(id_tokens) ):
        #        if i2 == len( id_tokens ) - 1:
        #            c1[id_tokens[i2]] += 1
        #        else:
        #            c1[id_tokens[i2]] += 1
        #            c2[id_tokens[i2], id_tokens[i2+1] ] += 1
        '''
        c1avg = int( torch.sum( c1 ) / torch.sum( torch.ne( c1, 0 ).int()) )
        c2avg = int( torch.sum( torch.sum( c2, dim = 1 ), dim = 0 ) / torch.sum( torch.ne( c2, 0 ).int() ) )

        c1[0] = c1avg

        c2[:,0] = c2avg
        c2[0,:] = c2avg
        
        sumc1 = torch.sum( c1, dim = 0 )
        sumc2 = torch.sum( torch.sum( c2, dim = 1 ), dim = 0 )

        prob1 = c1 / sumc1
        prob2 = c2 / sumc2

        self.w1 = prob1 ** -0.4
        self.w1 = torch.nan_to_num( self.w1, nan = 0.0, posinf=0.0, neginf=0.0 )
        avg1 = torch.sum( self.w1, dim = 0 ) / torch.sum( torch.ne( self.w1, 0.0 ).int() )
        self.w1 = self.w1 / avg1

        self.w2 = prob2 ** -0.4
        self.w2 = torch.nan_to_num( self.w2, nan = 0.0, posinf=0.0, neginf=0.0 )
        avg2 = torch.sum( torch.sum( self.w2, dim = 1 ), dim = 0 ) / torch.sum( torch.ne( self.w2, 0.0 ).int() )
        self.w2 = self.w2 / avg2

        with open( "/mnt/ssd2/v7/w_unigrma.pkl", mode="wb" ) as f:
            pickle.dump( self.w1, f )

        with open( "/mnt/ssd2/v7/w_bigrma.pkl", mode="wb" ) as f:
            pickle.dump( self.w2, f )
        
        '''

        #with open( "/mnt/ssd2/v7/w_unigram.pkl", 'rb') as f:
        #    self.w1 = pickle.load(f)

        #with open( "/mnt/ssd2/v7/w_bigram.pkl", 'rb') as f:
        #    self.w2 = pickle.load(f)
        
        if length_max == None:
            print( "length max:", self.length_max )
            print( "avg length:", length_sum / len( self.tokens ) )
    
    # ここで取り出すデータを指定している
    def __getitem__(
        self,
        index: int
    ):
        tokens = self.tokens[index]
        img_file = self.img_file[index] + ".jpg"
        img_path = os.path.join( self.img_directory, img_file ) #index番目の画像のパスを取得
        img = Image.open(img_path) #PIL形式で画像を読み込み
        if img.mode != 'RGB':
            img = img.convert("RGB")
        img = self.transforms(img)
        
        return img, tokens

    # この method がないと DataLoader を呼び出す際にエラーを吐かれる
    def __len__(self) -> int:
        return len(self.tokens)

    def length_max(self):
        return self.length_max

    #def w1(self):
    #    return self.w1

    #def w2(self):
    #    return self.w2

In [None]:
def collate_func(batch: Sequence[Tuple[Union[torch.Tensor, str]]], pad_index, length_max ):
    imgs, tokens = zip(*batch)

    max_length = length_max
    #max_length = 0
    #for target in tokens:
    #    if max_length < len( target ):
    #        max_length = len( target )
    
    targets = []
    lengths = []
    for target in tokens:
        pad_len = max_length - len( target ) 
        #print( "target:", target )
        input2= F.pad( target, (0, pad_len), mode='constant', value = pad_index)
        targets.append( input2 )
        lengths.append( len( target ) )
    
    imgs = torch.stack( imgs, dim = 0 )
    targets = torch.stack( targets, dim = 0 )
    lengths = torch.tensor( lengths, requires_grad = False  )

    #if imgs.dim() != 4:
    #   print( "in collate imgs size:", imgs.size() )
    
    return imgs, targets, lengths

In [None]:
model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained( model_id )
sos_token_id = tokenizer.encode( [ "[unused0]" ] )[1]
eos_token_id = tokenizer.encode( [ "[unused1]" ] )[1]

# 画像のtransformsを定義
transforms = v2.Compose([
    v2.Resize((336, 336)),
    #v2.AutoAugment(),
    #v2.ToTensor(),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    ## Coco データセット 2017 train の平均と標準偏差
    #v2.Normalize((0.456,0.427,0.401),(0.224,0.219,0.231) )
    ## ImageNetデータセットの平均と標準偏差
    #v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    # clip の preprocessor_config.json の平均と標準偏差
    v2.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# v7 データセット
train_dataset = MyDataset( file_path="/mnt/ssd2/v7/data.pkl",
                           img_directory = "/mnt/ssd2/v7/img",
                           #img_directory = "smb://192.168.1.2/img/v7/",
                           transforms=transforms, tokenizer = tokenizer, length_max = 97 )

# Subset samplerの生成
test_set, val_set, train_set = util.generate_subset_test_val_train(
    train_dataset, 0.1, 0.1 )
    
# 学習時にランダムにサンプルするためのサンプラー
train_sampler = SubsetRandomSampler(train_set)

# DataLoaderを生成
collate_func_lambda = lambda x: collate_func(x, tokenizer.pad_token_id, length_max = 97 )

test_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    #batch_size=config.batch_size,
                    batch_size=1,
                    num_workers=0,
                    sampler=test_set,
                    collate_fn=collate_func_lambda)


In [None]:
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
model = CaptioningTransformer( img_size = 336, dim_embedding=1024, length_max = 97, vocab_size=len(tokenizer),
                 tokenizer=tokenizer, dropout=0.1, model_id =model_id).to(device)

PATH = "model/model_bert_large_NAR_PAD_curr.pth"
if os.path.isfile(PATH):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    print( "paramerters were loaded." )

images = torch.randn( ( 2, 3, 336,336 ), device = device )
logits = model( images )

print( logits.size() )

In [None]:
test_num = 21

my_decode = False
#my_decode = True

# Subset samplerの生成
test_set, val_set, train_set = util.generate_subset_test_val_train(
    train_dataset, 0.1, 0.1 )

test_set = test_set[:test_num]

test_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    #batch_size=config.batch_size,
                    batch_size=1,
                    num_workers=0,
                    sampler=test_set,
                    collate_fn=collate_func_lambda)

test_pr_coef = 1

fn = bleu_score.SmoothingFunction().method7

transforms_inv = v2.Compose([
    v2.Normalize((-0.48145466/0.26862954, -0.4578275/0.26130258, -0.40821073/0.27577711), (1/0.26862954,1/0.26130258,1/0.27577711)),
    v2.ToPILImage()
])

# test
with tqdm(test_loader) as pbar:
    pbar.set_description(f'[テスト]')

    # 評価モード
    model.eval()

    test_errors = deque()
    test_bleus = deque()
    n_iter = 0
    for k, (imgs, captions, caption_lengths) in enumerate( pbar ):
        imgs = imgs.to(device)
        captions = captions.to(device)
        #caption_lengths = torch.tensor( caption_lengths ).to(config.device)
        
        with torch.no_grad():
            #logits = model(imgs )
            #hypo_ids = torch.argmax( logits, dim = 2 )
            hypo_ids = model.inference( imgs )
        
        n = 0
        hypo_sentence = []
        ref_sentence = []
        ref_imgs = []
        total_error = 0
        total_token_length = 0
        total_bleu = 0
        for (hypo_id, caption, img ) in zip( hypo_ids, captions, imgs ):
            #hypo_sent = tokenizer.decode( hypo_id.tolist() , skip_special_tokens = True )
            #hypo = tokenizer.tokenize( hypo_sent )
            #reference = tokenizer.decode( caption.tolist(), skip_special_tokens = True )
            #lev_ref = tokenizer.tokenize(reference)
            #hypo = tokenizer.decode( hypo_id.tolist(), skip_special_tokens = True )
            hypo_sent = model.my_decode( hypo_id.tolist(), tokenizer )
            hypo = tokenizer.tokenize( hypo_sent )
            #reference = tokenizer.decode( caption.tolist(), skip_special_tokens = True )
            reference = model.my_decode( caption.tolist(), tokenizer )
            lev_ref = tokenizer.tokenize( reference )

            
            # 認識誤りを計算
            (error, substitute, delete, insert, ref_length) = levenshtein.calculate_error(hypo,lev_ref)
            
            # 誤り文字数を累積する
            total_error += error
            # 文字の総数を累積する
            total_token_length += ref_length

            bleu = bleu_score.sentence_bleu( [reference], hypo_sent, smoothing_function=fn  )
        
            total_bleu += bleu

            inv_img = transforms_inv( img )
            plt.imshow( inv_img )
            plt.axis('off')
            plt.show()
            print( "hypo:", hypo_sent )
            print( "refe:", reference )
            print( "this pic. WER :", error / ref_length )
            print( "this pic. BLEU:", bleu )
            test_errors.append(error / ref_length)
            test_bleus.append(bleu)
            n_iter += 1
            print(f'test number = {n_iter} average, WER = {torch.Tensor(test_errors).mean().item()}, BLEU = {torch.Tensor(test_bleus).mean().item()}')
            print( "\n\n" )
            
                
            #if len(test_errors) > config.moving_avg:
            if len(test_errors) > 100:
                test_errors.popleft()
                test_bleus.popleft()
            pbar.set_postfix({
                #'loss': torch.Tensor(test_losses).mean().item(),
                'WER': torch.Tensor(test_errors).mean().item(),
                'BLEU': torch.Tensor(test_bleus).mean().item()
            })                

# 表示
test_error = np.mean( test_errors )
test_bleu = np.mean( test_bleus )
print(f'test {n_iter} average WER : {test_error}')
print(f'test {n_iter} average BLEU: {test_bleu}')