### ライブラリの準備

###モジュールのインポートとGoogleドライブのマウント

In [1]:
import os
import glob
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import datetime
#from tqdm import tqdm
from tqdm.notebook import tqdm
import pickle
import random
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from PIL import Image
import skimage.transform
from collections import deque
from typing import Sequence, Dict, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import models
import torchvision.transforms as T
import torchvision.datasets as dataset
from torchvision.transforms import v2

from timm.scheduler import CosineLRScheduler
from transformers import  get_linear_schedule_with_warmup

#from transformers import AutoImageProcessor, AutoModel, AutoProcessor, CLIPVisionModel
from transformers import BertTokenizer, BertModel, CLIPVisionModel, BertForPreTraining

import sys

import util
import levenshtein
from nltk import bleu_score
import ssl
from torch.amp import autocast, GradScaler

2026-01-01 03:38:42.901590: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-01 03:38:42.947033: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-01 03:38:44.108227: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


### 位置エンコーディングの実装

In [2]:
class PositionalEmbedding(nn.Module):
    '''
    位置埋め込み （Positional embedding）
    dim_embedding: 埋込み次元
    max_len      : 入力の最大系列長
    '''
    def __init__(self, dim_embedding: int, max_len: int=2048):
        super().__init__()

        self.pos_emb = nn.Embedding(max_len, dim_embedding)

    '''
    位置エンコーディングの順伝播
    x: 位置エンコーディングを埋め込む対象のテンソル,
       [バッチサイズ, 系列長, 埋め込み次元]
    '''
    def forward(self, x: torch.Tensor):
        seq = x.shape[1]
        positions = torch.arange(start=0, end=seq, step=1, device=x.device).to(torch.long)
        positions = self.pos_emb(positions)[:seq,:]
        
        return positions

### Transformerデコーダの実装

### CaptioningTransformerの実装

In [3]:
def logsumexp(x, dim=1):
    return torch.logsumexp(x.float(), dim=dim).type_as(x)

class DynamicCRF(nn.Module):
    def __init__(self, num_embedding, low_rank=32, beam_size=64, use_repeat_logits_half=False, crf_coef=1.0):
        super().__init__()

        #low_rank = num_embedding
        self.E1 = nn.Embedding(num_embedding, low_rank)
        self.E2 = nn.Embedding(num_embedding, low_rank)

        self.vocb = num_embedding
        self.rank = low_rank
        self.beam = beam_size
        self.use_repeat_logits_half = use_repeat_logits_half
        self.crf_coef = crf_coef

    def extra_repr(self):
        return "vocab_size={}, low_rank={}, beam_size={}".format(
            self.vocb, self.rank, self.beam)

    def forward(self, emissions, top_logits, top_indices, targets, masks, beam=None):
        numerator = self._compute_score(emissions, targets, masks)
        denominator = self._compute_normalizer(emissions, targets, masks, beam )
        beam_probs = self._compute_normalizer2(top_logits, top_indices, targets, masks, beam)

        #return numerator - denominator, beam_probs, all_preds, sample_log_probs, sampled_beam_idx
        #return numerator - denominator, beam_probs, sampled_beam_idx
        return numerator - denominator, beam_probs
        #return numerator, beam_probs
    
    def forward_decoder(self, emissions, masks=None, beam=None):
        return self._viterbi_decode(emissions, masks, beam)

    def _compute_score(self, emissions, targets, masks=None):
        batch_size, seq_len = targets.size()

        emission_scores = emissions.gather(2, targets[:, :, None])[:, :, 0]  # B x T
        transition_scores = (self.E1(targets[:, :-1]) * self.E2(targets[:, 1:])).sum(2)
       
        scores = emission_scores
        scores[:, 1:] += transition_scores
        
        if masks is not None:
            scores = scores * masks.type_as(scores)

        return scores.sum(-1)
        
    def _compute_normalizer(self, emissions, targets=None, masks=None, beam=None):

        eps = 1e-8
        
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        if targets is not None:
            #_emissions = emissions.scatter(2, targets[:, :, None], np.float('inf'))
            _emissions = emissions.scatter(2, targets[:, :, None], float('inf'))
            beam_targets = _emissions.topk(beam, 2)[1]
            beam_emission_scores = emissions.gather(2, beam_targets)
        else:
            beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D; position i - 1, previous step.
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D; position i, current step.
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        for i in range(1, seq_len):
            next_score = score[:, :, None] + beam_transition_matrix[:, i-1]
            next_score = logsumexp(next_score, dim=1) + beam_emission_scores[:, i]
            if masks is not None:
                score = torch.where(masks[:, i:i+1], next_score, score)
            else:
                score = next_score

        return logsumexp(score, dim=1)

    def _compute_normalizer2(self, top_logits, top_indices, targets=None, masks=None, beam=None):

        eps = 1e-8
        
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = top_logits.size()[:2]
        beam_emission_scores, beam_targets = top_logits, top_indices
        beam_probs = F.softmax( top_logits, dim = 2 )
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D; position i - 1, previous step.
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D; position i, current step.
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)

        traj_scores = []
        '''
        # compute the normalizer in the log-space
        #score = beam_emission_scores[:, 0]  # B x K
        _score1 = beam_emission_scores[:, 0]  # B x K
        
        for i in range(1, seq_len):
            traj_scores.append(_score1)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1] # bsz, beam, beam
            _score1, _index = _score.max(dim=1) # bsz, beam     bsz, beam 
            _score = _score1 + beam_emission_scores[:, i] # bsz, beam

            #if masks is not None:
            #    score = torch.where(masks[:, i: i+1], _score, score)
            #    index = torch.where(masks[:, i: i+1], _index, dummy)
            #else:
            #    score, index = _score, _index
            score = _score

        '''
        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        
        for i in range(1, seq_len):
            traj_scores.append(score)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1] # bsz, beam, beam
            _score, _index = _score.max(dim=1) # bsz, beam     bsz, beam 
            _score = _score + beam_emission_scores[:, i] # bsz, beam

            #if masks is not None:
            #    score = torch.where(masks[:, i: i+1], _score, score)
            #    index = torch.where(masks[:, i: i+1], _index, dummy)
            #else:
            #    score, index = _score, _index
            score = _score
        
        
        all_scores = traj_scores
        all_scores.append( score )
        all_scores = torch.stack( all_scores, dim = 0 ).transpose( 0, 1 )
        #all_scores = all_scores.transpose(0,1)
        
        #print( "size of all_scores:", all_scores.size() )
        
        denominator1 = torch.logsumexp(all_scores, dim=2).type_as(all_scores)
        
        beam_log_probs = all_scores - denominator1.view(batch_size, seq_len, 1)

        beam_probs = torch.exp(beam_log_probs)

        return beam_probs
    '''
    def _viterbi_decode(self, emissions, masks=None, beam=None):
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # bsz, seq_len, beam, beam

        traj_tokens, traj_scores = [], []
        finalized_tokens, finalized_scores = [], []

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        #print( "score size:", score.size() )
        dummy = torch.arange(beam, device=score.device).expand(*score.size()).contiguous()

        for i in range(1, seq_len):
            traj_scores.append(score)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1] # bsz, beam, beam
            _score, _index = _score.max(dim=1) # bsz, beam     bsz, beam 
            _score = _score + beam_emission_scores[:, i] # bsz, beam

            if masks is not None:
                score = torch.where(masks[:, i: i+1], _score, score)
                index = torch.where(masks[:, i: i+1], _index, dummy)
            else:
                score, index = _score, _index
            traj_tokens.append(index)

        # now running the back-tracing and find the best
        best_score, best_index = score.max(dim=1)
        finalized_tokens.append(best_index[:, None])
        finalized_scores.append(best_score[:, None])

        for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)):
            previous_index = finalized_tokens[-1]
            finalized_tokens.append(idx.gather(1, previous_index))
            finalized_scores.append(scs.gather(1, previous_index))

        finalized_tokens.reverse()
        finalized_tokens = torch.cat(finalized_tokens, 1)
        finalized_tokens = beam_targets.gather(2, finalized_tokens[:, :, None])[:, :, 0]

        finalized_scores.reverse()
        finalized_scores = torch.cat(finalized_scores, 1)
        finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]

        return finalized_scores, finalized_tokens

   
    '''
    def _viterbi_decode(self, emissions, targets, top_indices = None, masks=None, beam=None):

        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        
        if self.use_repeat_logits_half == True:
            penalty = 1.2
            scores, preds = torch.max( emissions, 2 )
            masks = emissions == scores[:,:,None]
            masks = masks.permute( 1, 0, 2 )
            new_mask = torch.zeros( (  masks.size(1), masks.size(2)), device = emissions.device, dtype=torch.bool )
            new_masks = torch.zeros( ( masks.size(0), masks.size(1), masks.size(2)), device = emissions.device, dtype=torch.bool )
            for i, mask in enumerate( masks ):
                new_mask = torch.logical_or( mask,  new_mask  )
                new_masks[i] = new_mask
            new_masks = new_masks.transpose(0,1)
            first_true_mask = ( new_masks.int().cumsum(dim = 1 ) == 1 ) & new_masks
            new_masks = new_masks & ( ~first_true_mask )

            p_masks = emissions > 0
            m_masks = emissions < 0
            p_new_masks = p_masks & new_masks
            m_new_masks = p_masks & new_masks
            emissions2 = emissions.clone()
            emissions2[p_new_masks] = emissions[p_new_masks] / penalty
            emissions2[m_new_masks] = emissions2[m_new_masks] * penalty
            masks = None
            if top_indices == None:
                beam_emission_scores, beam_targets = emissions2.topk(beam, 2)
            else:
                beam_emission_scores = torch.gather( emissions2, -1, top_indices )
                beam_targets = top_indices
        else:
            if top_indices == None:
                beam_emission_scores, beam_targets = emissions.topk(beam, 2)
            else:
                beam_emission_scores = torch.gather( emissions, -1, top_indices )
                beam_targets = top_indices
        
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # bsz, seq_len, beam, beam

        traj_tokens, traj_scores = [], []
        finalized_tokens, finalized_scores = [], []

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        #print( "score size:", score.size() )
        dummy = torch.arange(beam, device=score.device).expand(*score.size()).contiguous()

        for i in range(1, seq_len):
            traj_scores.append(score)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1] # bsz, beam, beam
            _score, _index = _score.max(dim=1) # bsz, beam     bsz, beam 
            _score = _score + beam_emission_scores[:, i] # bsz, beam

            if masks is not None:
                score = torch.where(masks[:, i: i+1], _score, score)
                index = torch.where(masks[:, i: i+1], _index, dummy)
            else:
                score, index = _score, _index
            traj_tokens.append(index)

        #print( "traj_scores:", traj_scores )
        
        all_scores = traj_scores
        #print( "all_scores:", all_scores )
        all_scores.append( score )
        #print( "all_scores:", all_scores )
        all_scores = torch.stack( all_scores, dim = 0 ).transpose( 0, 1 )
        #print( "size of all_scores:", all_scores.size() )
        
        # now running the back-tracing and find the best
        best_score, best_index = score.max(dim=1)
        finalized_tokens.append(best_index[:, None])
        finalized_scores.append(best_score[:, None])

        for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)):
            previous_index = finalized_tokens[-1]
            finalized_tokens.append(idx.gather(1, previous_index))
            finalized_scores.append(scs.gather(1, previous_index))

        finalized_tokens.reverse()
        finalized_tokens = torch.cat(finalized_tokens, 1)
        finalized_tokens = beam_targets.gather(2, finalized_tokens[:, :, None])[:, :, 0]

        finalized_scores.reverse()
        finalized_scores = torch.cat(finalized_scores, 1)
        finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
        if self.crf_coef != 0.0:
            if self.use_repeat_logits_half == True:
                numerator = self._compute_score(emissions2, targets)
                denominator = self._compute_normalizer(emissions2, targets)
            else:
                numerator = self._compute_score(emissions, targets)
                denominator = self._compute_normalizer(emissions, targets)
            crf_loss = - ( numerator - denominator ).mean() / seq_len
        else:
            crf_loss = torch.zeros( (1), device = emissions.device, dtype = torch.float )
        
        top_probs = F.softmax( all_scores, dim = 2 )
        
        return finalized_scores, finalized_tokens, top_probs, beam_targets, crf_loss   

In [4]:
class TopLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim, crf_low_rank, crf_beam_size, dropout, padding_idx, use_repeat_logits_half=False,
                crf_coef = 1.0 ):
        super(TopLayer, self).__init__()

        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.padding_idx = padding_idx
        print( "in TopLayer:" )
        self.crf_layer = DynamicCRF(num_embedding = vocab_size, low_rank = crf_low_rank, 
                                    beam_size = crf_beam_size, use_repeat_logits_half=use_repeat_logits_half,
                                   crf_coef=crf_coef)

        #self.one_more_layer_norm = nn.LayerNorm(embed_dim)
        #self.tgt_word_prj = nn.Linear(self.embed_dim, self.vocab_size)
        ## gae 学習用
        #self.linear_critical = nn.Linear(crf_beam_size, 1 )

    #def forward(self, src_representation, src_input, tgt_input, is_training):
    def forward(self, src_representation, top_logits, top_indices, src_input, tgt_input, is_training ):
        '''
            src_representation : bsz x seqlen x embed_dim
            src_input : bsz x seqlen
            tgt_input : bsz x seqlen
        '''
        #assert src_input.size() == tgt_input.size()

        src_input = src_input.transpose(0, 1) # src_len x bsz
        #seqlen, bsz = src_input.size()
        seqlen, bsz = src_input.shape[:2]

        src_representation = F.dropout(src_representation, p=self.dropout, training=is_training)
        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim

        src = src_representation

        #emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)
        emissions = src_representation
        #log_probs = torch.log_softmax(emissions, -1)
        #assert log_probs.size() == torch.Size([seqlen, bsz, self.vocab_size])

        emissions = emissions.transpose(0, 1) # [bsz x src_len x vocab_size]
        #emission_mask = ~tgt_input.eq(self.padding_idx) # [bsz x src_len] #pad のところは 0 padでないところが 1
        emission_mask = torch.ones_like( tgt_input, dtype=torch.bool ) #全部　pad でないとして 1
        batch_crf_loss, top_probs = self.crf_layer(emissions, top_logits, top_indices, tgt_input, emission_mask) # [bsz]
        #critical_value = self.linear_critical( top_probs )
        #critical_value = torch.zeros( ( 1,1,1) )
        batch_crf_loss = - batch_crf_loss
        assert batch_crf_loss.size() == torch.Size([bsz])
        return batch_crf_loss, top_probs

    def decoding(self, src_representation, src_input):
        '''
            src_representation : bsz x seqlen x embed_dim
            src_input : bsz x seqlen
            tgt_input : bsz x seqlen
        '''
        src_input = src_input.transpose(0, 1) # src_len x bsz
        seqlen, bsz = src_input.size()

        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim
        src = src_representation

        emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)

        emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        _, finalized_tokens = self.crf_layer.forward_decoder(emissions)
        assert finalized_tokens.size() == torch.Size([bsz, seqlen])
        return finalized_tokens

    def length_ratio_decoding(self, src_representation, src_input, length_ratio):
        '''
            src_representation : 1 x seqlen x embed_dim
            src_input : 1 x seqlen
        '''
        src_input = src_input.transpose(0, 1) # src_len x bsz
        seqlen, bsz = src_input.size()

        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim
        src = src_representation

        emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)

        emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        valid_len = int(seqlen * length_ratio) + 1
        valid_emissions = emissions[:, :valid_len+1,:]
        _, finalized_tokens = self.crf_layer.forward_decoder(valid_emissions)
        return finalized_tokens


In [5]:
class CaptioningTransformer(nn.Module):
    
    #CaptioningTransformerのコンストラクタ
    #dim_embedding  : 埋め込み次元
    #dim_feedforward: FNNの中間特徴次元
    #num_heads      : マルチヘッドアテンションのヘッド数
    #num_layers     : Transformerデコーダ層の数
    #vocab_size     : 辞書の次元
    #null_index     : NULLのID
    #dropout        : ドロップアウト確率
    
    def __init__(self, img_size: int,  dim_embedding: int, length_max: int, vocab_size: int, tokenizer, dropout: float = 0.0, \
                 pad_token_id: int=0, use_repeat_logits_half=False, crf_coef = 1.0):
        super().__init__()

        #CLIP
        model_id = "openai/clip-vit-large-patch14-336"
        self.clip_model = CLIPVisionModel.from_pretrained(model_id )
        memory = self.clip_model( torch.randn( 1, 3, 336, 336 ) )
        memory = memory.last_hidden_state
        img_length = memory.size(1)
        clip_dim = memory.size(2)
        self.connector_pool = nn.AdaptiveAvgPool1d(length_max - 1 )
        self.connector_ln = nn.LayerNorm( clip_dim )
        self.connector_linear1 = nn.Linear( clip_dim, dim_embedding )
        self.connector_gleu = nn.GELU()
        self.connector_linear2 = nn.Linear( dim_embedding, dim_embedding )

       
        # Connector
        self.connector_pool = nn.AdaptiveAvgPool1d(length_max - 1 )
       # Down Sampling
        cls_token = memory[:, :1, :] # (bsz, 1, 1024)
        patch_tokens = memory[:, 1:, :] # (bsz, 576, 1024)
        # パッチ部分を 576 -> 96 に圧縮
        patch_tokens = patch_tokens.transpose(1, 2) # (bsz, 1024, 576)
        patch_tokens = self.connector_pool(patch_tokens)
        patch_tokens = patch_tokens.transpose(1, 2) # (bsz, 96, 1024)
        # CLSと結合して合計 97 トークンにする
        memory = torch.cat([cls_token, patch_tokens], dim=1) # (bsz, 97, 1024)

        self.pos_emb = PositionalEmbedding( dim_embedding )

        model_id = "google-bert/bert-large-uncased"
        self.bert = BertModel.from_pretrained( model_id )

        ## 単語出力分布計算
        self.ln_outputs = nn.LayerNorm( dim_embedding )
        self.linear = nn.Linear(dim_embedding, vocab_size)

        crf_low_rank = 32
        crf_beam_size = 256
        self.crf_beam_size = crf_beam_size
        top_dropout = 0.0
        tgt_padding_idx = tokenizer.pad_token_id
        print( "initialize self.toplayer" )
        self.toplayer = TopLayer( vocab_size, dim_embedding, crf_low_rank, crf_beam_size, top_dropout, 
                                  tgt_padding_idx, use_repeat_logits_half=use_repeat_logits_half, crf_coef = crf_coef )

        ### GAE 用
        #self.ln_critical2 = nn.LayerNorm( crf_beam_size )
        #self.linear_critical2 = nn.Linear( crf_beam_size, 1)
        
        self.dim_embedding = dim_embedding


    # CaptioningTransformerの順伝播処理
    #features: 画像特徴量 [バッチサイズ, 埋め込み次元]
    #captions: 正解キャプション [バッチサイズ, 系列長]
    
    def forward(self, images: torch.Tensor ):

        self.device = images.device

        memory = self.clip_model( images ).last_hidden_state

        memory = self.mlp_connector( memory )

        #print( "memory size:", memory.size() )

        memory += self.pos_emb( memory )

        outputs = self.bert( inputs_embeds = memory ).last_hidden_state
        logits = self.ln_outputs( outputs )
        logits = self.linear( logits )
        #critical_value = self.ln_critical( outputs )
        #critical_value = self.linear_critical( critical_value )
        #masks = torch.isnan( critical_value )
        #print( critical_value[masks] )
        #critical_value = critical_value / ( logits.size(2) ** 2.0 )
        #critical_value = torch.zeros( ( 1, 1, 1 ) )
        
        return logits

    def mlp_connector(self, memory ):

        cls_token = memory[:, :1, :] # (bsz, 1, 1024)
        patch_tokens = memory[:, 1:, :] # (bsz, 576, 1024)

        # パッチ部分を 576 -> 96 に圧縮
        patch_tokens = patch_tokens.transpose(1, 2) # (bsz, 1024, 576)
        patch_tokens = self.connector_pool(patch_tokens)
        patch_tokens = patch_tokens.transpose(1, 2) # (bsz, 96, 1024)

        # CLSと結合して合計 97 トークンにする
        memory = torch.cat([cls_token, patch_tokens], dim=1) # (bsz, 97, 1024)

        memory = self.connector_ln( memory )
        memory = self.connector_linear1( memory )
        memory = self.connector_gleu( memory )
        memory = self.connector_linear2( memory )
        
        return memory

    def inference(self, images: torch.Tensor, targets: torch.Tensor, top_indices = None ):

        self.device = images.device
        
        memory = self.clip_model( images ).last_hidden_state
        memory = self.mlp_connector( memory )
        memory += self.pos_emb( memory )
        
        outputs = self.bert( inputs_embeds = memory ).last_hidden_state
        outputs = self.ln_outputs( outputs )
        emissions = self.linear( outputs )

        ##emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        #seqlen = emissions.size(1) #added by T.uchi
        #length_ratio = 1.0         #added by T.uchi
        ##length_ratio = 0.7         #added by T.uchi
        #valid_len = int(seqlen * length_ratio) + 1
        #valid_emissions = emissions[:, :valid_len+1,:]
        finalized_scores, finalized_tokens, top_probs, top_indices, crf_loss  = \
            self.toplayer.crf_layer._viterbi_decode(emissions, targets, top_indices)
        
        #critical_value = self.ln_critical2( top_probs )
        #critical_value = self.linear_critical2( critical_value )

        #return finalized_scores, finalized_tokens, top_probs, top_indices, critical_value[:,:,0], crf_loss, emissions
        return finalized_scores, finalized_tokens, top_probs, top_indices, crf_loss, emissions


In [6]:
class MyDataset(Dataset):
    def __init__(self, file_path: str, img_directory: str, transforms, tokenizer, length_max = None ) -> None:
        super().__init__()
        self.img_directory = img_directory
        self.transforms = transforms
        # TODO: fix to original data
        #画像の前処理
        self.img_file = []
        self.tokens = []
        #vocab_size = len( tokenizer )
        #c1 = torch.zeros( ( vocab_size ) )
        #c2 = torch.zeros( ( vocab_size, vocab_size ) )
        if length_max == None:
            self.length_max = 0
        else:
            self.length_max = length_max
        length_sum = 0

        with open( file_path, 'rb') as f:
            data = pickle.load(f)
        for i, line_data in enumerate( data ):
            if i % 100000 == 0:
                print( "i:", i )
            self.img_file.append( line_data['img_file'] )
            id_tokens = line_data['id_tokens']
            id_tokens.append( eos_token_id )
            id_tokens.append( eos_token_id )
            length_sum += len( id_tokens )
            if length_max != None:
                id_tokens = torch.tensor( id_tokens )[:self.length_max]
            else:
                if self.length_max < len( id_tokens ):
                    self.length_max = len( id_tokens )
                id_tokens = torch.tensor( id_tokens )
            self.tokens.append( id_tokens )
        # w1, w2 を作る時は length_max = None　でお願いします。
        #    for i2 in range( len(id_tokens) ):
        #        if i2 == len( id_tokens ) - 1:
        #            c1[id_tokens[i2]] += 1
        #        else:
        #            c1[id_tokens[i2]] += 1
        #            c2[id_tokens[i2], id_tokens[i2+1] ] += 1
        '''
        c1avg = int( torch.sum( c1 ) / torch.sum( torch.ne( c1, 0 ).int()) )
        c2avg = int( torch.sum( torch.sum( c2, dim = 1 ), dim = 0 ) / torch.sum( torch.ne( c2, 0 ).int() ) )

        c1[0] = c1avg

        c2[:,0] = c2avg
        c2[0,:] = c2avg
        
        sumc1 = torch.sum( c1, dim = 0 )
        sumc2 = torch.sum( torch.sum( c2, dim = 1 ), dim = 0 )

        prob1 = c1 / sumc1
        prob2 = c2 / sumc2

        self.w1 = prob1 ** -0.4
        self.w1 = torch.nan_to_num( self.w1, nan = 0.0, posinf=0.0, neginf=0.0 )
        avg1 = torch.sum( self.w1, dim = 0 ) / torch.sum( torch.ne( self.w1, 0.0 ).int() )
        self.w1 = self.w1 / avg1

        self.w2 = prob2 ** -0.4
        self.w2 = torch.nan_to_num( self.w2, nan = 0.0, posinf=0.0, neginf=0.0 )
        avg2 = torch.sum( torch.sum( self.w2, dim = 1 ), dim = 0 ) / torch.sum( torch.ne( self.w2, 0.0 ).int() )
        self.w2 = self.w2 / avg2

        with open( "/mnt/ssd2/v7/w_unigrma.pkl", mode="wb" ) as f:
            pickle.dump( self.w1, f )

        with open( "/mnt/ssd2/v7/w_bigrma.pkl", mode="wb" ) as f:
            pickle.dump( self.w2, f )
        
        '''

        #with open( "/mnt/ssd2/v7/w_unigram.pkl", 'rb') as f:
        #    self.w1 = pickle.load(f)

        #with open( "/mnt/ssd2/v7/w_bigram.pkl", 'rb') as f:
        #    self.w2 = pickle.load(f)
        
        if length_max == None:
            print( "length max:", self.length_max )
            print( "avg length:", length_sum / len( self.tokens ) )
    
    # ここで取り出すデータを指定している
    def __getitem__(
        self,
        index: int
    ):
        tokens = self.tokens[index]
        img_file = self.img_file[index] + ".jpg"
        img_path = os.path.join( self.img_directory, img_file ) #index番目の画像のパスを取得
        img = Image.open(img_path) #PIL形式で画像を読み込み
        if img.mode != 'RGB':
            img = img.convert("RGB")
        img = self.transforms(img)
        
        return img, tokens

    # この method がないと DataLoader を呼び出す際にエラーを吐かれる
    def __len__(self) -> int:
        return len(self.tokens)

    def length_max(self):
        return self.length_max

    #def w1(self):
    #    return self.w1

    #def w2(self):
    #    return self.w2

In [7]:
def collate_func(batch: Sequence[Tuple[Union[torch.Tensor, str]]], pad_index, length_max ):
    imgs, tokens = zip(*batch)

    max_length = length_max
    #max_length = 0
    #for target in tokens:
    #    if max_length < len( target ):
    #        max_length = len( target )
    
    targets = []
    lengths = []
    for target in tokens:
        pad_len = max_length - len( target ) 
        #print( "target:", target )
        input2= F.pad( target, (0, pad_len), mode='constant', value = pad_index)
        targets.append( input2 )
        lengths.append( len( target ) )
    
    imgs = torch.stack( imgs, dim = 0 )
    targets = torch.stack( targets, dim = 0 )
    lengths = torch.tensor( lengths, requires_grad = False  )

    #if imgs.dim() != 4:
    #   print( "in collate imgs size:", imgs.size() )
    
    return imgs, targets, lengths

###学習におけるハイパーパラメータやオプションの設定

In [8]:
model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained( model_id )
sos_token_id = tokenizer.encode( [ "[unused0]" ] )[1]
eos_token_id = tokenizer.encode( [ "[unused1]" ] )[1]

print( "sos_tokeni_d:", sos_token_id )
print( "eos_token_id:", eos_token_id )

class ConfigTrain(object):
    '''
    ハイパーパラメータ、システム共通変数の設定
    '''
    def __init__(self):

        # ハイパーパラメータ
        self.img_size = 336
        self.dim_embedding = 1024   # 埋め込み層の次元
        self.length_max = 97
        #self.lr = 5e-5            # 学習率
        #self.lr = 2e-5            # 学習率
        self.lr_clip = 2e-7
        self.lr_con = 1e-4
        self.lr_bert = 2e-5            # 学習率
        #self.lr_cri = 1e-4
        self.lr_others = 1e-4
        #self.lr_top = 1e-4
        #self.lr = 5e-6            # 学習率
        self.dropout = 0.0         # dropout確率
        #self.batch_size = 128       # ミニバッチ数
        self.batch_size = 40       # ミニバッチ数
        #self.batch_size = 32       # ミニバッチ数
        #self.batch_size = 16       # ミニバッチ数
        #self.batch_size = 8       # ミニバッチ数
        #self.batch_size = 4       # ミニバッチ数
        #self.batch_size = 1       # ミニバッチ数
        #self.num_epochs = 100       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        #self.num_epochs = 100       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        #self.num_epochs = 60       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        self.num_epochs = 5       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        self.use_amp = True
        #self.use_amp = False
        #self.use_saved_pth = True
        self.use_saved_pth = False
        self.model_id = "google-bert/bert-large-uncased"
        self.vocab_size = len( tokenizer )
        self.weight_decay = 0.1
        self.betas = (0.9, 0.999 )
        self.warmup = 0.2
        #self.alpha = 1.0
        self.crf_coef = 1.0
        self.ce_coef = 0.3
        self.use_repeat_logits_half = False
        self.window_size = 3
        
        # パスの設定
        self.img_directory = '/mnt/ssd2/v7/img'
        self.anno_file = '/mnt/ssd2/v7/data.pkl'
        self.save_directory = './model'

        # 検証に使う学習セット内のデータの割合
        self.test_ratio = 0.1
        self.val_ratio = 0.1
        #self.val_ratio = 0.0004
        #self.test_ratio = 0.0004
        
        # 学習に使うデバイス
        #self.device = 'cuda'
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        #self.device = 'cpu'
        
        # データローダーに使うCPUプロセスの数
        #self.num_workers = 4
        self.num_workers = 0 if self.device == torch.device('cpu') else 12
        #self.num_workers = 0
        
        # 移動平均で計算する損失の値の数
        self.moving_avg = 100

sos_tokeni_d: 1
eos_token_id: 2


In [9]:
#config = ""
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
## 辞書（単語→単語ID）の読み込み
#with open('../PreTrain_Decoder/translateDatasetNTT_blank4_pad0/word_to_id2.pkl', 'rb') as f:
#    word_to_id = pickle.load(f)
#max_idx_en = len( word_to_id )
#word_to_id['<mask>'] = max_idx_en
#mask_value = word_to_id['<mask>']
#start_idx = word_to_id['<start>']
#bert_model_path = 'models--google-bert--bert-large-uncased/snapshots/6da4b6a26a1877e173fca3225479512db81a5e5b'
#tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path = bert_model_path )
model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained(model_id)
model = CaptioningTransformer( img_size = 336, dim_embedding=1024, length_max = 97, vocab_size=len(tokenizer),
                 tokenizer=tokenizer, dropout=0.1, model_id =model_id).to(device)

#images = torch.randint( 0, 255, size = (10,3,256,256) )
images = torch.randn( ( 1, 3, 336,336 ) )
outputs = model( images )

print( outputs.size() )

img_length: 577
text_length_max: 97
stride: 6
bert in memory size: torch.Size([1, 97, 1024])
initialize self.toplayer
in TopLyaer:
torch.Size([1, 97, 30522])


### 学習率スケジューラ

### 学習を行う関数

In [9]:
def calc_loss_nll( train_log_prob_matrix, train_batch_tgt ):
    bsz, tgt_len = train_batch_tgt.size()
    
    train_log_prob_matrix_unbind = torch.unbind(train_log_prob_matrix, dim = 0)
    assert len(train_log_prob_matrix_unbind) == tgt_len
    train_truth_tgt_unbind = torch.unbind(train_batch_tgt, dim = 1)
    assert len(train_truth_tgt_unbind) == tgt_len

    nll_loss_list = []

    for index in range(tgt_len):
        curr_step_log_prob = train_log_prob_matrix_unbind[index]
        curr_step_tgt = train_truth_tgt_unbind[index].view(bsz)
        #one_nll_loss = NLL(curr_step_log_prob, curr_step_tgt)
        one_nll_loss = criterion_nll(curr_step_log_prob, curr_step_tgt)
        assert one_nll_loss.size() == torch.Size([bsz])
        nll_loss_list.append(one_nll_loss)

    nll_loss_matrix = torch.stack(nll_loss_list, dim = 1)
    #assert nll_loss_matrix.size() == torch.Size([bsz, tgt_len])
    #tgt_padding_matrix = ~train_batch_tgt.eq(data.tgt_padding_idx)
    #tgt_padding_matrix = tgt_padding_matrix.type(nll_loss_matrix.type())
    #assert tgt_padding_matrix.size() == nll_loss_matrix.size()
    #nll_loss_matrix = nll_loss_matrix * tgt_padding_matrix

    train_nll_loss = nll_loss_matrix.sum(-1)
    #assert train_batch_crf_loss.size() == train_nll_loss.size()

    return train_nll_loss

In [10]:
model = CaptioningTransformer( 336,
    97, 1024, len(tokenizer),
    tokenizer, 0.0)

# 損失関数の定義
log_softmax = nn.LogSoftmax( dim = 2 )
criterion_nll = nn.NLLLoss( reduction = 'none' )

imgs = torch.rand( ( 1, 3, 336,336), requires_grad = True )

outputs = model( imgs )

src_representation = outputs
src_input = outputs
captions = torch.randint( 0, len(tokenizer), size=(1,97) )
tgt_input = captions
train_batch_crf_loss, topk_probs, topk_indicies = model.toplayer( src_representation, src_input, tgt_input, is_training = True )
print( "train_batch:", train_batch_crf_loss.grad_fn )
print( "topk_probs:", topk_probs.grad_fn )
a = torch.mean( train_batch_crf_loss ) # mean over bsz
log_probs = log_softmax( outputs ).transpose(0,1)
train_nll_loss = calc_loss_nll( log_probs, captions )
b = 0.3 * torch.mean( train_nll_loss )
loss = a + b
print( "a:", a.grad_fn )
print( "b:", b.grad_fn )

initialize self.toplayer
in TopLayer:


RuntimeError: The size of tensor a (1024) must match the size of tensor b (97) at non-singleton dimension 2

In [9]:
config = ConfigTrain()

model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained(model_id)

# 辞書サイズを保存
vocab_size = len( tokenizer )

# モデル出力用のディレクトリを作成
os.makedirs(config.save_directory, exist_ok=True)

# 画像のtransformsを定義
transforms = v2.Compose([
    v2.Resize((336, 336)),
    v2.AutoAugment(),
    #v2.ToTensor(),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    ## Coco データセット 2017 train の平均と標準偏差
    #v2.Normalize((0.456,0.427,0.401),(0.224,0.219,0.231) )
    # ImageNetデータセットの平均と標準偏差
    #v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    # Clip Model の config から引用。
    v2.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# v7 データセット
train_dataset = MyDataset( file_path=config.anno_file,
                           img_directory = config.img_directory,
                           transforms=transforms,tokenizer=tokenizer, length_max = config.length_max)

# Subset samplerの生成
test_set, val_set, train_set = util.generate_subset_test_val_train(
    train_dataset, config.test_ratio, config.val_ratio )
    
# 学習時にランダムにサンプルするためのサンプラー
train_sampler = SubsetRandomSampler(train_set)

# DataLoaderを生成
collate_func_lambda = lambda x: collate_func(x, tokenizer.pad_token_id, config.length_max)
train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                    sampler=train_sampler,
                    collate_fn=collate_func_lambda)
val_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                    sampler=val_set,
                    collate_fn=collate_func_lambda)
test_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    #batch_size=config.batch_size,
                    batch_size=1,
                    num_workers=config.num_workers,
                    sampler=test_set,
                    collate_fn=collate_func_lambda)


print( "config.device:", config.device )
print( "学習セット数:",len( train_loader ) )
print( "評価セット数:",len( val_loader ))
print( "テストセット数:",len( test_loader ))
print( "use_amp:", config.use_amp )
print( "use_saved_pth:", config.use_saved_pth )

# モデルの定義
model = CaptioningTransformer( config.img_size,
    config.dim_embedding, config.length_max, config.vocab_size,
    tokenizer, config.dropout, pad_token_id = tokenizer.pad_token_id,
    use_repeat_logits_half = config.use_repeat_logits_half,
    crf_coef = config.crf_coef )
model.to(config.device)
#crf_low_rank = 32
#crf_beam_size = 256
#top_dropout = 0.0
#tgt_padding_idx = tokenizer.pad_token_id
#toplayer = TopLayer( vocab_size, dim_embedding, crf_low_rank, crf_beam_size, top_dropout, tgt_padding_idx )


# 損失関数の定義
#criterion = nn.CrossEntropyLoss( ignore_index = tokenizer.pad_token_id, reduction = 'mean' )
#criterion = nn.CrossEntropyLoss( reduction = 'mean' )
log_softmax = nn.LogSoftmax( dim = 2 )
#softmax = nn.Softmax( dim = 2 )
criterion_nll = nn.NLLLoss( reduction = 'none' )
#criterion_nll = nn.NLLLoss( )
#criterion = nn.CrossEntropyLoss()

# 最適化手法の定義
# 最適化手法の定義
# Optimizerの生成, clipとそうでないモジュールとの
# パラメータで異なる学習率を適用
params_clip = []
params_bert = []
params_con = []
#params_cri = []
params_others = []
for name, parameter in model.named_parameters():
    if parameter.requires_grad:
        if 'clip_model' in name:
            params_clip.append(parameter)
        elif 'connector' in name:
            params_con.append(parameter)
        elif 'bert' in name and 'critical' not in name:
            params_bert.append(parameter)
        #elif 'critical' in name:
        #    params_cri.append(parameter)
        else:
            params_others.append(parameter)
param_groups = [
    {'params': params_clip, 'lr': config.lr_clip},
    {'params': params_con, 'lr': config.lr_con},
    {'params': params_bert, 'lr': config.lr_bert},
    #{'params': params_cri, 'lr': config.lr_cri},
    {'params': params_others, 'lr': config.lr_others}]
#optimizer = torch.optim.AdamW( model.parameters() , lr=config.lr)
#optimizer = torch.optim.AdamW( param_groups, weight_decay = config.weight_decay, betas=config.betas )
optimizer = torch.optim.AdamW( param_groups, weight_decay = config.weight_decay, betas=config.betas )

# 全ステップ数
num_global_steps = len( train_loader ) * config.num_epochs
print( "num_global_steps:", num_global_steps )
num_warmup_steps = num_global_steps * config.warmup
print( "num_warmup_steps:", num_warmup_steps )
#スケジューラーの定義
scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps, num_global_steps )    
#t_scheduler = get_linear_schedule_with_warmup( t_optimizer, num_warmup_steps, num_global_steps )    


PATH = "model/model_bert_large_NAR_PAD_curr.pth"
print( "use_saved_pth:", config.use_saved_pth )
print( "exist saved_pth:", os.path.isfile(PATH) ) 
use_saved_pth = config.use_saved_pth
if use_saved_pth and os.path.isfile(PATH):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ## optimizerのstateを現在のdeviceに移す。これをしないと、保存前後でdeviceの不整合が起こる可能性がある。
    #for state in optimizer.state.values():
        #for k, v in state.items():
            #if isinstance(v, torch.Tensor):
                #state[k] = v.to(device)
    begin_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    global_step = checkpoint['global_step']    
else:
    begin_epoch = 0
    global_step = 0

print( "begin_epoch:", begin_epoch )
print( "global_ste:", global_step )

len_tr_loader = len( train_loader )
#train_param = len_tr_loader // 6
train_param = len_tr_loader // 100
len_val_loader = len( val_loader )
#train_param = len_val_loader // 3
val_param = len_val_loader // 3
print( "train_param:", train_param )
print( "val_param:", val_param )

print( "epochs:", config.num_epochs )
print( "batch_size:", config.batch_size )
print( "lr_clip:", config.lr_clip )
print( "lr_con:", config.lr_con )
print( "lr_bert:", config.lr_bert )
print( "lr_others:", config.lr_others )
print( "weight_decay:", config.weight_decay )
print( "betas:", config.betas )
print( "crf_coef:", config.crf_coef )
print( "ce_coef:", config.ce_coef )
print( "use_repeat_logits_half:", config.use_repeat_logits_half )

# 学習経過の書き込み
now = datetime.datetime.now()
train_loss_file = '{}/MyOriginal_train_loss_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))
with open(train_loss_file, 'a') as f:
    print(f'{len_tr_loader}', file=f) 
print( "train_loss_file:", train_loss_file )
val_loss_file = '{}/MyOriginal_val_loss_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))
with open(val_loss_file, 'a') as f:
    print(f'{len_val_loader}', file=f) 
norm_file = '{}/norm_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))

# 学習
val_loss_best = float('inf')

fn = bleu_score.SmoothingFunction().method7

# AMP用のスケーラー
scaler = GradScaler(enabled=config.use_amp)

pad_token_id = tokenizer.pad_token_id

for epoch in range(config.num_epochs):
    with tqdm(train_loader) as pbar:
    #with tqdm(val_loader) as pbar:
        pbar.set_description(f'[エポック {epoch + 1}]')

        # 学習モードに設定
        model.train()

        train_losses = deque()
        train_a = deque()
        train_b = deque()
        train_errors = deque()
        train_bleus = deque()
        for n_batch, (imgs, captions, caption_lengths) in enumerate( pbar ):
            # ミニバッチを設定
            imgs = imgs.to(config.device)
            captions = captions.to(config.device)
                
            optimizer.zero_grad()

            # 最後の単語から次を予測する必要はないため最後の単語を除外
            with autocast(str(config.device),enabled=config.use_amp):
                #logits = model( imgs )

                # 損失の計算
                # 単語軸が第1軸である必要があるため、転置
                #train_batch_crf_loss = model.toplayer( src_representation, src_input, tgt_input, is_training = True )
                #train_batch_crf_loss, topk_probs, topk_indicies = model.toplayer( src_representation, src_input, tgt_input, is_training = True )
                #top_logits, top_indices = torch.topk( logits, 256, dim = -1 )
                #train_batch_crf_loss, topk_probs, topk_indicies = model.toplayer( logits, top_logits, top_indices, \
                #                                                                  top_logits, captions, is_training = True )
                finalized_scores, finalized_tokens, top_probs, top_indices, crf_loss, emissions = \
                    model.inference( imgs, captions, top_indices = None )
                a = config.crf_coef *crf_loss
                #log_probs = log_softmax( logits ).transpose(0,1)
                #train_nll_loss = calc_loss_nll( log_probs, captions )
                b = config.ce_coef * nn.CrossEntropyLoss()( emissions.transpose(1,2), captions )
                loss = a + b

            #hypo_ids = torch.argmax( logits, dim = 2 )
            hypo_ids = finalized_tokens
            
            scaler.scale(loss).backward()
            #scaler.unscale_(optimizer)
            #clip_grad_threshold = 1.0
            #torch.nn.utils.clip_grad_norm_(\
            #        model.parameters(),
            #        clip_grad_threshold)
            # オプティマイザにより，パラメータを更新する
            scaler.step(optimizer)
            scaler.update()            
            scheduler.step()

            #for name, param in model.named_parameters():
            #    print( name )
            
            norm0 = torch.sqrt( torch.norm( model.clip_model.vision_model.encoder.layers[0].self_attn.q_proj.weight.grad, p = 2 ) ).item()
            norm1 = torch.sqrt( torch.norm( model.bert.encoder.layer[23].attention.self.query.weight.grad, p = 2 ) ).item()
            norm_mean = torch.mean( torch.stack ([ torch.sqrt( torch.norm( param.grad, p = 2 ) ) \
                                                  for param in model.parameters() if param.grad is not None ] ) ).item()
            with open(norm_file, 'a') as f:
                print( "epcoch:", epoch, ", step:", global_step, ", norm0:", norm0, ", norm1:", norm1, ", norm_mean:", norm_mean, file=f  )
                f.flush()
            global_step += 1

            n = 0
            hypo_sentence = []
            ref_sentence = []
            hypo_sentence1 = []
            ref_sentence1 = []
            total_error = 0
            total_token_length = 0
            total_bleu = 0
            n2 = 0
            for (hypo_id, caption) in zip( hypo_ids, captions ):
                ##hypo = tokenizer.decode( hypo_id.tolist(), skip_special_tokens = True )
                #hypo = model.my_decode( hypo_id.tolist(), tokenizer )
                #hypo_tokens = tokenizer.tokenize( hypo )
                ##reference = tokenizer.decode( caption.tolist(), skip_special_tokens = True )
                #reference = model.my_decode( caption.tolist(), tokenizer )
                #ref_tokens = tokenizer.tokenize( reference )
                hypo = tokenizer.decode( [token_id for token_id in hypo_id.tolist() if token_id != pad_token_id] )
                reference = tokenizer.decode( [token_id for token_id in caption.tolist() if token_id != pad_token_id] )
                hypo_tokens = tokenizer.tokenize( hypo )
                ref_tokens = tokenizer.tokenize( reference )

                
                # 認識誤りを計算
                (error, substitute, 
                    delete, insert, ref_length) = \
                    levenshtein.calculate_error(hypo_tokens,
                                                    ref_tokens)
                
                # 誤り文字数を累積する
                total_error += error
                # 文字の総数を累積する
                total_token_length += ref_length

                bleu = bleu_score.sentence_bleu( [reference], hypo, smoothing_function=fn  )
        
                total_bleu += bleu                    
                    
                if n < 1 and n_batch == len( train_loader ) - 1 :
                    hypo_sentence.append( hypo )
                    ref_sentence.append( reference )
                if n < 1 and n_batch % train_param == 0:
                    hypo_sentence1.append( hypo )
                    ref_sentence1.append( reference )
                    
                n += 1
                n2 += 1
            
            avg_error = total_error / total_token_length * 100
            avg_bleu = total_bleu / n2 * 100
                
            # 学習時の損失をログに書き込み
            train_losses.append(loss.item())
            train_a.append(a.item())
            train_b.append(b.item())
            train_errors.append( avg_error )
            train_bleus.append( avg_bleu )
            #train_ciders.append( avg_cider )
            if len(train_losses) > config.moving_avg:
                train_losses.popleft()
                train_a.popleft()
                train_b.popleft()
                train_errors.popleft()
                train_bleus.popleft()
                #train_ciders.popleft()
            mean_loss = torch.Tensor(train_losses).mean().item()
            mean_a = torch.Tensor(train_a).mean().item()
            mean_b = torch.Tensor(train_b).mean().item()
            mean_error = torch.Tensor(train_errors).mean().item()
            mean_bleu = torch.Tensor(train_bleus).mean().item()
            pbar.set_postfix({
                'loss': mean_loss,
                'crf': mean_a,
                'ca': mean_b,
                'WER': mean_error,
                'BLEU': mean_bleu,
                #'CIDER': torch.Tensor(train_ciders).mean().item()
            })
            with open(train_loss_file, 'a') as f:
                print(f'{epoch}, {mean_loss}, {mean_a}, {mean_b}, {mean_error}, {mean_bleu}', file=f)
            print_flag = 1
            for ( hypo_se, ref_se ) in zip( hypo_sentence1, ref_sentence1 ):
                if print_flag == 1:
                    print( "lr clip     :", optimizer.param_groups[0]["lr"] )
                    print( "lr connector:", optimizer.param_groups[1]["lr"] )
                    print( "lr bert     :", optimizer.param_groups[2]["lr"] )
                    print( "lr others   :", optimizer.param_groups[3]["lr"] )
                    print_flag = 0
                print(f'Train epoch = {epoch}, loss = {mean_loss}, WER = {mean_error}, BLEU = {mean_bleu}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
                    
            for ( hypo_se, ref_se ) in zip( hypo_sentence, ref_sentence ):
                print(f'Train epoch = {epoch}, loss = {mean_loss}, WER = {mean_error}, BLEU = {mean_bleu}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
    # 学習率を表示
    #print(f'学習率: {optimizer.param_groups[0]['lr']}')
    train_loss = np.mean(train_losses)
    train_a1 = np.mean(train_a)
    train_b1 = np.mean(train_b)
    train_error = np.mean(train_errors )
    train_bleu = np.mean(train_bleus )
    print(f'Train loss: {train_loss}')
    print(f'Train crf: {train_a1}')
    print(f'Train ca: {train_b1}')
    print(f'Train WER: {train_error}')        
    print(f'Train BLEU: {train_bleu}')

    # 検証
    with tqdm(val_loader) as pbar:
        pbar.set_description(f'[検証]')

        # 評価モード
        model.eval()

        #val_losses = []
        #val_losses = deque()
        #val_a = deque()
        #val_b = deque()
        val_errors = deque()
        val_bleus = deque()
        for n_batch, (imgs, captions, caption_lengths) in enumerate( pbar ):

            # ミニバッチを設定
            imgs = imgs.to(config.device)
            captions = captions.to(config.device)
            #caption_lengths = torch.tensor( caption_lengths ).to(config.device)
                
            with torch.no_grad():
                #outputs = model( imgs )
                finalized_scores, finalized_tokens, top_probs, top_indices, crf_loss, emissions = \
                    model.inference( imgs, captions, top_indices = None )

                #hypo_ids = torch.argmax( outputs, dim = 2 )
                hypo_ids = finalized_tokens
                ## 損失の計算
                ## 単語軸が第1軸である必要があるため、転置
                ##outputs = outputs.transpose(1, 2)
                #src_representation = outputs
                ##src_input = imgs
                #src_input = outputs
                #tgt_input = captions
                #train_batch_crf_loss = model.toplayer( src_representation, src_input, tgt_input, is_training = True )
                #log_probs = log_softmax( outputs.transpose(0,1) )
                ##bsz, seq = captions.size()
                ##loss_CA = - criterion_nll( log_prob.view( bsz * seq, -1 ) , captions.view( bsz * seq ) )
                #train_nll_loss = calc_loss_ca( log_probs, captions )
                #a = torch.mean( train_batch_crf_loss )
                #b = config.alpha * torch.mean( train_nll_loss )
                #loss = a + b
               
            n = 0
            hypo_sentence = []
            ref_sentence = []
            hypo_sentence1 = []
            ref_sentence1 = []
            total_error = 0
            total_token_length = 0
            total_bleu = 0
            n2 = 0
            for (hypo_id, caption) in zip( hypo_ids, captions ):
                ##hypo = tokenizer.decode( hypo_id.tolist(), skip_special_tokens = True )
                #hypo = model.my_decode( hypo_id.tolist(), tokenizer )
                #hypo_tokens = tokenizer.tokenize( hypo )
                ##reference = tokenizer.decode( caption.tolist(), skip_special_tokens = True )
                #reference = model.my_decode( caption.tolist(), tokenizer )
                #ref_tokens = tokenizer.tokenize( reference )
                hypo = tokenizer.decode( [token_id for token_id in hypo_id.tolist() if token_id != pad_token_id] )
                reference = tokenizer.decode( [token_id for token_id in caption.tolist() if token_id != pad_token_id] )
                hypo_tokens = tokenizer.tokenize( hypo )
                ref_tokens = tokenizer.tokenize( reference )
       
                # 認識誤りを計算
                (error, substitute, 
                    delete, insert, ref_length) = \
                    levenshtein.calculate_error(hypo_tokens,
                                                ref_tokens)
                    
                # 誤り文字数を累積する
                total_error += error
                # 文字の総数を累積する
                total_token_length += ref_length

                bleu = bleu_score.sentence_bleu( [reference], hypo, smoothing_function=fn  )
        
                total_bleu += bleu

                if n < 1 and n_batch == len( val_loader ) - 1:
                    hypo_sentence.append( hypo )
                    ref_sentence.append( reference )
                        
                if n < 1 and n_batch % val_param == 0:
                    hypo_sentence1.append( hypo )
                    ref_sentence1.append( reference )
                    
                n += 1
                n2 += 1
                
            avg_error = total_error / total_token_length * 100                    
            avg_bleu = total_bleu / n2 * 100

            # 学習時の損失をログに書き込み
            #val_losses.append(loss.item())
            #val_a.append(a.item())
            #val_b.append(b.item())
            val_errors.append( avg_error )
            val_bleus.append( avg_bleu )
            if len(val_errors) > config.moving_avg:
                #val_losses.popleft()
                #val_a.popleft()
                #val_b.popleft()
                val_errors.popleft()
                val_bleus.popleft()
            #mean_loss = torch.Tensor(val_losses).mean().item()
            #mean_a = torch.Tensor(val_a).mean().item()
            #mean_b = torch.Tensor(val_b).mean().item()
            mean_error = torch.Tensor(val_errors).mean().item()
            mean_bleu = torch.Tensor(val_bleus).mean().item()
            pbar.set_postfix({
                #'loss': mean_loss,
                #'crf': mean_a,
                #'ca': mean_b,
                'WER': mean_error,
                'BLEU': mean_bleu,
            })
            # Validation Lossをログに書き込み
            with open(val_loss_file, 'a') as f:
                print(f'{epoch}, {mean_error}, {mean_bleu}', file=f)
            
            for ( hypo_se, ref_se ) in zip( hypo_sentence1, ref_sentence1 ):
                print(f'Val epoch = {epoch}, WER = {mean_error}, BLEU = {mean_bleu}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
                    
            for ( hypo_se, ref_se ) in zip( hypo_sentence, ref_sentence ):
                print(f'Val epoch = {epoch}, WER = {mean_error}, BLEU = {mean_bleu}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
                    
    # Loss 表示
    #val_loss = np.mean(val_losses)
    val_error = np.mean( val_errors )
    val_bleu = np.mean( val_bleus )
    #print(f'Validation loss: {val_loss}')
    print(f'Validation WER: {val_error}')
    print(f'Validation BLEU: {val_bleu}')

    ## より良い検証結果が得られた場合、モデルを保存
            
    # モデルを保存
    torch.save({'epoch': epoch,
                'global_step': global_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,},
        f'{config.save_directory}/model_bert_large_NAR_PAD_sft2_curr.pth')
    ## モデルを保存
        
# モデルを保存
torch.save({'epoch': epoch,
    'global_step': global_step,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'loss': loss,},
    f'{config.save_directory}/model_bert_large_NAR_PAD_sft2_final.pth')


i: 0
i: 100000
i: 200000
i: 300000
i: 400000
i: 500000
config.device: cuda:0
学習セット数: 10149
評価セット数: 1269
テストセット数: 50744
use_amp: True
use_saved_pth: False
initialize self.toplayer
in TopLayer:
num_global_steps: 50745
num_warmup_steps: 10149.0
use_saved_pth: False
exist saved_pth: False
begin_epoch: 0
global_ste: 0
train_param: 101
val_param: 423
epochs: 5
batch_size: 40
lr_clip: 2e-07
lr_con: 0.0001
lr_bert: 2e-05
lr_others: 0.0001
weight_decay: 0.1
betas: (0.9, 0.999)
crf_coef: 1.0
ce_coef: 0.3
use_repeat_logits_half: False
train_loss_file: ./model/MyOriginal_train_loss_20260101_033912.csv


  0%|          | 0/10149 [00:00<?, ?it/s]

lr clip     : 1.9706375012316484e-11
lr connector: 9.853187506158243e-09
lr bert     : 1.9706375012316485e-09
lr others   : 9.853187506158243e-09
Train epoch = 0, loss = 26.072107315063477, WER = 233.25942993164062, BLEU = 16.870601654052734
refe: [CLS] in this picture we can see buildings, trees and vehicles are on the road. [SEP] [unused1] [unused1]
hypo: ##ises bryce secretariat modesrgh exercises incorporatestish renault underworld shrink protegeroi mumbai droplets [unused89] claims lair shrink protegeroi mumbai droplets volley claims lair [unused755] markham prefect protects randolph debatesises participating堂roi mumbai ᵗ sustained kristin yorker bryce antiquity grit fancy prefect protects raj protects raj wagnerroi mumbai droplets fix [unused981] confines edison τ unsafeises antiquity grit contact adventures protects raj wagner insulation exercises incorporatesل [unused472] unsafe fix [unused981] claims lair [unused755] cleantick markham shrink protegeroi mumbai droplets exclusiv

  0%|          | 0/1269 [00:00<?, ?it/s]

Val epoch = 0, WER = 74.3861312866211, BLEU = 38.83528137207031
refe: [CLS] in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky. [SEP] [unused1] [unused1]
hypo: [CLS] in this picture we can see few trees. in the background there is a sky. [SEP] [unused1] [unused1]
Val epoch = 0, WER = 73.30298614501953, BLEU = 41.24933624267578
refe: [CLS] in this picture we can see planets, where we can see few people and some objects. [SEP] [unused1] [unused1]
hypo: [CLS] in this picture we can see a woman foreground uniform leaves coat i can see a. [SEP] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1]
Val epoch = 0, WER = 73.75590515136719, BLEU = 41.32341003417969
refe: [CLS] in this picture we can see some graves and a memorial, in the background there are some trees, we can see christianity symbols here. [SEP] [unused1] [unused1]
hy

  0%|          | 0/10149 [00:00<?, ?it/s]

lr clip     : 1.999950734062469e-07
lr connector: 9.999753670312346e-05
lr bert     : 1.9999507340624694e-05
lr others   : 9.999753670312346e-05
Train epoch = 1, loss = 2.928704023361206, WER = 76.21749114990234, BLEU = 38.710941314697266
refe: [CLS] the image is inside the room. in the image there is a woman sitting on chair and holding a electronic device on her hand, in background we can see a couch, towel and a window which is closed and right side there is a card. [SEP] [unused1] [unused1]
hypo: [CLS] in this image i can see a woman sitting on the there is head we can see a person tapf red 2 tapf lamp pillow top toys
lr clip     : 1.9949748743718594e-07
lr connector: 9.974874371859297e-05
lr bert     : 1.9949748743718595e-05
lr others   : 9.974874371859297e-05
Train epoch = 1, loss = 2.710420608520508, WER = 76.31710815429688, BLEU = 40.806121826171875
refe: [CLS] in this picture there is a person holding the plant. at the back there are trees. at the bottom there is grass on the 

  0%|          | 0/1269 [00:00<?, ?it/s]

Val epoch = 1, WER = 68.5604248046875, BLEU = 42.22783660888672
refe: [CLS] in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky. [SEP] [unused1] [unused1]
hypo: [CLS] in this image we can see plants, plants, plants, plants, there are plants, we can see the background there are trees. [SEP] [unused1] [unused1]
Val epoch = 1, WER = 69.57406616210938, BLEU = 43.80345153808594
refe: [CLS] in this picture we can see planets, where we can see few people and some objects. [SEP] [unused1] [unused1]
hypo: [CLS] in this picture we can see two persons. in front of the background we can see the background there are trees. [SEP] [unused1]
Val epoch = 1, WER = 69.64454650878906, BLEU = 43.55693054199219
refe: [CLS] in this picture we can see some graves and a memorial, in the background there are some trees, we can see christianity symbols here. [

  0%|          | 0/10149 [00:00<?, ?it/s]

lr clip     : 1.4999507340624692e-07
lr connector: 7.499753670312346e-05
lr bert     : 1.4999507340624693e-05
lr others   : 7.499753670312346e-05
Train epoch = 2, loss = 1.318399429321289, WER = 69.16487884521484, BLEU = 42.602210998535156
refe: [CLS] in this picture i can see group of people sitting on the chairs. there are paper, bags and some other objects. there is a woman standing near the podium and there is a mike on the podium, and in the background there are televisions attached to the pillar and there are lights. [SEP] [unused1] [unused1]
hypo: [CLS] in this picture we can see a group of people are so acting in front of them there is a group of people standing and i can see number of people standing and there is clicked
lr clip     : 1.4949748743718594e-07
lr connector: 7.474874371859297e-05
lr bert     : 1.4949748743718595e-05
lr others   : 7.474874371859297e-05
Train epoch = 2, loss = 1.372731328010559, WER = 70.35518646240234, BLEU = 42.054744720458984
refe: [CLS] in this 

  0%|          | 0/1269 [00:00<?, ?it/s]

Val epoch = 2, WER = 89.79296875, BLEU = 47.026336669921875
refe: [CLS] in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky. [SEP] [unused1] [unused1]
hypo: [CLS] in this image we can see plants, plants, plants, we can see number of the road. [SEP] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1]
Val epoch = 2, WER = 92.0322036743164, BLEU = 47.54921340942383
refe: [CLS] in this picture we can see planets, where we can see few people and some objects. [SEP] [unused1] [unused1]
hypo: [CLS] in this image we can see a man. in front of the background we can see a man. [SEP] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1]
Val epoch = 2, WER = 91.22148132324219, BLEU = 47.63743591308594
refe: [CLS] in this pi

  0%|          | 0/10149 [00:00<?, ?it/s]

lr clip     : 9.999507340624692e-08
lr connector: 4.999753670312347e-05
lr bert     : 9.999507340624693e-06
lr others   : 4.999753670312347e-05
Train epoch = 3, loss = 2.380382537841797, WER = 69.23076629638672, BLEU = 40.924827575683594
refe: [CLS] there are buildings. on the left side there is a brick wall. near to that we can see branch of a tree. in the background there is sky. [SEP] [unused1] [unused1]
hypo: [CLS] in this image we can see a building, buildings, we can see the background we can see the sky. [SEP] [unused1]
lr clip     : 9.949748743718592e-08
lr connector: 4.974874371859297e-05
lr bert     : 9.949748743718594e-06
lr others   : 4.974874371859297e-05
Train epoch = 3, loss = 2.129638910293579, WER = 73.48490142822266, BLEU = 43.421112060546875
refe: [CLS] in the middle of the image we can see a man, in front of him we can see a vehicle and a tree, in the background we can find few poles, water, hills, few more trees and clouds. [SEP] [unused1] [unused1]
hypo: [CLS] in 

  0%|          | 0/1269 [00:00<?, ?it/s]

Val epoch = 3, WER = 72.4121322631836, BLEU = 42.126094818115234
refe: [CLS] in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky. [SEP] [unused1] [unused1]
hypo: [CLS] in this image we can see plants, there are plants, there are trees. in the road. [SEP] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1]
Val epoch = 3, WER = 71.71426391601562, BLEU = 43.468719482421875
refe: [CLS] in this picture we can see planets, where we can see few people and some objects. [SEP] [unused1] [unused1]
hypo: [CLS] in this image we can see a person is a person is a person is a wall. [SEP] [unused1] [unused1] [unused1]
Val epoch = 3, WER = 71.52983856201172, BLEU = 43.4749755859375
refe: [CLS] in this picture we can see some graves and a memorial, in the background there are some trees, we can see christianity symbols here. [SEP] [unused1] [u

  0%|          | 0/10149 [00:00<?, ?it/s]

lr clip     : 4.999507340624692e-08
lr connector: 2.499753670312346e-05
lr bert     : 4.999507340624693e-06
lr others   : 2.499753670312346e-05
Train epoch = 4, loss = 1.3410676717758179, WER = 66.93961334228516, BLEU = 45.6710205078125
refe: [CLS] in this image i can see few people are sitting and i can see the cup and few objects on the tables. background is in black color. [SEP] [unused1] [unused1]
hypo: [CLS] in this image, we can see two persons sitting on the background i can see the table. on the table. on the table. on the table. on the table. [SEP] [unused1]
lr clip     : 4.949748743718593e-08
lr connector: 2.4748743718592964e-05
lr bert     : 4.949748743718593e-06
lr others   : 2.4748743718592964e-05
Train epoch = 4, loss = 1.6869996786117554, WER = 77.20124816894531, BLEU = 44.702667236328125
refe: [CLS] a woman is sitting in the chair and working in the laptop she wear a black color dress. [SEP] [unused1] [unused1]
hypo: [CLS] in this image there is a woman sitting and she 

  0%|          | 0/1269 [00:00<?, ?it/s]

Val epoch = 4, WER = 77.90081787109375, BLEU = 44.23005676269531
refe: [CLS] in this image i can see few trees which are green in color, few flowers which are red in color and in the background i can see a person standing, the road, few vehicles, few buildings, few trees and the sky. [SEP] [unused1] [unused1]
hypo: [CLS] in this image we can see plants, there are plants, there are plants. in the road. [SEP] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1]
Val epoch = 4, WER = 78.94837951660156, BLEU = 46.452850341796875
refe: [CLS] in this picture we can see planets, where we can see few people and some objects. [SEP] [unused1] [unused1]
hypo: [CLS] in this image we can see a man is a man is a plant. [SEP] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1] [unused1]
Val epoch = 4, WER = 78.86493682861328, BLEU = 46.22598648071289
refe: [CLS] in this picture we can see some graves and a memorial, i