In [1]:
# -*- coding: utf-8 -*-

#
# Pytorchで用いるDatasetの定義
#

#!pip install janome

# sysモジュールをインポート
import sys

import matplotlib.pyplot as plt
import pandas as pd
import torch
import random
from torch import nn, Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import math
import janome
from janome.tokenizer import Tokenizer
from collections import Counter
from torch.utils.data.sampler import SubsetRandomSampler
import time
import levenshtein
import pickle
from timm.scheduler import CosineLRScheduler
from nltk import bleu_score
#from tqdm import tqdm
from tqdm.notebook import tqdm
from torch import autocast, GradScaler
from typing import Sequence, Dict, Tuple, Union
from transformers import  get_linear_schedule_with_warmup
from transformers import  CLIPVisionModel, get_linear_schedule_with_warmup, BertTokenizer

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device( "cpu")
use_amp = False

In [2]:
path = "../../python_image_recognition-main/6_img_captioning/6_7_ImgCap_pre_trained_Feature_extractor/CLIP_ENCODER/models--google-bert--bert-large-uncased/snapshots/6da4b6a26a1877e173fca3225479512db81a5e5b"
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path = path)

vocab_size = len(tokenizer)
print(vocab_size)

30522


In [3]:
'''
BERT tokenizers utilize specific special tokens for various purposes, primarily to structure input sequences for the BERT model. These tokens are typically added automatically when processing text with a BERT tokenizer, especially when using methods like encode_plus or __call__ from the Hugging Face Transformers library.
The most common special tokens in a BERT tokenizer are:
[CLS] (Classifier Token):
This token is placed at the beginning of the input sequence. Its corresponding final hidden state is used as the aggregate representation of the entire sequence for classification tasks.
[SEP] (Separator Token):
This token is used to separate different segments within an input sequence, such as separating a question from an answer in question-answering tasks, or to mark the end of a single sentence.
[PAD] (Padding Token):
This token is used to pad sequences to a uniform length, ensuring that all sequences in a batch have the same dimensions for efficient processing by the model.
[UNK] (Unknown Token):
This token represents words or subwords that are not found in the tokenizer's vocabulary.
[MASK] (Mask Token):
This token is used during the pre-training phase of BERT for the Masked Language Model (MLM) objective, where a percentage of input tokens are randomly replaced with [MASK], and the model predicts the original token.
When you tokenize text and request the return of special tokens, these tokens will be included in the output token IDs and can be viewed by decoding the token IDs back to strings. For example, if you tokenize a sentence like "Hello world," the output might look something like ['[CLS]', 'Hello', 'world', '[SEP]'] when decoded.
'''

'\nBERT tokenizers utilize specific special tokens for various purposes, primarily to structure input sequences for the BERT model. These tokens are typically added automatically when processing text with a BERT tokenizer, especially when using methods like encode_plus or __call__ from the Hugging Face Transformers library.\nThe most common special tokens in a BERT tokenizer are:\n[CLS] (Classifier Token):\nThis token is placed at the beginning of the input sequence. Its corresponding final hidden state is used as the aggregate representation of the entire sequence for classification tasks.\n[SEP] (Separator Token):\nThis token is used to separate different segments within an input sequence, such as separating a question from an answer in question-answering tasks, or to mark the end of a single sentence.\n[PAD] (Padding Token):\nThis token is used to pad sequences to a uniform length, ensuring that all sequences in a batch have the same dimensions for efficient processing by the model.

In [3]:
class SequenceDataset(Dataset):
    ''' ミニバッチデータを作成するクラス
        torch.utils.data.Datasetクラスを継承し，
        以下の関数を定義する
        __len__: 総サンプル数を出力する関数
        __getitem__: 1サンプルのデータを出力する関数
    feat_scp:  特徴量リストファイル
    label_scp: ラベルファイル
    feat_mean: 特徴量の平均値ベクトル
    feat_std:  特徴量の次元毎の標準偏差を並べたベクトル 
    pad_index: バッチ化の際にフレーム数を合わせる
               ためにpaddingする整数値
    splice:    前後(splice)フレームを特徴量を結合する
               splice=1とすると，前後1フレーム分結合
               するので次元数は3倍になる．
               splice=0の場合は何もしない
    '''
    def __init__(self, 
                 filename,
                 tokenizer,
                 mask_prob = 0.15 * 0.8,
                 arbi_prob = 0.15 * 0.1,
                 ):

        self.mask_token_id = tokenizer.mask_token_id
        self.pad_token_id = tokenizer.pad_token_id
        #print( "self.mask_value:", self.mask_value)
        self.max_idx_en = len( tokenizer )
        #self.device = device
        self.mask_prob = mask_prob
        self.arbi_prob = arbi_prob
        # 元のtoken の mask_prob = 15%を置き換える。

        with open(filename, mode='r', encoding="utf-8") as file_f:
            for n, line in enumerate( file_f ):
                pass
        
        file_max = n
        
        # 特徴量リスト，ラベルを1行ずつ
        # 読み込みながら情報を取得する
        self.input_ids = []
        self.input_lens = []
        self.target_ids = []
        self.target_lens = []
        self.num_data = 0
        with open(filename, mode='r', encoding="utf-8") as file_f:
            for n, line in enumerate( file_f ):
                if n % 100000 == 0:
                    print( "n:", n )
                if n > file_max // 50:
                    break
                en1 = torch.tensor( tokenizer.encode(line), dtype=torch.int, requires_grad = False  )
                #print( "en1:", en1 )
                self.input_ids.append( en1 )
                self.input_lens.append( len( en1 ) )
                en2 = self.masking( en1 )
                self.target_ids.append( en2 )
                self.target_lens.append( len( en2 ) )
                #self.target_ids.append( en1 )
                #self.target_lens.append( len( en1 ) )

                self.num_data += 1

    def __len__(self):
        ''' 学習データの総サンプル数を返す関数
        本実装では発話単位でバッチを作成するため，
        総サンプル数=発話数である．
        '''
        #return self.num_data // 100
        return self.num_data


    def __getitem__(self, idx):
        ''' サンプルデータを返す関数
        本実装では発話単位でバッチを作成するため，
        idx=発話番号である．
        '''
        # 特徴量系列のフレーム数
        input_len = self.input_lens[idx]
        # ラベルの長さ
        target_len = self.target_lens[idx]

        # ラベル
        input = self.input_ids[idx]

        # 発話ID
        target = self.target_ids[idx]

        # 特徴量，ラベル，フレーム数，
        # ラベル長，発話IDを返す
        return (input, target, input_len, target_len )
    
    def masking(self, input_x: torch.Tensor) -> torch.Tensor:

        masks = torch.zeros( input_x.shape[0], device = input_x.device, dtype=torch.bool )
        #pad_length = torch.sum( torch.eq( input_x, self.pad_token_id ).int() )
        pad_length = 0
        #print( input_ids.shape )
        masks[1:-1-pad_length] = torch.rand( (input_x.shape[0] - 2 - pad_length), device=input_x.device ) < self.mask_prob
        #print( masks )
        #mask = (torch.rand( input_ids.shape) < mask_prob )
        masked_input_ids = input_x.clone()
        masked_input_ids[masks] = self.mask_token_id
        while True:
            arbi1 = torch.zeros( input_x.shape[0], device = input_x.device, dtype=torch.bool )
            arbi1[1:-1-pad_length] = torch.rand( ( input_x.shape[0] - 2 - pad_length), device=input_x.device ) < self.arbi_prob
            num_arbi1 = torch.sum( torch.eq( arbi1, True ).int() )
            arbi2 = [ a if m == False else False for a, m in zip( arbi1, masks )]
            num_arbi2 = torch.sum( torch.eq( torch.tensor( arbi2 ), True ).int() )
            if num_arbi1 == num_arbi2:
                break
        #print( num_arbi2 / len( masked_input_ids))
        for i, arbi in enumerate( arbi2 ):
            if arbi == True:
                masked_input_ids[i] = torch.randint( 0, len( tokenizer), size=(1,) )
        
        
        #mlm_labels = input_x.clone()
        #mlm_labels[~masks] = -100
        
        return masked_input_ids

In [4]:
def collate_func(batch: Sequence[Tuple[Union[torch.Tensor, str]]], pad_index ):
    inputs0, targets0, input_lens, target_lens = zip(*batch)

    inputs = []
    targets = []
    for input1, target1, input_len, target_len in zip( inputs0, targets0, input_lens, target_lens ):
        pad_len = max( input_lens ) - input_len 
        input2= F.pad( input1, (0, pad_len), mode='constant', value = pad_index)
        inputs.append( input2 )
        pad_len = max( target_lens ) - target_len
        target2= F.pad( target1, (0, pad_len), mode='constant', value = pad_index)
        targets.append( target2 )

    
    inputs = torch.stack( inputs, dim = 0 )
    targets = torch.stack( targets, dim = 0 )
    input_lens = torch.tensor( input_lens )
    target_lens = torch.tensor( target_lens )
    
    return inputs, targets, input_lens, target_lens


In [5]:
collate_func_lambda = lambda x: collate_func(x, tokenizer.pad_token_id)

In [6]:
train_dataset = SequenceDataset( "train.txt", tokenizer )
print( "train dataset defiened, len( train_dataset):", len( train_dataset) )

# 訓練データのDataLoaderを呼び出す
# 訓練データはシャッフルして用いる
#  (num_workerは大きい程処理が速くなりますが，
#   PCに負担が出ます．PCのスペックに応じて
#   設定してください)


n: 0
train dataset defiened, len( train_dataset): 35439


In [7]:
val_dataset = SequenceDataset( "val.txt", tokenizer  )
print( "val dataset defiened, len( val_dataset):", len( val_dataset ) )

n: 0
val dataset defiened, len( val_dataset): 3938


In [8]:
batch_size = 16

num_workers = 0 if device == torch.device( 'cpu' ) else 8

train_loader = DataLoader(train_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            pin_memory=True,
                            collate_fn = collate_func_lambda)
print( "train_loader defiend" )
# 開発データのDataLoaderを呼び出す
# 開発データはデータはシャッフルしない
val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=num_workers,
                        collate_fn = collate_func_lambda)
print( "val_loader defined" )
print( len( train_loader ))
print( len( val_loader ))

text, target, text_len, target_len = next(iter(train_loader))
#print(text[0], target[0], text_len[0],target_len[0], sep="\n")
print( "text:", text )
#print( "text_len:", text_len )

train_loader defiend
val_loader defined
2215
247
text: tensor([[  101,  2178,  2391,  ...,     0,     0,     0],
        [  101,  1997,  2607,  ...,     0,     0,     0],
        [  101,  2000, 16519,  ...,     0,     0,     0],
        ...,
        [  101, 15847,  1010,  ...,     0,     0,     0],
        [  101,  2612,  2009,  ...,     0,     0,     0],
        [  101,  2009,  2036,  ...,     0,     0,     0]], dtype=torch.int32)


In [36]:
n_word = 0
n_mask = 0
n_arbi = 0
n = 0
for text, target, _, _ in val_loader:
    for te, ta in zip( text, target ):
        n += 1
        #reference0 = []
        #for m in te:
        #    reference0.append( token_list_en[m.item()])
        #    if m == idx_list_en['<end>']:
        #        break
        ##print( "n:", n, "input :", ' '.join(reference ) )
        #reference0 = tokenizer.decode( te)
        #reference0 = tokenizer.tokenize( reference0 )
        
        #reference1 = []
        #for m in ta:
        #    reference1.append( token_list_en[m.item()])
        #    if m == idx_list_en['<end>']:
        #        break
        ##print( "n:", n, "target:", ' '.join(reference ) )
        #reference1 = tokenizer.decode( ta)
        #reference1 = tokenizer.tokenize( reference1 )
        
        for ref0, ref1 in zip( te, ta ):
            if ref0 != tokenizer.pad_token_id and ref1 != tokenizer.pad_token_id:
                n_word += 1
                if ref0 != ref1:
                    if ref1 == tokenizer.mask_token_id:
                        n_mask += 1    
                    if ref1 != tokenizer.mask_token_id:
                        n_arbi += 1
                        #print(" ref0:", ref0, "ref1:",ref1 )
                        #print( "ref0:", ref0, "ref1:", ref1 )
print( "n:", n )
print( "mask ratio:", n_mask / n_word )
print( "arbi ratio:", n_arbi / n_word )
print( "input: ", reference0 )
print( "target:", reference1 )

n: 3938
mask ratio: 0.11199696279422931
arbi ratio: 0.012000264104849626
input:  ['[CLS]', 'these', 'issues', 'will', 'be', 'discussed', 'with', 'israel', 'and', 'considered', 'by', 'the', 'council', 'in', 'coming', 'weeks', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]']
target: ['[CLS]', 'these', 'issues', 'will', 'be', 'discussed', 'with', 'israel', 'and', 'considered', 'by', 'the', 'council', 'in', 'coming', 'weeks', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]']


In [9]:
class PositionalEmbedding(nn.Module):
    '''
    位置埋め込み （Positional embedding）
    dim_embedding: 埋込み次元
    max_len      : 入力の最大系列長
    '''
    def __init__(self, dim_embedding: int, max_len: int=2048):
        super().__init__()

        self.pos_emb = nn.Embedding(max_len, dim_embedding)

    '''
    位置エンコーディングの順伝播
    x: 位置エンコーディングを埋め込む対象のテンソル,
       [バッチサイズ, 系列長, 埋め込み次元]
    '''
    def forward(self, x: torch.Tensor):
        seq = x.shape[1]
        positions = torch.arange(start=0, end=seq, step=1, device=x.device).to(torch.long)
        positions = self.pos_emb(positions)[:seq,:]
        
        return positions

In [10]:
class TransformerEncoder(nn.Module):
    '''
    CaptioningTransformerのコンストラクタ
    dim_embedding  : 埋め込み次元
    dim_feedforward: FNNの中間特徴次元
    num_heads      : マルチヘッドアテンションのヘッド数
    num_layers     : Transformerデコーダ層の数
    vocab_size     : 辞書の次元
    pad_index      : PADのID
    dropout        : ドロップアウト確率
    '''
    def __init__(self, vocab_size: int, dim_embedding: int, dim_feedforward: int,
                 num_heads: int, num_layers: int ):
        super().__init__()

        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=dim_embedding, nhead=num_heads, batch_first=True, activation='gelu', norm_first=True)
            for _ in range(num_layers)
        ])

    ''' CaptioningTransformerの順伝播処理
    features: 画像特徴量 [バッチサイズ, 埋め込み次元]
    captions: 正解キャプション [バッチサイズ, 系列長]

    '''
    def forward(self, src: torch.Tensor, src_mask: torch.Tensor=None, \
                src_key_padding_mask: torch.Tensor=None ):


        for layer in self.encoder_layers:
            src = layer( src, src_mask = src_mask, src_key_padding_mask = src_key_padding_mask,\
                        #is_causal= True )
                        is_causal= False )

        return src

In [11]:
class Transformer(nn.Module):
    def __init__(self, dim_embedding: int, dim_feedforward: int,
                 num_heads: int, num_layers: int, vocab_size: int,
                 pad_index: int, dropout: float=0.5, us_rate: float=2.0 ):
        super().__init__()

        # 単語埋め込み
        self.embed = nn.Embedding(
            vocab_size, dim_embedding, padding_idx=pad_index)

        # 位置エンコーディング
        self.pos_emb = PositionalEmbedding(dim_embedding)        
        
        # dropout
        self.dropout = nn.Dropout( dropout )
        
        self.encoder = TransformerEncoder(vocab_size, dim_embedding, dim_feedforward, num_heads, num_layers)

        # 単語出力分布計算
        self.ln = nn.LayerNorm( dim_embedding )
        self.linear = nn.Linear(dim_embedding, vocab_size)

        self.pad_index = pad_index
        self.num_heads = num_heads
        
    ''' CaptioningTransformerの順伝播処理
    features: 画像特徴量 [バッチサイズ, 埋め込み次元]
    captions: 正解キャプション [バッチサイズ, 系列長]

    '''
    def forward(self, text):

        device = text.device

        src = self.embed( text )
        
        src += self.pos_emb( src )
        src = self.dropout( src )
        src_key_padding_mask = torch.eq(text, self.pad_index)

        #ones = torch.ones( text.size(1) ).to( device = device )
        #src_mask = torch.diag( ones ).bool()
        src_mask = None

        preds = self.encoder( src, src_mask, src_key_padding_mask )
        
        preds = self.ln( preds )
        logits = self.linear( preds )

        return logits

In [9]:
print( device )

cpu


In [12]:
epoch_num = 10
model = Transformer(768, 3072, 12, 6, vocab_size, tokenizer.pad_token_id ).to(device)

input_texts = torch.randint( 0, vocab_size, size=(16, 45))
outputs_logits = model( input_texts )

#criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token_id )
criterion = nn.CrossEntropyLoss( )
#criterion = nn.CTCLoss(blank=idx_list_en['<blank>'], reduction='mean', zero_infinity = True )
#optimizer = optim.Adam(model.parameters(), lr=0.0001 )
#optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay = 0.005 )
#optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay = 0.001 )
#optimizer = optim.Adadelta(model.parameters(), lr = 1.0, weight_decay = 0.0 )
#optimizer = optim.Adadelta(model.parameters(),lr=1.0,rho=0.95,eps=1e-8,weight_decay=0.0)
#optimizer = optim.AdamW( model.parameters(), lr = 1e-4 )
lr = 1e-4
#lr = 1e-3
optimizer = optim.AdamW( model.parameters(), lr = lr )
# 全ステップ数
num_global_steps = len( train_loader ) * epoch_num
print( "num_global_steps:", num_global_steps )
num_warmup_steps = num_global_steps * 0.1
print( "num_warmup_steps:", num_warmup_steps )
#スケジューラーの定義
scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps, num_global_steps )
eps = 1e-4

num_global_steps: 22150
num_warmup_steps: 2215.0


In [13]:
# WarmupとCosine Decayを行うスケジューラを利用
#scheduler = CosineLRScheduler(
#    optimizer, t_initial=epoch_num, lr_min=1e-1,
#    warmup_t=5, warmup_lr_init=5e-2, warmup_prefix=True)

#tr_print_coef = 64000
#tr_save_coef = 1000
#val_print_coef = 240
len_tr_loader = len( train_loader )
len_val_loader = len( val_loader )
tr_print_coef = len_tr_loader // 10
tr_save_coef = len_tr_loader // 100
#tr_print_coef = len_val_loader // 10
#tr_save_coef = len_val_loader // 10
val_print_coef = len_val_loader // 3
print( "len( train_loader ):", len_tr_loader )
print( "len( val_loader ):", len_val_loader )
print( "tr_print_coef:", tr_print_coef )
print( "tr_save_coef:", tr_save_coef )
print( "val_print_coef:", val_print_coef )
#tr_print_coef = 1
#tr_save_coef = 1
#val_print_coef = 1
#train_length = len(train_loader)
#train_int = train_length // tr_print_coef
#print( train_int )
#print( train_length )
#val_length = len(val_loader)
#val_int = val_length // val_print_coef
#print( val_int )
#print( val_length )

history = {"len_tr_loader":[],"len_val_loader":[], "train_loss":[], "val_loss": [], "train_wer": [], "val_wer": [], "train_bleu": [], "val_bleu": [] }
history["len_tr_loader"].append( len_tr_loader )
history["len_val_loader"].append( len_val_loader )
with open("Diag_Mask6.pkl", "wb") as f:
    pickle.dump( history, f )      
n = 0
train_loss = 0
val_loss = 0


# 学習率の減衰やEarly stoppingの
# 判定を開始するエポック数
# (= 最低限このエポックまではどれだけ
# validation結果が悪くても学習を続ける)
lr_decay_start_epoch = 5

# 学習率を減衰する割合
# (減衰後学習率 <- 現在の学習率*lr_decay_factor)
# 1.0以上なら，減衰させない
lr_decay_factor = 0.5

# Early stoppingの閾値
# 最低損失値を更新しない場合が
# 何エポック続けば学習を打ち切るか
early_stop_threshold = 3

# 最も低い損失値，
# そのときのモデルとエポック数を記憶しておく
best_loss = -1
best_model = None
best_epoch = 0

# Early stoppingフラグ．Trueになると学習を打ち切る
early_stop_flag = False
# Early stopping判定用(損失値の最低値が
# 更新されないエポックが何回続いているか)のカウンタ
counter_for_early_stop = 0

fn = bleu_score.SmoothingFunction().method7

# AMP用のスケーラー
scaler = GradScaler(enabled=use_amp)

for epoch in range(epoch_num):
    # early stopフラグが立っている場合は，
    # 学習を打ち切る
    if early_stop_flag:
        print('    Early stopping.'\
            ' (early_stop_threshold = %d)' \
            % (early_stop_threshold))
        #log_file.write('\n    Early stopping.'\
        #        ' (early_stop_threshold = %d)' \
        #        % (early_stop_threshold))
        break

    with tqdm(train_loader) as pbar:
    #with tqdm(val_loader) as pbar:
        pbar.set_description(f'[Train エポック {epoch + 1}]')
    
        model.train()
        #scheduler.step(epoch)
        #print( "Train")
        train_loss = 0
        mean_error = 0
        mean_bleu = 0
        n3 = 0
        for i, ( text, target, text_len, target_len ) in enumerate( pbar ):
            optimizer.zero_grad()
            text = text.to(device)
            target = target.to(device).long()
            
            with autocast(str(device),enabled=use_amp):
                outputs = model( text )
                # 損失の計算
                loss = criterion(outputs.transpose(1,2) + eps, target)

            preds = torch.argmax( outputs, dim = 2 )
            #print( "preds:",preds[0])
            #print( "text:",text[0])
            #print( "target:", target[0] )
                
            # 誤差逆伝播
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            clip_grad_threshold = 5.0
            torch.nn.utils.clip_grad_norm_(\
                    model.parameters(),
                    clip_grad_threshold)
            # オプティマイザにより，パラメータを更新する
            scaler.step(optimizer)
            scaler.update()            
            
            scheduler.step()
            
            #lr = optimizer.param_groups[0]['lr']
            #print( "lr:", lr )

            total_error = 0
            total_token_length = 0
            total_bleu = 0
            n2 = 0       
            hypo_sentence = []
            ref_sentence = []
            for hypo_id, caption in zip( preds, target ):
                hypo = tokenizer.decode( hypo_id, skip_special_tokens = True )
                #print( "hypo_id:", hypo_id )
                hypo_list = tokenizer.tokenize( hypo )
                reference = tokenizer.decode( caption, skip_special_tokens = True )
                ref_list =  tokenizer.tokenize( reference )
                #print( "hypo:", hypo )
                #print( "hypo_list:", hypo_list )
                #print( "reference:", reference )
                #print( "ref_list:", ref_list )
                
                # 認識誤りを計算
                #print( "hypo_list:", hypo_list)
                #print( "ref_list:", ref_list )
                (error, substitute, 
                    delete, insert, ref_length) = \
                    levenshtein.calculate_error(hypo_list,ref_list)

                #print( "error:", error, "ref_length:", ref_length )
                bleu = bleu_score.sentence_bleu( [reference], hypo, smoothing_function=fn )
        
                total_bleu += bleu
    
                total_error += error
                total_token_length += ref_length

                #if n2 < 2 and ( i % tr_print_coef == tr_print_coef -1  or  i == len( train_loader ) -1 ) :
                if n2 < 2 and ( i % tr_print_coef == tr_print_coef -1 ) :
                    hypo_sentence.append( hypo )
                    ref_sentence.append( reference )
                    
                n2 += 1
                
            train_loss += loss.item()
            history["train_loss"].append( loss.item() )
            history["train_wer"].append( total_error/total_token_length * 100)
            mean_error += total_error/total_token_length * 100
            history["train_bleu"].append( total_bleu/ n2 * 100 )
            mean_bleu += total_bleu/ n2 * 100
            n3 += 1
            if i % tr_save_coef == tr_save_coef - 1:
                with open("Diag_Mask6.pkl", "wb") as f:
                    pickle.dump( history, f )            
            if i % tr_print_coef == tr_print_coef - 1:
                lr = optimizer.param_groups[0]['lr']
                #print(f"Train epoch:{epoch+1}  index:{i+1} loss:{train_loss/n3:.10f} WER:{mean_error / n3:.10f} BLEU:{mean_bleu / n3:.10f } lr:{lr}")
                print(f"Train epoch:{epoch+1}  index:{i+1}  loss:{train_loss/n3:.10f}   WER:{mean_error / n3:.10f} BLEU:{mean_bleu / n3 } lr:{lr:.10f}")
            #if i == len( train_loader ) - 1:
            for (hypo_s, refe_s ) in zip( hypo_sentence, ref_sentence ):
                print( "index:", i+1, "target:", refe_s)
                print( "index:", i+1, "hypo  :", hypo_s )
            pbar.set_postfix({
                    'loss': train_loss / n3,
                    'WER': mean_error / n3,
                    'BLEU': mean_bleu / n3
            })

    with tqdm(val_loader) as pbar:
        pbar.set_description(f'[検証]')
        model.eval()
        #for i, ( text, target, text_len, target_len ) in enumerate(val_loader):
        val_loss = 0
        mean_error = 0
        mean_bleu = 0
        n3 = 0
        for i, ( text, target, text_len, target_len ) in enumerate( pbar ):
            text = text.to(device)
            target = target.to(device).long()

            with torch.no_grad():
                outputs = model(text)
                preds = torch.argmax( outputs, dim = 2 )
                loss = criterion( outputs.transpose(1, 2) + eps, target )
           
            total_error = 0
            total_token_length = 0
            total_bleu = 0
            n2 = 0
            hypo_sentence = []
            ref_sentence = []
            for hypo_id, caption in zip( preds, target ):
                #hypo = tokenizer.decode( hypo_id )
                hypo = tokenizer.decode( hypo_id, skip_special_tokens = True )
                #print( "hypo_id:", hypo_id )
                hypo_list = tokenizer.tokenize( hypo )
                #reference = tokenizer.decode( caption )
                reference = tokenizer.decode( caption, skip_special_tokens = True )
                ref_list =  tokenizer.tokenize( reference )
                
                # 認識誤りを計算
                (error, substitute, 
                    delete, insert, ref_length) = \
                    levenshtein.calculate_error(hypo_list,ref_list)

                bleu = bleu_score.sentence_bleu( [reference], hypo, smoothing_function=fn )
        
                total_bleu += bleu
            
                total_error += error
                total_token_length += ref_length

                if n2 < 2 and ( i % val_print_coef == val_print_coef -1 ) :
                    hypo_sentence.append( hypo )
                    ref_sentence.append( reference )
                    
                n2 += 1
        
            val_loss += loss.item()
            history["val_loss"].append( loss.item() )
            history["val_wer"].append( total_error/total_token_length * 100)
            mean_error += total_error/total_token_length * 100
            history["val_bleu"].append( total_bleu / n2 * 100 )
            mean_bleu += total_bleu / n2 * 100
            n3 += 1
            if i % val_print_coef == val_print_coef - 1:
                lr = optimizer.param_groups[0]['lr']
                print(f"Val epoch:{epoch+1}  index:{i+1}  loss:{val_loss/n3:.10f}   WER:{mean_error / n3:.10f} BLEU:{mean_bleu / n3 } lr:{lr:.10f}")
                PATH = './Diag_Mask6_curr.pt'
                torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,},
                     PATH)
                with open("Diag_Mask6.pkl", "wb") as f:
                    pickle.dump( history, f )
            for (hypo_s, refe_s ) in zip( hypo_sentence, ref_sentence ):
                print( "index:", i+1, "target:", refe_s)
                print( "index:", i+1, "hypo  :", hypo_s )
            pbar.set_postfix({
                    'loss': val_loss / n3,
                    'WER': mean_error/ n3,
                    'BLEU': mean_bleu / n3
                })

    epoch_loss = val_loss/n3
    if epoch == 0 or best_loss > epoch_loss:
        # 損失値が最低値を更新した場合は，
        # その時のモデルを保存する
        best_loss = epoch_loss
        torch.save(model.state_dict(), 
                    './best_model_Mask6.pt')
        best_epoch = epoch
        # Early stopping判定用の
        # カウンタをリセットする
        counter_for_early_stop = 0
    else:
        # 最低値を更新しておらず，
        if epoch+1 >= lr_decay_start_epoch:
            # かつlr_decay_start_epoch以上の
            # エポックに達している場合
            if counter_for_early_stop+1 \
                    >= early_stop_threshold:
                # 更新していないエポックが，
                # 閾値回数以上続いている場合，
                # Early stopping フラグを立てる
                early_stop_flag = True
            else:
                # Early stopping条件に
                # 達していない場合は
                # 学習率を減衰させて学習続行
                if lr_decay_factor < 1.0:
                    for i, param_group \
                            in enumerate(\
                            optimizer.param_groups):
                        if i == 0:
                            lr = param_group['lr']
                            dlr = lr_decay_factor \
                                * lr
                            print('    (Decay '\
                                'learning rate:'\
                                ' %f -> %f)' \
                                % (lr, dlr))
                        param_group['lr'] = dlr
                # Early stopping判定用の
                # カウンタを増やす
                counter_for_early_stop += 1
#torch.cuda.synchronize()    

len( train_loader ): 2215
len( val_loader ): 247
tr_print_coef: 221
tr_save_coef: 22
val_print_coef: 82


  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:1  index:221  loss:7.3673456848   WER:174.0606712955 BLEU:7.681044542170696 lr:0.0000099774
index: 221 target: this report would undermine rather improve race.
index: 221 hypo  : 
index: 221 target: there have been times within parliament when on social security were actually quite awkward.
index: 221 hypo  : 
Train epoch:1  index:442  loss:5.4434230433   WER:133.9404583211 BLEU:3.853063383159605 lr:0.0000199549
index: 442 target: mr president - in - of the council, firstly i would like to thank the - in - office of the council for their high level presence this house and for their positive attitude.
index: 442 hypo  : in the, to the of the and.
index: 442 target: we look forward a fruitful dialogue with you on this.
index: 442 hypo  : to.
Train epoch:1  index:663  loss:4.3972555045   WER:113.3481861828 BLEU:3.186496846086864 lr:0.0000299323
index: 663 target: calls for design improvements in hull construction, especially double - hulled vessels, are sensible but take time 

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:1  index:82  loss:2.1155501575   WER:43.8017875741 BLEU:35.31954036322815 lr:0.0001000000
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one they wantgement, on the, are not to for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they have it.
Val epoch:1  index:164  loss:2.0926694303   WER:43.7337955490 BLEU:34.936491130466834 lr:0.0001000000
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the we have from the, the has been the of a cooperation agreement the of europe and the.
index: 164 target: of course commission takes a favourable als of this because, given the importance of the european cultural network, the centre can make a positive contribution to improving the link betw

  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:2  index:221  loss:0.7417906159   WER:29.4334871154 BLEU:72.33693187141314 lr:0.0000988914
index: 221 target: that have been made will not improve the status成 culture in europe, on the contrary bismarck
index: 221 hypo  : the that have been made will not improve the of culture in europe, on the contrary.
index: 221 target: mr president, the chief aim of an internal electricity market must be play.
index: 221 hypo  : mr president, the aim of an internal market must be fair play.
Train epoch:2  index:442  loss:0.7082203149   WER:28.5661682390 BLEU:74.36016350109264 lr:0.0000977828
index: 442 target: any event, for members not aware of the president s position, i will read the official communication and it thereby appear today's minutes :
index: 442 hypo  : in any event, for those members who not aware of the president's position, i will read the communication and it will in today's minutes :
index: 442 target: it concerns the implications of dam on access to fresh water in th

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:2  index:82  loss:0.6445684033   WER:20.1114158669 BLEU:92.02405772880961 lr:0.0000888889
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one hand they want enlargement while, on the other hand, they are not prepared to pay for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they cannot have it both ways.
Val epoch:2  index:164  loss:0.6259904391   WER:19.6050034364 BLEU:92.37643963545725 lr:0.0000888889
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the information we have received from the luxembourg authorities, the centre has been created within the context of a cooperation agreement between the council of europe and the luxembourg authorities.
index: 164 target: of course c

  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:3  index:221  loss:0.4099312987   WER:20.7469723429 BLEU:90.05559834867762 lr:0.0000877803
index: 221 target: we must reject what unacceptable, as our constituents demand.
index: 221 hypo  : we must reject what is unacceptable, as our demand.
index: 221 target: these recommendations, which, for the most part, are consistent with the guidelines in the portuguese presidency's work programme which will therefore be at the lisbon summit, have been rejected or distorted by the committee on economic and monetary affairs of this parliament, by abians which only managed to a fragile unity in its rejection of the idea of a european union which is on convergence national policies and just on the grouping together of markets.
index: 221 hypo  : these recommendations, which, for the most part, are consistent with the guidelines in the portuguese presidency's work programme and which will therefore be discussed at the lisbon, have been rejected or by the committee on economic and moneta

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:3  index:82  loss:0.6134729956   WER:18.8497469398 BLEU:93.49357466083109 lr:0.0000777778
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one hand they want enlargement while, on the other hand, they are not prepared to pay for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they cannot have it both ways.
Val epoch:3  index:164  loss:0.5966466772   WER:18.2687025468 BLEU:93.78947281396239 lr:0.0000777778
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the information we have received from the luxembourg authorities, the centre has been created within the context of a cooperation agreement between the council of europe and the luxembourg authorities.
index: 164 target: of course c

  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:4  index:221  loss:0.3312160546   WER:18.9054527382 BLEU:92.79788652572006 lr:0.0000766692
index: 221 target: grain legumes
index: 221 hypo  : grain legumes
index: 221 target: this relates in particular to the intended agenda, regards the amsterdam leftovers, which suggests that unacceptable inner cabinets will be created in the future. also areas relating the second third, which tend to be reflected in an undirable militaation of the union
index: 221 hypo  : this relates in particular to the intended agenda, as regards the amsterdam left, which suggests that unacceptable will be created in the future. this also areas relating to the second and third pillars, which tend to be reflected in an undesirablerisation of the european union.
Train epoch:4  index:442  loss:0.3308575180   WER:18.8313026892 BLEU:92.93386689663643 lr:0.0000755606
index: 442 target: for its part, the commission published an initial on trafficking in women in, that is say the year immediately after beiji

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:4  index:82  loss:0.5830738254   WER:17.9110142413 BLEU:94.28892827422834 lr:0.0000666667
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one hand they want enlargement while, on the other hand, they are not prepared to pay for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they cannot have it both ways.
Val epoch:4  index:164  loss:0.5655758230   WER:17.2995335002 BLEU:94.73434822332987 lr:0.0000666667
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the information we have received from the luxembourg authorities, the centre has been created within the context of a cooperation agreement between the council of europe and the luxembourg authorities.
index: 164 target: of course c

  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:5  index:221  loss:0.2897276276   WER:17.7825900472 BLEU:94.20799614259661 lr:0.0000655581
index: 221 target: and that is why we to the reached yesterday. that why we sustain that conclusion we continue to the values and the law
index: 221 hypo  : and that is why we came to the conclusion we reached yesterday morning. that is why we that conclusion as we continue to uphold the values and the law.
index: 221 target: it should be noted the regulation on social security systems does at present refer to the question of early retirement and the commission has proposed certain amendments to the regulation which are, however, still before the council.
index: 221 hypo  : it should be noted that the current regulation on social security systems does not at present refer to the question of early retirement and the commission has proposed certain amendments to the regulation which are, however, still pending before the council.
Train epoch:5  index:442  loss:0.2917003923   WER:17.7726

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:5  index:82  loss:0.5955825317   WER:17.4400978587 BLEU:94.63114509576315 lr:0.0000555556
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one hand they want enlargement while, on the other hand, they are not prepared to pay for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they cannot have it both ways.
Val epoch:5  index:164  loss:0.5776998557   WER:16.9045599445 BLEU:95.05066067518385 lr:0.0000555556
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the information we have received from the luxembourg authorities, the centre has been created within the context of a cooperation agreement between the council of europe and the luxembourg authorities.
index: 164 target: of course c

  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:6  index:221  loss:0.2729600422   WER:17.2049862760 BLEU:94.30167737475838 lr:0.0000544470
index: 221 target: they could meet a regular basis to the promotion of human rights in short and long term.
index: 221 hypo  : they could meet on a regular basis to the promotion of human rights in the short and long term.
index: 221 target: i have 194ed secure the wide support for the report.
index: 221 hypo  : i have endeavoured to secure the widest possible support for the report.
Train epoch:6  index:442  loss:0.2743794773   WER:17.3297545791 BLEU:94.39846352414229 lr:0.0000533383
index: 442 target: among others, it calling for equal rights of all matthews between same - sex couples and traditional family, together lowering the age consent for same - relations.
index: 442 hypo  : among others, it is calling for equal rights of all kinds between same - sex couples and the traditional family, together with lowering the age of consent for same - sex relations.
index: 442 target: the 

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:6  index:82  loss:0.5683985992   WER:17.3390529586 BLEU:94.77177873800309 lr:0.0000444444
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one hand they want enlargement while, on the other hand, they are not prepared to pay for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they cannot have it both ways.
Val epoch:6  index:164  loss:0.5509899739   WER:16.7860227006 BLEU:95.1737700545365 lr:0.0000444444
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the information we have received from the luxembourg authorities, the centre has been created within the context of a cooperation agreement between the council of europe and the luxembourg authorities.
index: 164 target: of course co

  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:7  index:221  loss:0.2546310193   WER:16.6106157180 BLEU:94.85359660680466 lr:0.0000433358
index: 221 target: during recent trip by the president european parliament to the middle east, which you mentioned, the various conflicting parties were agreement about asking for a active presence from the european union in the on - peace process the region.
index: 221 hypo  : during the recent trip by the president of the european parliament to the middle east, which you mentioned, the various conflicting parties were in agreement about asking for a more active presence from the european union in the on - going peace process in the region.
index: 221 target: the 2000 - 2006 donnie provision has been made, unfortunately, only a total contribution of eur 2 020 million, and an indicative allocation for each member has been set.
index: 221 hypo  : for the 2000 - 2006 provision has been made, unfortunately, only for a total contribution of eur 2 020 million, and an indicative allocation 

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:7  index:82  loss:0.5787460033   WER:17.4711149569 BLEU:94.6900360738293 lr:0.0000333333
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one hand they want enlargement while, on the other hand, they are not prepared to pay for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they cannot have it both ways.
Val epoch:7  index:164  loss:0.5617745019   WER:16.9267892758 BLEU:95.10615613164333 lr:0.0000333333
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the information we have received from the luxembourg authorities, the centre has been created within the context of a cooperation agreement between the council of europe and the luxembourg authorities.
index: 164 target: of course co

  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:8  index:221  loss:0.2484473283   WER:16.5460591462 BLEU:94.69418724058703 lr:0.0000322247
index: 221 target: i that measures that are proportionate to that objective is the way forward.
index: 221 hypo  : i believe that measures that are proportionate to that objective is the way forward.
index: 221 target: proposal for joint resolution rights
index: 221 hypo  : proposal for a joint resolution on human rights
Train epoch:8  index:442  loss:0.2468857727   WER:16.5927992328 BLEU:94.84163493448595 lr:0.0000311161
index: 442 target: consequently, my group support the rapporteur's and, above all, will particularly support no 2, which proposes the inclusion areas dedicated to the production of seed vetches in the aid system, as this will lead high - quality crops.
index: 442 hypo  : consequently, my group will support the rapporteur's report and, above all, particularly support amendment no 2, which proposes the inclusion of areas dedicated to the production of seed vetches in t

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:8  index:82  loss:0.5721660155   WER:17.2800899777 BLEU:94.87665945319785 lr:0.0000222222
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one hand they want enlargement while, on the other hand, they are not prepared to pay for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they cannot have it both ways.
Val epoch:8  index:164  loss:0.5554680706   WER:16.7722429538 BLEU:95.23327270002187 lr:0.0000222222
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the information we have received from the luxembourg authorities, the centre has been created within the context of a cooperation agreement between the council of europe and the luxembourg authorities.
index: 164 target: of course c

  0%|          | 0/2215 [00:00<?, ?it/s]

Train epoch:9  index:221  loss:0.2398463875   WER:16.4668152866 BLEU:94.63751279653806 lr:0.0000211136
index: 221 target: basically, as as we are concerned, we need to adopt macroeconomic political measures which are based pepsi stimulating demand, public investment onל for improved public. these measures should be geared towards full employment and guarantee rate of employment around 75 % by 2010.
index: 221 hypo  : basically, as far as we are concerned, we need to adopt macroeconomic political measures which are based on stimulating demand, on public investment and on respect for improved public services. these measures should be geared towards full employment and guarantee a rate of employment of around 75 % by 2010.
index: 221 target: after the debacle in seattle, our duty from now on is to spin concerted and ensure that the of world trade develop and are used₤ fight poverty.
index: 221 hypo  : after the debacle in seattle, our duty from now on is to adopt a concerted position and 

  0%|          | 0/247 [00:00<?, ?it/s]

Val epoch:9  index:82  loss:0.5700652283   WER:17.1870050543 BLEU:94.93924801285996 lr:0.0000111111
index: 82 target: on the one disguise wantlargement while, on the hand, they are not prepared pay for it.
index: 82 hypo  : on the one hand they want enlargement while, on the other hand, they are not prepared to pay for it.
index: 82 target: they cannot have it both ways.
index: 82 hypo  : they cannot have it both ways.
Val epoch:9  index:164  loss:0.5532929823   WER:16.7048893404 BLEU:95.27829192758294 lr:0.0000111111
index: 164 target: according to the information we have received from luxembourg authorities, the centre has been created within context of a cooperation between council of europe and the luxembourg authorities.
index: 164 hypo  : according to the information we have received from the luxembourg authorities, the centre has been created within the context of a cooperation agreement between the council of europe and the luxembourg authorities.
index: 164 target: of course c

In [None]:
print( torch.__version__ )