### ライブラリの準備

###モジュールのインポートとGoogleドライブのマウント

In [1]:
import os
import glob
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import datetime
#from tqdm import tqdm
from tqdm.notebook import tqdm
import pickle
import random
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from PIL import Image
import skimage.transform
from collections import deque
from typing import Sequence, Dict, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import models
import torchvision.transforms as T
import torchvision.datasets as dataset
from torchvision.transforms import v2

from timm.scheduler import CosineLRScheduler
from transformers import  get_linear_schedule_with_warmup

#from transformers import AutoImageProcessor, AutoModel, AutoProcessor, CLIPVisionModel
from transformers import BertTokenizer, BertModel, CLIPVisionModel, BertForPreTraining

import sys

import util
import levenshtein
from nltk import bleu_score
import ssl
from torch.amp import autocast, GradScaler

### 位置エンコーディングの実装

In [2]:
class PositionalEmbedding(nn.Module):
    '''
    位置埋め込み （Positional embedding）
    dim_embedding: 埋込み次元
    max_len      : 入力の最大系列長
    '''
    def __init__(self, dim_embedding: int, max_len: int=2048):
        super().__init__()

        self.pos_emb = nn.Embedding(max_len, dim_embedding)

    '''
    位置エンコーディングの順伝播
    x: 位置エンコーディングを埋め込む対象のテンソル,
       [バッチサイズ, 系列長, 埋め込み次元]
    '''
    def forward(self, x: torch.Tensor):
        seq = x.shape[1]
        positions = torch.arange(start=0, end=seq, step=1, device=x.device).to(torch.long)
        positions = self.pos_emb(positions)[:seq,:]
        
        return positions

### Transformerデコーダの実装

### CaptioningTransformerの実装

In [3]:
def logsumexp(x, dim=1):
    return torch.logsumexp(x.float(), dim=dim).type_as(x)

class DynamicCRF(nn.Module):
    def __init__(self, num_embedding, low_rank=32, beam_size=64):
        super().__init__()

        self.E1 = nn.Embedding(num_embedding, low_rank)
        self.E2 = nn.Embedding(num_embedding, low_rank)

        self.vocb = num_embedding
        self.rank = low_rank
        self.beam = beam_size

    def extra_repr(self):
        return "vocab_size={}, low_rank={}, beam_size={}".format(
            self.vocb, self.rank, self.beam)

    def forward(self, emissions, targets, masks, beam=None):
        numerator = self._compute_score(emissions, targets, masks)
        denominator = self._compute_normalizer(emissions, targets, masks, beam)
        return numerator - denominator

    def forward_decoder(self, emissions, masks=None, beam=None):
        return self._viterbi_decode(emissions, masks, beam)

    def _compute_score(self, emissions, targets, masks=None):
        batch_size, seq_len = targets.size()
        emission_scores = emissions.gather(2, targets[:, :, None])[:, :, 0]  # B x T
        transition_scores = (self.E1(targets[:, :-1]) * self.E2(targets[:, 1:])).sum(2)

        scores = emission_scores
        scores[:, 1:] += transition_scores

        if masks is not None:
            scores = scores * masks.type_as(scores)
        
        return scores.sum(-1)

    def _compute_normalizer(self, emissions, targets=None, masks=None, beam=None):
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        if targets is not None:
            #_emissions = emissions.scatter(2, targets[:, :, None], np.float('inf'))
            _emissions = emissions.scatter(2, targets[:, :, None], float('inf'))
            beam_targets = _emissions.topk(beam, 2)[1]
            beam_emission_scores = emissions.gather(2, beam_targets)
        else:
            beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D; position i - 1, previous step.
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D; position i, current step.
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        for i in range(1, seq_len):
            next_score = score[:, :, None] + beam_transition_matrix[:, i-1]
            next_score = logsumexp(next_score, dim=1) + beam_emission_scores[:, i]

            if masks is not None:
                score = torch.where(masks[:, i:i+1], next_score, score)
            else:
                score = next_score

        # Sum (log-sum-exp) over all possible tags
        return logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions, masks=None, beam=None):
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)

        traj_tokens, traj_scores = [], []
        finalized_tokens, finalized_scores = [], []

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        #print( "score size:", score.size() )
        dummy = torch.arange(beam, device=score.device).expand(*score.size()).contiguous()

        for i in range(1, seq_len):
            traj_scores.append(score)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1]
            _score, _index = _score.max(dim=1)
            _score = _score + beam_emission_scores[:, i]

            if masks is not None:
                score = torch.where(masks[:, i: i+1], _score, score)
                index = torch.where(masks[:, i: i+1], _index, dummy)
            else:
                score, index = _score, _index
            traj_tokens.append(index)

        # now running the back-tracing and find the best
        best_score, best_index = score.max(dim=1)
        finalized_tokens.append(best_index[:, None])
        finalized_scores.append(best_score[:, None])

        for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)):
            previous_index = finalized_tokens[-1]
            finalized_tokens.append(idx.gather(1, previous_index))
            finalized_scores.append(scs.gather(1, previous_index))

        finalized_tokens.reverse()
        finalized_tokens = torch.cat(finalized_tokens, 1)
        finalized_tokens = beam_targets.gather(2, finalized_tokens[:, :, None])[:, :, 0]

        finalized_scores.reverse()
        finalized_scores = torch.cat(finalized_scores, 1)
        finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]

        return finalized_scores, finalized_tokens

In [4]:
class TopLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim, crf_low_rank, crf_beam_size, dropout, padding_idx):
        super(TopLayer, self).__init__()

        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.padding_idx = padding_idx
        print( "in TopLyaer:" )
        self.crf_layer = DynamicCRF(num_embedding = vocab_size, low_rank = crf_low_rank, 
                                    beam_size = crf_beam_size)

        #self.one_more_layer_norm = nn.LayerNorm(embed_dim)
        #self.tgt_word_prj = nn.Linear(self.embed_dim, self.vocab_size)

    def forward(self, src_representation, src_input, tgt_input, is_training):
        '''
            src_representation : bsz x seqlen x embed_dim
            src_input : bsz x seqlen
            tgt_input : bsz x seqlen
        '''
        #assert src_input.size() == tgt_input.size()

        src_input = src_input.transpose(0, 1) # src_len x bsz
        #seqlen, bsz = src_input.size()
        seqlen, bsz = src_input.shape[:2]

        src_representation = F.dropout(src_representation, p=self.dropout, training=is_training)
        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim

        #src = src_representation

        #emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)
        emissions = src_representation
        #log_probs = torch.log_softmax(emissions, -1)
        #assert log_probs.size() == torch.Size([seqlen, bsz, self.vocab_size])

        emissions = emissions.transpose(0, 1) # [bsz x src_len x vocab_size]
        #emission_mask = ~tgt_input.eq(self.padding_idx) # [bsz x src_len] #pad のところは 0 padでないところが 1
        emission_mask = torch.ones_like( tgt_input, dtype=torch.bool ) #全部　pad でないとして 1
        batch_crf_loss = -1 * self.crf_layer(emissions, tgt_input, emission_mask) # [bsz]
        assert batch_crf_loss.size() == torch.Size([bsz])
        #return log_probs, batch_crf_loss
        #return torch.mean( batch_crf_loss )
        #print( "batch_crf_loss:", batch_crf_loss )
        return batch_crf_loss

    def decoding(self, src_representation, src_input):
        '''
            src_representation : bsz x seqlen x embed_dim
            src_input : bsz x seqlen
            tgt_input : bsz x seqlen
        '''
        src_input = src_input.transpose(0, 1) # src_len x bsz
        seqlen, bsz = src_input.size()

        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim
        src = src_representation

        emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)

        emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        _, finalized_tokens = self.crf_layer.forward_decoder(emissions)
        assert finalized_tokens.size() == torch.Size([bsz, seqlen])
        return finalized_tokens

    def length_ratio_decoding(self, src_representation, src_input, length_ratio):
        '''
            src_representation : 1 x seqlen x embed_dim
            src_input : 1 x seqlen
        '''
        src_input = src_input.transpose(0, 1) # src_len x bsz
        seqlen, bsz = src_input.size()

        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim
        src = src_representation

        emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)

        emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        valid_len = int(seqlen * length_ratio) + 1
        valid_emissions = emissions[:, :valid_len+1,:]
        _, finalized_tokens = self.crf_layer.forward_decoder(valid_emissions)
        return finalized_tokens


In [5]:
class CaptioningTransformer(nn.Module):
    '''
    CaptioningTransformerのコンストラクタ
    dim_embedding  : 埋め込み次元
    dim_feedforward: FNNの中間特徴次元
    num_heads      : マルチヘッドアテンションのヘッド数
    num_layers     : Transformerデコーダ層の数
    vocab_size     : 辞書の次元
    null_index     : NULLのID
    dropout        : ドロップアウト確率
    '''
    def __init__(self, img_size: int, dim_embedding: int, length_max: int,
                  vocab_size: int, tokenizer, dropout: float=0.1, model_id: str=''):
        super().__init__()

        #CLIP
        clip_model_id = "openai/clip-vit-large-patch14-336"
        self.clip_model = CLIPVisionModel.from_pretrained(clip_model_id, output_hidden_states = True)
        images = torch.randn( ( 1, 3, img_size, img_size ) )
        memory = self.clip_model( images )
        memory = memory.last_hidden_state
        img_length = memory.size(1)
        clip_dim = memory.size(2)
        
        # Dense Connector
        self.dc_linear = nn.Linear( clip_dim * 3, dim_embedding )
        self.dropout = nn.Dropout( dropout )        
        self.ln_memory = nn.LayerNorm( dim_embedding )

        # Down Sampling
        stride = img_length // ( length_max - 1 )
        self.conv1 = nn.Conv1d( dim_embedding, dim_embedding, 1, stride )
        print( "img_length:", img_length )
        print( "text_length_max:", length_max )
        print( "stride:", stride )
        memory = self.conv1( memory.transpose(1,2) ).transpose(1,2)
        print( "bert in memory size:", memory.size() )

        self.pos_emb = PositionalEmbedding( dim_embedding )
        
        self.bert = BertModel.from_pretrained( model_id )

        ## 単語出力分布計算
        self.ln_outputs = nn.LayerNorm( dim_embedding )
        self.linear = nn.Linear(dim_embedding, vocab_size)

        crf_low_rank = 32
        crf_beam_size = 256
        top_dropout = 0.0
        tgt_padding_idx = tokenizer.pad_token_id
        print( "initialize self.toplayer" )
        self.toplayer = TopLayer( vocab_size, dim_embedding, crf_low_rank, crf_beam_size, top_dropout, tgt_padding_idx )
        
        self.dim_embedding = dim_embedding

    ''' CaptioningTransformerの順伝播処理
    features: 画像特徴量 [バッチサイズ, 埋め込み次元]
    captions: 正解キャプション [バッチサイズ, 系列長]
    '''
    def forward(self, images: torch.Tensor ):

        self.device = images.device

        memory = self.clip_model( images )
        memory = self.dense_connector( memory )
        memory = self.dropout( memory )
        memory = self.ln_memory( memory )

        memory = self.conv1( memory.transpose(1,2) ).transpose(1,2)

        memory += self.pos_emb( memory )
        
        outputs = self.bert( inputs_embeds = memory ).last_hidden_state
        outputs = self.ln_outputs( outputs )
        logits = self.linear( outputs )
        
        return logits

    def inference(self, images: torch.Tensor ):

        self.device = images.device

        memory = self.clip_model( images )
        memory = self.dense_connector( memory )
        memory = self.dropout( memory )
        memory = self.ln_memory( memory )

        memory = self.conv1( memory.transpose(1,2) ).transpose(1,2)

        memory += self.pos_emb( memory )
        
        outputs = self.bert( inputs_embeds = memory ).last_hidden_state
        outputs = self.ln_outputs( outputs )
        emissions = self.linear( outputs )

        #emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        seqlen = emissions.size(1) #added by T.uchi
        length_ratio = 1.0         #added by T.uchi
        #length_ratio = 0.7         #added by T.uchi
        valid_len = int(seqlen * length_ratio) + 1
        valid_emissions = emissions[:, :valid_len+1,:]
        _, finalized_tokens = self.toplayer.crf_layer.forward_decoder(valid_emissions)
        return finalized_tokens

    def dense_connector(self, memory ):
        tmp1 = torch.tensor([], device = self.device )
        tmp2 = torch.tensor([], device = self.device )
        tmp_full = len( memory.hidden_states )
        tmp_half = tmp_full // 2
        for i in range( 0, tmp_half ):
            tmp1 = torch.cat( [tmp1, memory.hidden_states[i][None]], dim = 0 )
        tmp1 = torch.sum(tmp1, dim=0) / tmp_half
        for i in range( tmp_half, tmp_full ):
            tmp2 = torch.cat( [tmp2, memory.hidden_states[i][None]], dim = 0 )
        tmp2 = torch.sum(tmp2, dim=0 ) / ( tmp_full - tmp_half )
        tmp3 = torch.cat([tmp1, tmp2], dim=-1)
        tmp3 = torch.cat( [ memory.last_hidden_state, tmp3], dim = -1 )
        tmp3 = self.dc_linear( tmp3 )
        return tmp3

    def my_decode(self, token_list, tokenizer ):

        def my_index( l, x ):
            if x in l:
                return l.index(x)
            else:
                return -1
        if my_index( token_list, eos_token_id ) != -1:
            token_list = token_list[:my_index( token_list, eos_token_id )]
        else:
            token_list = token_list
            
        text = tokenizer.decode( token_list, skip_special_tokens = True )
        
        return text

In [6]:
class MyDataset(Dataset):
    def __init__(self, file_path: str, img_directory: str, transforms, tokenizer, length_max = None ) -> None:
        super().__init__()
        self.img_directory = img_directory
        self.transforms = transforms
        # TODO: fix to original data
        #画像の前処理
        self.img_file = []
        self.tokens = []
        #vocab_size = len( tokenizer )
        #c1 = torch.zeros( ( vocab_size ) )
        #c2 = torch.zeros( ( vocab_size, vocab_size ) )
        if length_max == None:
            self.length_max = 0
        else:
            self.length_max = length_max
        length_sum = 0

        with open( file_path, 'rb') as f:
            data = pickle.load(f)
        for i, line_data in enumerate( data ):
            if i % 100000 == 0:
                print( "i:", i )
            self.img_file.append( line_data['img_file'] )
            id_tokens = line_data['id_tokens']
            id_tokens.append( eos_token_id )
            id_tokens.append( eos_token_id )
            length_sum += len( id_tokens )
            if length_max != None:
                id_tokens = torch.tensor( id_tokens )[:self.length_max]
            else:
                if self.length_max < len( id_tokens ):
                    self.length_max = len( id_tokens )
                id_tokens = torch.tensor( id_tokens )
            self.tokens.append( id_tokens )
        # w1, w2 を作る時は length_max = None　でお願いします。
        #    for i2 in range( len(id_tokens) ):
        #        if i2 == len( id_tokens ) - 1:
        #            c1[id_tokens[i2]] += 1
        #        else:
        #            c1[id_tokens[i2]] += 1
        #            c2[id_tokens[i2], id_tokens[i2+1] ] += 1
        '''
        c1avg = int( torch.sum( c1 ) / torch.sum( torch.ne( c1, 0 ).int()) )
        c2avg = int( torch.sum( torch.sum( c2, dim = 1 ), dim = 0 ) / torch.sum( torch.ne( c2, 0 ).int() ) )

        c1[0] = c1avg

        c2[:,0] = c2avg
        c2[0,:] = c2avg
        
        sumc1 = torch.sum( c1, dim = 0 )
        sumc2 = torch.sum( torch.sum( c2, dim = 1 ), dim = 0 )

        prob1 = c1 / sumc1
        prob2 = c2 / sumc2

        self.w1 = prob1 ** -0.4
        self.w1 = torch.nan_to_num( self.w1, nan = 0.0, posinf=0.0, neginf=0.0 )
        avg1 = torch.sum( self.w1, dim = 0 ) / torch.sum( torch.ne( self.w1, 0.0 ).int() )
        self.w1 = self.w1 / avg1

        self.w2 = prob2 ** -0.4
        self.w2 = torch.nan_to_num( self.w2, nan = 0.0, posinf=0.0, neginf=0.0 )
        avg2 = torch.sum( torch.sum( self.w2, dim = 1 ), dim = 0 ) / torch.sum( torch.ne( self.w2, 0.0 ).int() )
        self.w2 = self.w2 / avg2

        with open( "/mnt/ssd2/v7/w_unigrma.pkl", mode="wb" ) as f:
            pickle.dump( self.w1, f )

        with open( "/mnt/ssd2/v7/w_bigrma.pkl", mode="wb" ) as f:
            pickle.dump( self.w2, f )
        
        '''

        #with open( "/mnt/ssd2/v7/w_unigram.pkl", 'rb') as f:
        #    self.w1 = pickle.load(f)

        #with open( "/mnt/ssd2/v7/w_bigram.pkl", 'rb') as f:
        #    self.w2 = pickle.load(f)
        
        if length_max == None:
            print( "length max:", self.length_max )
            print( "avg length:", length_sum / len( self.tokens ) )
    
    # ここで取り出すデータを指定している
    def __getitem__(
        self,
        index: int
    ):
        tokens = self.tokens[index]
        img_file = self.img_file[index] + ".jpg"
        img_path = os.path.join( self.img_directory, img_file ) #index番目の画像のパスを取得
        img = Image.open(img_path) #PIL形式で画像を読み込み
        if img.mode != 'RGB':
            img = img.convert("RGB")
        img = self.transforms(img)
        
        return img, tokens

    # この method がないと DataLoader を呼び出す際にエラーを吐かれる
    def __len__(self) -> int:
        return len(self.tokens)

    def length_max(self):
        return self.length_max

    #def w1(self):
    #    return self.w1

    #def w2(self):
    #    return self.w2

In [7]:
def collate_func(batch: Sequence[Tuple[Union[torch.Tensor, str]]], pad_index, length_max ):
    imgs, tokens = zip(*batch)

    max_length = length_max
    #max_length = 0
    #for target in tokens:
    #    if max_length < len( target ):
    #        max_length = len( target )
    
    targets = []
    lengths = []
    for target in tokens:
        pad_len = max_length - len( target ) 
        #print( "target:", target )
        input2= F.pad( target, (0, pad_len), mode='constant', value = pad_index)
        targets.append( input2 )
        lengths.append( len( target ) )
    
    imgs = torch.stack( imgs, dim = 0 )
    targets = torch.stack( targets, dim = 0 )
    lengths = torch.tensor( lengths, requires_grad = False  )

    #if imgs.dim() != 4:
    #   print( "in collate imgs size:", imgs.size() )
    
    return imgs, targets, lengths

###学習におけるハイパーパラメータやオプションの設定

In [8]:
model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained( model_id )
sos_token_id = tokenizer.encode( [ "[unused0]" ] )[1]
eos_token_id = tokenizer.encode( [ "[unused1]" ] )[1]

print( "sos_tokeni_d:", sos_token_id )
print( "eos_token_id:", eos_token_id )

class ConfigTrain(object):
    '''
    ハイパーパラメータ、システム共通変数の設定
    '''
    def __init__(self):

        # ハイパーパラメータ
        self.img_size = 336
        self.dim_embedding = 1024   # 埋め込み層の次元
        self.length_max = 97
        #self.lr = 5e-5            # 学習率
        #self.lr = 2e-5            # 学習率
        self.lr_clip = 2e-7
        self.lr_bert = 2e-5            # 学習率
        self.lr_others = 4e-5
        #self.lr_top = 1e-4
        #self.lr = 5e-6            # 学習率
        self.dropout = 0.1         # dropout確率
        #self.batch_size = 128       # ミニバッチ数
        self.batch_size = 32       # ミニバッチ数
        #self.batch_size = 16       # ミニバッチ数
        #self.batch_size = 8       # ミニバッチ数
        #self.batch_size = 4       # ミニバッチ数
        #self.batch_size = 1       # ミニバッチ数
        #self.num_epochs = 100       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        #self.num_epochs = 100       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        #self.num_epochs = 60       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        self.num_epochs = 10       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        self.use_amp = True
        #self.use_amp = False
        #self.use_saved_pth = True
        self.use_saved_pth = False
        self.model_id = "google-bert/bert-large-uncased"
        self.vocab_size = len( tokenizer )
        self.weight_decay = 0.1
        self.betas = (0.9, 0.999 )
        self.warmup = 0.1
        #self.alpha = 1.0
        self.alpha = 0.5
        self.window_size = 3
        
        # パスの設定
        self.img_directory = '/mnt/ssd2/v7/img'
        self.anno_file = '/mnt/ssd2/v7/data.pkl'
        self.save_directory = './model'

        # 検証に使う学習セット内のデータの割合
        self.test_ratio = 0.1
        self.val_ratio = 0.1
        #self.val_ratio = 0.0004
        #self.test_ratio = 0.0004
        
        # 学習に使うデバイス
        #self.device = 'cuda'
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        #self.device = 'cpu'
        
        # データローダーに使うCPUプロセスの数
        #self.num_workers = 4
        self.num_workers = 0 if self.device == torch.device('cpu') else 12
        #self.num_workers = 0
        
        # 移動平均で計算する損失の値の数
        self.moving_avg = 100

sos_tokeni_d: 1
eos_token_id: 2


In [9]:
#config = ""
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
## 辞書（単語→単語ID）の読み込み
#with open('../PreTrain_Decoder/translateDatasetNTT_blank4_pad0/word_to_id2.pkl', 'rb') as f:
#    word_to_id = pickle.load(f)
#max_idx_en = len( word_to_id )
#word_to_id['<mask>'] = max_idx_en
#mask_value = word_to_id['<mask>']
#start_idx = word_to_id['<start>']
#bert_model_path = 'models--google-bert--bert-large-uncased/snapshots/6da4b6a26a1877e173fca3225479512db81a5e5b'
#tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path = bert_model_path )
model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained(model_id)
model = CaptioningTransformer( img_size = 336, dim_embedding=1024, length_max = 97, vocab_size=len(tokenizer),
                 tokenizer=tokenizer, dropout=0.1, model_id =model_id).to(device)

#images = torch.randint( 0, 255, size = (10,3,256,256) )
images = torch.randn( ( 1, 3, 336,336 ) )
outputs = model( images )

print( outputs.size() )

img_length: 577
text_length_max: 97
stride: 6
bert in memory size: torch.Size([1, 97, 1024])
initialize self.toplayer
in TopLyaer:
torch.Size([1, 97, 30522])


### 学習率スケジューラ

### 学習を行う関数

In [9]:
def calc_loss_nll( train_log_prob_matrix, train_batch_tgt ):
    bsz, tgt_len = train_batch_tgt.size()
    
    train_log_prob_matrix_unbind = torch.unbind(train_log_prob_matrix, dim = 0)
    assert len(train_log_prob_matrix_unbind) == tgt_len
    train_truth_tgt_unbind = torch.unbind(train_batch_tgt, dim = 1)
    assert len(train_truth_tgt_unbind) == tgt_len

    nll_loss_list = []

    for index in range(tgt_len):
        curr_step_log_prob = train_log_prob_matrix_unbind[index]
        curr_step_tgt = train_truth_tgt_unbind[index].view(bsz)
        #one_nll_loss = NLL(curr_step_log_prob, curr_step_tgt)
        one_nll_loss = criterion_nll(curr_step_log_prob, curr_step_tgt)
        assert one_nll_loss.size() == torch.Size([bsz])
        nll_loss_list.append(one_nll_loss)

    nll_loss_matrix = torch.stack(nll_loss_list, dim = 1)
    #assert nll_loss_matrix.size() == torch.Size([bsz, tgt_len])
    #tgt_padding_matrix = ~train_batch_tgt.eq(data.tgt_padding_idx)
    #tgt_padding_matrix = tgt_padding_matrix.type(nll_loss_matrix.type())
    #assert tgt_padding_matrix.size() == nll_loss_matrix.size()
    #nll_loss_matrix = nll_loss_matrix * tgt_padding_matrix

    train_nll_loss = nll_loss_matrix.sum(-1)
    #assert train_batch_crf_loss.size() == train_nll_loss.size()

    return train_nll_loss

In [9]:
'''
def calc_loss_ca( logits, captions, c ):

    eps = 1e-4

    B, T, V = logits.size()
    
    one_hot_cap = F.one_hot( captions, num_classes = len( tokenizer ) ) # B * T * V

    lcabi = torch.zeros( (B, T),  dtype=torch.float, device = logits.device )
    zeroB = torch.zeros( (B),  dtype=torch.float, device = logits.device )
    for i in range( T ):
        tmp = torch.stack( [ torch.log( (  1.0 - torch.exp( torch.sum( logits[:,i,:] * one_hot_cap[:,j,:], dim = 1 ) ) / \
            ( torch.sum( torch.exp(logits[:,i]), dim = 1 ) + eps ) + eps ) ) if j != i  else zeroB \
            for j in range( max( 0, i - c ), min(  T, i + c ) ) ], dim = 0 )  # window 幅 * B
        lcabi[:,i] = torch.sum( tmp, dim = 0 ) # wubdiow 幅 * B を window 幅について sum
    
    # lcabi は B * T
    
    #pbi = torch.exp( torch.sum( logits * one_hot_cap, dim = 2 ) ) / ( torch.sum( torch.exp( logits ), dim = 2 ) + eps ) # B * T
    #logpbi = torch.log( pbi ) # B * T
    #print( "logpbi:", torch.mean( torch.mean( logpbi )))
    logpbi = torch.sum( logits * one_hot_cap, dim = 2 ) - torch.log( torch.sum( torch.exp(logits), dim = 2 ) + eps ) # B * T
    #print( "logpbi:", torch.mean( torch.mean(logpbi)) )
    
    loss_ca = -torch.mean( torch.mean( logpbi + lcabi, dim = 1 ), dim = 0 )
    
    return loss_ca
'''

In [8]:
logits = torch.randn( ( 4, 97, 30522 ) )
captions = torch.randint( 0, len( tokenizer ), size=( 4,97 ) )

loss_ca = calc_loss_ca( logits, captions, 3 )

print( loss_ca )

tensor(10.8275)


In [10]:
config = ConfigTrain()

#tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path = config.bert_model_path)
#model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained(config.model_id)

# 辞書サイズを保存
vocab_size = len( tokenizer )

# モデル出力用のディレクトリを作成
os.makedirs(config.save_directory, exist_ok=True)

# 画像のtransformsを定義
transforms = v2.Compose([
    v2.Resize((336, 336)),
    v2.AutoAugment(),
    #v2.ToTensor(),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    ## Coco データセット 2017 train の平均と標準偏差
    #v2.Normalize((0.456,0.427,0.401),(0.224,0.219,0.231) )
    # ImageNetデータセットの平均と標準偏差
    #v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    # Clip Model の config から引用。
    v2.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# v7 データセット
train_dataset = MyDataset( file_path=config.anno_file,
                           img_directory = config.img_directory,
                           transforms=transforms,tokenizer=tokenizer, length_max = config.length_max)

# Subset samplerの生成
test_set, val_set, train_set = util.generate_subset_test_val_train(
    train_dataset, config.test_ratio, config.val_ratio )
    
# 学習時にランダムにサンプルするためのサンプラー
train_sampler = SubsetRandomSampler(train_set)

# DataLoaderを生成
collate_func_lambda = lambda x: collate_func(x, tokenizer.pad_token_id, config.length_max)
train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                    sampler=train_sampler,
                    collate_fn=collate_func_lambda)
val_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                    sampler=val_set,
                    collate_fn=collate_func_lambda)
test_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    #batch_size=config.batch_size,
                    batch_size=1,
                    num_workers=config.num_workers,
                    sampler=test_set,
                    collate_fn=collate_func_lambda)


print( "config.device:", config.device )
print( "学習セット数:",len( train_loader ) )
print( "評価セット数:",len( val_loader ))
print( "テストセット数:",len( test_loader ))
print( "use_amp:", config.use_amp )
print( "use_saved_pth:", config.use_saved_pth )

# モデルの定義
model = CaptioningTransformer( config.img_size,
    config.dim_embedding, config.length_max, config.vocab_size,
    tokenizer, config.dropout, config.model_id)
model.to(config.device)
#crf_low_rank = 32
#crf_beam_size = 256
#top_dropout = 0.0
#tgt_padding_idx = tokenizer.pad_token_id
#toplayer = TopLayer( vocab_size, dim_embedding, crf_low_rank, crf_beam_size, top_dropout, tgt_padding_idx )


# 損失関数の定義
#criterion = nn.CrossEntropyLoss( ignore_index = tokenizer.pad_token_id, reduction = 'mean' )
#criterion = nn.CrossEntropyLoss( reduction = 'mean' )
log_softmax = nn.LogSoftmax( dim = 2 )
#softmax = nn.Softmax( dim = 2 )
criterion_nll = nn.NLLLoss( reduction = 'none' )
#criterion_nll = nn.NLLLoss( )
#criterion = nn.CrossEntropyLoss()

# 最適化手法の定義
# 最適化手法の定義
# Optimizerの生成, clipとそうでないモジュールとの
# パラメータで異なる学習率を適用
params_clip = []
params_bert = []
params_others = []
for name, parameter in model.named_parameters():
    if parameter.requires_grad:
        if 'clip_model' in name:
            params_clip.append(parameter)
        elif 'bert' in name:
            params_bert.append(parameter)
        else:
            params_others.append(parameter)
param_groups = [
    {'params': params_clip, 'lr': config.lr_clip},
    {'params': params_bert, 'lr': config.lr_bert},
    {'params': params_others, 'lr': config.lr_others}]
#optimizer = torch.optim.AdamW( model.parameters() , lr=config.lr)
#optimizer = torch.optim.AdamW( param_groups, weight_decay = config.weight_decay, betas=config.betas )
optimizer = torch.optim.AdamW( param_groups, weight_decay = config.weight_decay, betas=config.betas )
#t_optimizer = torch.optim.AdamW( toplayer.parameters(), lr = config.lr_top, weight_decay = config.weight_decay, betas=config.betas )


# 全ステップ数
num_global_steps = len( train_loader ) * config.num_epochs
print( "num_global_steps:", num_global_steps )
num_warmup_steps = num_global_steps * config.warmup
print( "num_warmup_steps:", num_warmup_steps )
#スケジューラーの定義
scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps, num_global_steps )    
#t_scheduler = get_linear_schedule_with_warmup( t_optimizer, num_warmup_steps, num_global_steps )    


PATH = "model/model_bert_large_NAR_PAD_curr.pth"
print( "use_saved_pth:", config.use_saved_pth )
print( "exist saved_pth:", os.path.isfile(PATH) ) 
use_saved_pth = config.use_saved_pth
if use_saved_pth and os.path.isfile(PATH):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ## optimizerのstateを現在のdeviceに移す。これをしないと、保存前後でdeviceの不整合が起こる可能性がある。
    #for state in optimizer.state.values():
        #for k, v in state.items():
            #if isinstance(v, torch.Tensor):
                #state[k] = v.to(device)
    begin_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    global_step = checkpoint['global_step']    
else:
    begin_epoch = 0
    global_step = 0

print( "begin_epoch:", begin_epoch )
print( "global_ste:", global_step )

len_tr_loader = len( train_loader )
train_param = len_tr_loader // 6
#train_param = len_tr_loader // 100
len_val_loader = len( val_loader )
#train_param = len_val_loader // 3
val_param = len_val_loader // 3
print( "train_param:", train_param )
print( "val_param:", val_param )

print( "epochs:", config.num_epochs )
print( "batch_size:", config.batch_size )
print( "lr_clip:", config.lr_clip )
print( "lr_bert:", config.lr_bert )
print( "lr_others:", config.lr_others )
#print( "lr_top   :", config.lr_top )
print( "weight_decay:", config.weight_decay )
print( "betas:", config.betas )

# 学習経過の書き込み
now = datetime.datetime.now()
train_loss_file = '{}/MyOriginal_train_loss_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))
with open(train_loss_file, 'a') as f:
    print(f'{len_tr_loader}', file=f) 
print( "train_loss_file:", train_loss_file )
val_loss_file = '{}/MyOriginal_val_loss_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))
with open(val_loss_file, 'a') as f:
    print(f'{len_val_loader}', file=f) 
norm_file = '{}/norm_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))

# 学習
val_loss_best = float('inf')

fn = bleu_score.SmoothingFunction().method7

# AMP用のスケーラー
scaler = GradScaler(enabled=config.use_amp)

for epoch in range(config.num_epochs):
    with tqdm(train_loader) as pbar:
    #with tqdm(val_loader) as pbar:
        pbar.set_description(f'[エポック {epoch + 1}]')

        # 学習モードに設定
        model.train()

        train_losses = deque()
        train_a = deque()
        train_b = deque()
        train_errors = deque()
        train_bleus = deque()
        for n_batch, (imgs, captions, caption_lengths) in enumerate( pbar ):
            # ミニバッチを設定
            imgs = imgs.to(config.device)
            captions = captions.to(config.device)
                
            optimizer.zero_grad()

            # 最後の単語から次を予測する必要はないため最後の単語を除外
            with autocast(str(config.device),enabled=config.use_amp):
                outputs = model( imgs )

                # 損失の計算
                # 単語軸が第1軸である必要があるため、転置
                #outputs = outputs.transpose(1, 2)
                src_representation = outputs
                #src_input = imgs
                src_input = outputs
                tgt_input = captions
                train_batch_crf_loss = model.toplayer( src_representation, src_input, tgt_input, is_training = True )
                #a = torch.mean( train_batch_crf_loss ) / 200
                a = torch.mean( train_batch_crf_loss ) # mean over bsz
                log_probs = log_softmax( outputs ).transpose(0,1)
                ##log_probs = log_softmax( outputs.transpose(0,1) )
                ##bsz, seq = captions.size()
                train_nll_loss = calc_loss_nll( log_probs, captions )
                b = config.alpha * torch.mean( train_nll_loss )
                #b = config.alpha * train_nll_loss
                #b = torch.tensor( 0 )
                #train_nll_loss = calc_loss_ca( outputs, captions, config.window_size )
                #loss_ca = calc_loss_ca( outputs, captions, config.window_size )
                #b = config.alpha * loss_ca
                loss = a + b
                #loss = a 
            
                #loss = loss_CRF + config.alpha * loss_CA

            hypo_ids = torch.argmax( outputs, dim = 2 )
            
            scaler.scale(loss).backward()
            #scaler.unscale_(optimizer)
            #clip_grad_threshold = 1.0
            #torch.nn.utils.clip_grad_norm_(\
            #        model.parameters(),
            #        clip_grad_threshold)
            # オプティマイザにより，パラメータを更新する
            scaler.step(optimizer)
            scaler.update()            
            scheduler.step()

            #for name, param in model.named_parameters():
            #    print( name )
            
            norm0 = torch.sqrt( torch.norm( model.clip_model.vision_model.encoder.layers[0].self_attn.q_proj.weight.grad, p = 2 ) ).item()
            norm1 = torch.sqrt( torch.norm( model.bert.encoder.layer[23].attention.self.query.weight.grad, p = 2 ) ).item()
            norm_mean = torch.mean( torch.stack ([ torch.sqrt( torch.norm( param.grad, p = 2 ) ) \
                                                  for param in model.parameters() if param.grad is not None ] ) ).item()
            with open(norm_file, 'a') as f:
                print( "epcoch:", epoch, ", step:", global_step, ", norm0:", norm0, ", norm1:", norm1, ", norm_mean:", norm_mean, file=f  )
                f.flush()
            global_step += 1

            n = 0
            hypo_sentence = []
            ref_sentence = []
            hypo_sentence1 = []
            ref_sentence1 = []
            total_error = 0
            total_token_length = 0
            total_bleu = 0
            n2 = 0
            for (hypo_id, caption) in zip( hypo_ids, captions ):
                #hypo = tokenizer.decode( hypo_id.tolist(), skip_special_tokens = True )
                hypo = model.my_decode( hypo_id.tolist(), tokenizer )
                hypo_tokens = tokenizer.tokenize( hypo )
                #reference = tokenizer.decode( caption.tolist(), skip_special_tokens = True )
                reference = model.my_decode( caption.tolist(), tokenizer )
                ref_tokens = tokenizer.tokenize( reference )
                        
                # 認識誤りを計算
                (error, substitute, 
                    delete, insert, ref_length) = \
                    levenshtein.calculate_error(hypo_tokens,
                                                    ref_tokens)
                
                # 誤り文字数を累積する
                total_error += error
                # 文字の総数を累積する
                total_token_length += ref_length

                bleu = bleu_score.sentence_bleu( [reference], hypo, smoothing_function=fn  )
        
                total_bleu += bleu                    
                    
                if n < 1 and n_batch == len( train_loader ) - 1 :
                    hypo_sentence.append( hypo )
                    ref_sentence.append( reference )
                if n < 1 and n_batch % train_param == 0:
                    hypo_sentence1.append( hypo )
                    ref_sentence1.append( reference )
                    
                n += 1
                n2 += 1
            
            avg_error = total_error / total_token_length * 100
            avg_bleu = total_bleu / n2 * 100
                
            # 学習時の損失をログに書き込み
            train_losses.append(loss.item())
            train_a.append(a.item())
            train_b.append(b.item())
            train_errors.append( avg_error )
            train_bleus.append( avg_bleu )
            #train_ciders.append( avg_cider )
            if len(train_losses) > config.moving_avg:
                train_losses.popleft()
                train_a.popleft()
                train_b.popleft()
                train_errors.popleft()
                train_bleus.popleft()
                #train_ciders.popleft()
            mean_loss = torch.Tensor(train_losses).mean().item()
            mean_a = torch.Tensor(train_a).mean().item()
            mean_b = torch.Tensor(train_b).mean().item()
            mean_error = torch.Tensor(train_errors).mean().item()
            mean_bleu = torch.Tensor(train_bleus).mean().item()
            pbar.set_postfix({
                'loss': mean_loss,
                'crf': mean_a,
                'ca': mean_b,
                'WER': mean_error,
                'BLEU': mean_bleu,
                #'CIDER': torch.Tensor(train_ciders).mean().item()
            })
            with open(train_loss_file, 'a') as f:
                print(f'{epoch}, {mean_loss}, {mean_a}, {mean_b}, {mean_error}, {mean_bleu}', file=f)
            print_flag = 1
            for ( hypo_se, ref_se ) in zip( hypo_sentence1, ref_sentence1 ):
                if print_flag == 1:
                    print( "lr:", optimizer.param_groups[0]["lr"] )
                    print_flag = 0
                print(f'Train epoch = {epoch}, loss = {mean_loss}, WER = {mean_error}, BLEU = {mean_bleu}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
                    
            for ( hypo_se, ref_se ) in zip( hypo_sentence, ref_sentence ):
                print(f'Train epoch = {epoch}, loss = {mean_loss}, WER = {mean_error}, BLEU = {mean_bleu}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
    # 学習率を表示
    print(f'学習率: {optimizer.param_groups[0]['lr']}')
    train_loss = np.mean(train_losses)
    train_a1 = np.mean(train_a)
    train_b1 = np.mean(train_b)
    train_error = np.mean(train_errors )
    train_bleu = np.mean(train_bleus )
    print(f'Train loss: {train_loss}')
    print(f'Train crf: {train_a1}')
    print(f'Train ca: {train_b1}')
    print(f'Train WER: {train_error}')        
    print(f'Train BLEU: {train_bleu}')

    # 検証
    with tqdm(val_loader) as pbar:
        pbar.set_description(f'[検証]')

        # 評価モード
        model.eval()

        #val_losses = []
        #val_losses = deque()
        #val_a = deque()
        #val_b = deque()
        val_errors = deque()
        val_bleus = deque()
        for n_batch, (imgs, captions, caption_lengths) in enumerate( pbar ):

            # ミニバッチを設定
            imgs = imgs.to(config.device)
            captions = captions.to(config.device)
            #caption_lengths = torch.tensor( caption_lengths ).to(config.device)
                
            with torch.no_grad():
                #outputs = model( imgs )
                finalized_tokens = model.inference( imgs )
                #hypo_ids = torch.argmax( outputs, dim = 2 )
                hypo_ids = finalized_tokens
                ## 損失の計算
                ## 単語軸が第1軸である必要があるため、転置
                ##outputs = outputs.transpose(1, 2)
                #src_representation = outputs
                ##src_input = imgs
                #src_input = outputs
                #tgt_input = captions
                #train_batch_crf_loss = model.toplayer( src_representation, src_input, tgt_input, is_training = True )
                #log_probs = log_softmax( outputs.transpose(0,1) )
                ##bsz, seq = captions.size()
                ##loss_CA = - criterion_nll( log_prob.view( bsz * seq, -1 ) , captions.view( bsz * seq ) )
                #train_nll_loss = calc_loss_ca( log_probs, captions )
                #a = torch.mean( train_batch_crf_loss )
                #b = config.alpha * torch.mean( train_nll_loss )
                #loss = a + b
               
            n = 0
            hypo_sentence = []
            ref_sentence = []
            hypo_sentence1 = []
            ref_sentence1 = []
            total_error = 0
            total_token_length = 0
            total_bleu = 0
            n2 = 0
            for (hypo_id, caption) in zip( hypo_ids, captions ):
                #hypo = tokenizer.decode( hypo_id.tolist(), skip_special_tokens = True )
                hypo = model.my_decode( hypo_id.tolist(), tokenizer )
                hypo_tokens = tokenizer.tokenize( hypo )
                #reference = tokenizer.decode( caption.tolist(), skip_special_tokens = True )
                reference = model.my_decode( caption.tolist(), tokenizer )
                ref_tokens = tokenizer.tokenize( reference )

                        
                # 認識誤りを計算
                (error, substitute, 
                    delete, insert, ref_length) = \
                    levenshtein.calculate_error(hypo_tokens,
                                                ref_tokens)
                    
                # 誤り文字数を累積する
                total_error += error
                # 文字の総数を累積する
                total_token_length += ref_length

                bleu = bleu_score.sentence_bleu( [reference], hypo, smoothing_function=fn  )
        
                total_bleu += bleu

                if n < 1 and n_batch == len( val_loader ) - 1:
                    hypo_sentence.append( hypo )
                    ref_sentence.append( reference )
                        
                if n < 1 and n_batch % val_param == 0:
                    hypo_sentence1.append( hypo )
                    ref_sentence1.append( reference )
                    
                n += 1
                n2 += 1
                
            avg_error = total_error / total_token_length * 100                    
            avg_bleu = total_bleu / n2 * 100

            # 学習時の損失をログに書き込み
            #val_losses.append(loss.item())
            #val_a.append(a.item())
            #val_b.append(b.item())
            val_errors.append( avg_error )
            val_bleus.append( avg_bleu )
            if len(val_errors) > config.moving_avg:
                #val_losses.popleft()
                #val_a.popleft()
                #val_b.popleft()
                val_errors.popleft()
                val_bleus.popleft()
            #mean_loss = torch.Tensor(val_losses).mean().item()
            #mean_a = torch.Tensor(val_a).mean().item()
            #mean_b = torch.Tensor(val_b).mean().item()
            mean_error = torch.Tensor(val_errors).mean().item()
            mean_bleu = torch.Tensor(val_bleus).mean().item()
            pbar.set_postfix({
                #'loss': mean_loss,
                #'crf': mean_a,
                #'ca': mean_b,
                'WER': mean_error,
                'BLEU': mean_bleu,
            })
            # Validation Lossをログに書き込み
            with open(val_loss_file, 'a') as f:
                print(f'{epoch}, {mean_error}, {mean_bleu}', file=f)
            
            for ( hypo_se, ref_se ) in zip( hypo_sentence1, ref_sentence1 ):
                print(f'Val epoch = {epoch}, WER = {mean_error}, BLEU = {mean_bleu}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
                    
            for ( hypo_se, ref_se ) in zip( hypo_sentence, ref_sentence ):
                print(f'Val epoch = {epoch}, WER = {mean_error}, BLEU = {mean_bleu}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
                    
    # Loss 表示
    #val_loss = np.mean(val_losses)
    val_error = np.mean( val_errors )
    val_bleu = np.mean( val_bleus )
    #print(f'Validation loss: {val_loss}')
    print(f'Validation WER: {val_error}')
    print(f'Validation BLEU: {val_bleu}')

    ## より良い検証結果が得られた場合、モデルを保存
            
    # モデルを保存
    torch.save({'epoch': epoch,
                'global_step': global_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,},
        f'{config.save_directory}/model_bert_large_NAR_PAD_curr.pth')
    ## モデルを保存
        
# モデルを保存
torch.save({'epoch': epoch,
    'global_step': global_step,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'loss': loss,},
    f'{config.save_directory}/model_bert_large_NAR_PAD_final.pth')


i: 0
i: 100000
i: 200000
i: 300000
i: 400000
i: 500000
config.device: cuda:0
学習セット数: 12687
評価セット数: 1586
テストセット数: 50744
use_amp: True
use_saved_pth: False
img_length: 577
text_length_max: 97
stride: 6
bert in memory size: torch.Size([1, 97, 1024])
initialize self.toplayer
in TopLyaer:
num_global_steps: 126870
num_warmup_steps: 12687.0
use_saved_pth: False
exist saved_pth: False
begin_epoch: 0
global_ste: 0
train_param: 2114
val_param: 528
epochs: 10
batch_size: 32
lr_clip: 2e-07
lr_bert: 2e-05
lr_others: 4e-05
weight_decay: 0.1
betas: (0.9, 0.999)
train_loss_file: ./model/MyOriginal_train_loss_20251112_013826.csv


  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 1.5764168046031368e-11
Train epoch = 0, loss = 2324.23681640625, WER = 248.5, BLEU = 13.77786636352539
refe: on the left side, there is a person holding a glass and standing. on the right side, there is a person in white color shirt, wearing a cap, smiling and standing. and the background is blurred.
hypo: tidessters prosecutorsters न grassland opportunity pretended grasslandtok 09 hurricanelston न scenario tak healy mug prosecutor orioles prosecutor am tak luggagesters pretended prosecutor prosecutor prosecutor mug prosecutor scenario prosecutor bai opportunity unnecessary prosecutorstersdilly scenariostersdilly inspire prosecutor prosecutor luggage luggagelston pretended scenario scenario prosecutor scenario scenariostersstersstersdilly consisting न prosecutor academies hurricane न prosecutor prosecutor न न न class prosecutor prosecutordilly न scenario न motive scenario न inspire prosecutor prosecutor scenariolston opportunity न न न groverdillysters scenariolston prosecutor inspi

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 0, WER = 82.83582305908203, BLEU = 32.22361373901367
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see we can see windows some, there are, there are, plants tent street flowers some tent street flowers some, there are
Val epoch = 0, WER = 83.60577392578125, BLEU = 35.748111724853516
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people on the wearing we can see on the wearing trees person few some tent street flowers some, there are
Val epoch = 0, WER = 8

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 1.9999824842577267e-07
Train epoch = 1, loss = 355.1583251953125, WER = 84.11614227294922, BLEU = 20.631010055541992
refe: in this picture we can see few houses, and few cars on the road, and also we can see a sign board, plants and cables.
hypo: in this image i can see the building a a,,, and, and,, and, and, and and and and the and and and and
lr: 1.9629542050918262e-07
Train epoch = 1, loss = 322.8299255371094, WER = 82.82540893554688, BLEU = 19.79180145263672
refe: in this image i can see on the left side a man is smiling, he wore t - shirt, spectacles. beside him a woman is also smiling, at the top it is the sky.
hypo: in this image i the see two man and and and and and and and... and
lr: 1.9259259259259257e-07
Train epoch = 1, loss = 298.4232482910156, WER = 82.66929626464844, BLEU = 19.537235260009766
refe: in this image, we can see shells and a tube on a cloth.
hypo: in this image i can see a, on a and....
lr: 1.8888976467600255e-07
Train epoch = 1, loss = 291.9125366210937

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 1, WER = 77.68656921386719, BLEU = 30.80805206298828
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image i can see trees, plants, trees, there are trees, there are trees, we can see the sky.
Val epoch = 1, WER = 76.66458892822266, BLEU = 35.931549072265625
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people wearing flags flags flags flags flags flags flags flags flags flags flags flags flags flags flags flags flags flags flags their hands.
Val epoch = 1, WER = 77.03

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 1.7777602620355044e-07
Train epoch = 2, loss = 258.4427185058594, WER = 82.35736083984375, BLEU = 22.3997859954834
refe: as we can see in the image there is a wooden table. on table there is a paper and spectacles. on paper there is something written.
hypo: in this image i a a a,, a and and and and a
lr: 1.740731982869604e-07
Train epoch = 2, loss = 230.9816436767578, WER = 83.2561264038086, BLEU = 18.892454147338867
refe: in this image i can see a man holding a guitar and wearing a brown color shirt. on the left side a man stand and wearing a cap. on the left corner there is a text written on the image.
hypo: in this image a guitar a a guitar a guitar guitar guitar a
lr: 1.7037037037037037e-07
Train epoch = 2, loss = 227.8808135986328, WER = 83.55486297607422, BLEU = 18.98621368408203
refe: in this image there are few persons standing there are few bottles in fort of the persons on the table, at the side there is a popcorn cane, at the back ground there is a window and a wall.
hyp

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 2, WER = 76.11940002441406, BLEU = 34.012752532958984
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see few plants, plants, plants, plants, we can see the background we can see the sky.
Val epoch = 2, WER = 76.2474365234375, BLEU = 36.7791633605957
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people standing and holding some their hands. in the their hands. in the background we can see the background we can see a flag.
Val epoch = 2, WER = 76.4533920288

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 1.555538039813282e-07
Train epoch = 3, loss = 209.12078857421875, WER = 82.47248077392578, BLEU = 25.95050621032715
refe: in this picture we can see a woman and she is holding a cat.
hypo: in this image there can a a a and a a and and and
lr: 1.5185097606473817e-07
Train epoch = 3, loss = 224.85279846191406, WER = 82.86471557617188, BLEU = 21.534082412719727
refe: in the foreground of this picture we can see a man seems to be standing. in the background we can see some other items.
hypo: in this image there can a a person. the and and and and
lr: 1.4814814814814815e-07
Train epoch = 3, loss = 232.16876220703125, WER = 82.41110229492188, BLEU = 21.437036514282227
refe: in this image there is a man standing on the ground. he is holding a golf stick in his hand. there is grass on the ground. in the background there are trees and a building. at the top there is the sky.
hypo: in this image of the a a person a the the,,,,,, the the the the,, the,,,,,,
lr: 1.444453202315581e-07
Train epo

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 3, WER = 77.6119384765625, BLEU = 27.903629302978516
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see plants, plants, plants, we can see the background there are trees.
Val epoch = 3, WER = 76.19596862792969, BLEU = 31.117887496948242
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people standing and holding a flag army their hands. in the background we can see the background there are trees.
Val epoch = 3, WER = 76.4884033203125, BLEU = 30.7411136627197

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 1.3333158175910599e-07
Train epoch = 4, loss = 177.74798583984375, WER = 77.93696594238281, BLEU = 25.757362365722656
refe: in this image we can see a boat in which a person is sitting is floating on the water. here we can see this person is standing on the rock. in the background, we can see trees and sky with the clouds.
hypo: in this image there can see there, on water water the the the the the the the the the the the the
lr: 1.2962875384251594e-07
Train epoch = 4, loss = 201.44859313964844, WER = 82.182861328125, BLEU = 21.54418182373047
refe: here we can see a toy and puzzles. there is a carpet on the floor. in the background we can see a wall.
hypo: in this image of can see a a a a the, the the the, the the the the the the the
lr: 1.2592592592592592e-07
Train epoch = 4, loss = 195.2248992919922, WER = 82.06240844726562, BLEU = 23.070087432861328
refe: in this picture we can see a person is holding a dog belt and walking and the dog is also walking on the path.
hypo: in this i

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 4, WER = 74.25373077392578, BLEU = 31.49457359313965
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see a group of the foreground of the background we can see the background we can see the sky.
Val epoch = 4, WER = 75.41302490234375, BLEU = 34.570377349853516
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people standing and holding a flag. in their hands. in the background we can see the background we can see a flag.
Val epoch = 4, WER = 75.88600158691406

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 1.1110935953688376e-07
Train epoch = 5, loss = 197.68997192382812, WER = 82.30277252197266, BLEU = 18.53028106689453
refe: in the middle of the picture, we see a toilet seat. on the right side, we see the tissue rolls and a sanitizer. behind that, we see a wall in brown color. in the background, we see a brown wall. this picture might be clicked in the washroom.
hypo: in this image i can a a a, a,,,,,,,,,,,
lr: 1.0740653162029374e-07
Train epoch = 5, loss = 203.94483947753906, WER = 82.1174545288086, BLEU = 22.897539138793945
refe: in this image, there is a person wearing clothes and cap. this person is standing in front of this mic and playing a guitar.
hypo: in this image a in a a a a a a a a and and
lr: 1.0370370370370369e-07
Train epoch = 5, loss = 209.6751708984375, WER = 81.61355590820312, BLEU = 22.411489486694336
refe: here in this picture we can see rock stones present in the middle of the river over there and we can see people travelling in the water with the help of life

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 5, WER = 75.74626922607422, BLEU = 29.367403030395508
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see plants, plants, plants, trees, there are trees.
Val epoch = 5, WER = 75.7435302734375, BLEU = 32.318565368652344
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people holding a flag. in their hands. in their hands. in the background we can see a flag.
Val epoch = 5, WER = 76.16816711425781, BLEU = 31.8616943359375
refe: on the left side of the image we 

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 8.888713731466156e-08
Train epoch = 6, loss = 183.91311645507812, WER = 83.28267669677734, BLEU = 18.535940170288086
refe: in this picture we can see three persons are sitting in front of them there is a table and one person is talking and the rest of two persons are looking towards the opposite back side we can see a black color wall
hypo: in this image i sitting see three chairs on on chairs chairs chairs. the
lr: 8.518430939807151e-08
Train epoch = 6, loss = 197.7929229736328, WER = 81.56292724609375, BLEU = 23.555686950683594
refe: in the middle of the image we can see a paper, on the paper there is drawing. behind the paper there is wall.
hypo: in this image there can see a on on on on on on on and
lr: 8.148148148148148e-08
Train epoch = 6, loss = 184.09291076660156, WER = 81.39913177490234, BLEU = 23.057422637939453
refe: in this image in the center there are plants. on the right side there is a door and a red colour stand and there is a wall. on the right side of the wall th

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 6, WER = 74.62686920166016, BLEU = 29.446977615356445
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see plants, plants, plants. in the background we can see the sky.
Val epoch = 6, WER = 75.63579559326172, BLEU = 31.968101501464844
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people standing and holding a flag. in their hands. in the background we can see a flag.
Val epoch = 6, WER = 76.09152221679688, BLEU = 31.39946937561035
refe: on the left side of 

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 6.666491509243932e-08
Train epoch = 7, loss = 198.85076904296875, WER = 79.14764404296875, BLEU = 21.86953353881836
refe: in this picture, we see a man in the white t - shirt is running. at the bottom, we see the pavement or the soil. in the left top, we see the grass. on the right side, we see a man in the black t - shirt is standing. in the right top, we see the grass. this is a black and white picture. this picture might be clicked in the playground.
hypo: in this image black and white image. in the the ground the the the the the the the.
lr: 6.296208717584928e-08
Train epoch = 7, loss = 196.3692626953125, WER = 81.69368743896484, BLEU = 22.000137329101562
refe: in this image i can see two women are standing in the front and i can also see smile on their faces. i can see both of them are wearing jackets and i can see the left one is holding few things. i can also see the right one is wearing a cap. in the background i can see an open grass ground, number of trees, few buildings,

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 7, WER = 76.2686538696289, BLEU = 27.913833618164062
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see some plants, plants, plants. in the background we can see the sky.
Val epoch = 7, WER = 76.17823791503906, BLEU = 30.900772094726562
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people standing on the ground. in their hands. in the background there is a flag.
Val epoch = 7, WER = 76.39102172851562, BLEU = 30.62799644470215
refe: on the left side of the

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 4.444269287021711e-08
Train epoch = 8, loss = 170.0294952392578, WER = 82.37671661376953, BLEU = 22.276729583740234
refe: it looks like a black and white picture. we can see a woman and behind the woman there is a dark background.
hypo: in this image black and white image.. woman a a a.
lr: 4.073986495362707e-08
Train epoch = 8, loss = 184.02078247070312, WER = 81.2964859008789, BLEU = 22.610116958618164
refe: in this image we can see a group of plants. we can also see some poles, a fence, a pathway and some tents. on the backside we can see a group of trees, the mountains and the sky which looks cloudy.
hypo: in this image there can see the,,,,,,,, the, the, the the
lr: 3.7037037037037036e-08
Train epoch = 8, loss = 174.8050994873047, WER = 81.10934448242188, BLEU = 22.545743942260742
refe: in this image, there are some chairs which are in black color and there are some tables which are covered by a blue color cloth, there is a black color wall in the middle, there are some people

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 8, WER = 76.3432846069336, BLEU = 29.510948181152344
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see plants, plants, plants, plants, there are trees.
Val epoch = 8, WER = 76.07099914550781, BLEU = 31.37432289123535
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people standing on the ground. in their hands. in the background there is a flag.
Val epoch = 8, WER = 76.35205078125, BLEU = 31.176040649414062
refe: on the left side of the image we can see a p

  0%|          | 0/12687 [00:00<?, ?it/s]

lr: 2.2220470647994886e-08
Train epoch = 9, loss = 164.63665771484375, WER = 82.31907653808594, BLEU = 25.356525421142578
refe: there are few persons playing on the ground. this is grass and there is a ball. in the background we can see a mesh.
hypo: in this image i can see playing playing playing playing the the...... the
lr: 1.851764273140485e-08
Train epoch = 9, loss = 178.0460205078125, WER = 81.39530944824219, BLEU = 23.11852264404297
refe: this picture shows a couple of men standing and holding a wine bottles in their hand and we see people seated on the chairs
hypo: in this image standing standing a holding. bottle holding holding holding holding bottle bottle bottle bottle the and bottle and and and and
lr: 1.4814814814814813e-08
Train epoch = 9, loss = 175.54296875, WER = 81.08969116210938, BLEU = 23.35457992553711
refe: in this image i can see a man is sitting and smiling. this picture is black and white in color.
hypo: in this image black and white image. a sitting a a a. a 

  0%|          | 0/1586 [00:00<?, ?it/s]

Val epoch = 9, WER = 75.74626922607422, BLEU = 28.057846069335938
refe: in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky.
hypo: in this image we can see plants, plants, we can see the background we can see the background there are buildings.
Val epoch = 9, WER = 76.31013488769531, BLEU = 31.327890396118164
refe: in this picture, we see men in the uniform are standing. the man on the left side is holding a wooden stick. behind him, we see a man is holding a green color flag. we see people are holding flags which are in red, blue and green color. in the background, we see a car and buildings. there are poles and buildings in the background.
hypo: in this image we can see a group of people standing on the ground. in their hands. in the background there is a flag.
Val epoch = 9, WER = 76.43274688720703, BLEU = 31.079444885253906
refe:

In [None]:
def calc_loss_ca( logits, captions, c ):

    eps = 1e-4

    B, T, V = logits.size()
    
    one_hot_cap = F.one_hot( captions, num_classes = len( tokenizer ) ) # B * T * V

    lcabi = torch.zeros( (B, T),  dtype=torch.float, device = logits.device )
    zeroB = torch.zeros( (B),  dtype=torch.float, device = logits.device )
    for i in range( T ):
        tmp = torch.stack( [ torch.log( (  1.0 - torch.exp( torch.sum( logits[:,j,:] * one_hot_cap[:,i,:], dim = 1 ) ) / \
            ( torch.sum( torch.exp(logits[:,j]), dim = 1 ) + eps ) + eps ) ) if j != i  else zeroB \
            for j in range( max( 0, i - c ), min(  T, i + c ) ) ], dim = 0 )  # window 幅 * B
        lcabi[:,i] = torch.sum( tmp, dim = 0 ) # wubdiow 幅 * B を window 幅について sum
    
    # lcabi は B * T

    lca = torch.sum( torch.sum( lcabi , dim = 1 ), dim = 0 )
    
    return lca

In [None]:
def calc_cnt_repeat( logits, c, tau ):

    B, T, V = logits.size()
    
    def differentiable_argamx( logits, tau ):

        tmp = F.gumbel_softmax( logits, tau, hard=True )
        tmp1 = torch.arange( 0, logits.size(2) )[None,None] * tmp
        tokens = torch.sum( tmp1, dim = 2 )

        return tokens

    tokens = differentiable_argamx( logits, tau ) #logits から token を算出。微分可能 B * T

    cnt = 0
    for i in range( T ):
        cntj = 0
        for j in range( max( 0, i - c ), min(  T, i + c ) ):
            if j != i:
                tmp = torch.abs( tokens[:,j] - tokens[:,i] ) # i と j が同じだったら0 その他は 1 以上の整数 B
                tmp = F.sigmoid( 10 - 100 * tmp ) # 同じところだけ 1, あとは 0. B
                cntj += torch.sum( tmp ) # B についての sum　を j　について足しこんでいる。
        cnt += cntj # i について足しこんでいる。

    return cnt