### ライブラリの準備

###モジュールのインポートとGoogleドライブのマウント

In [None]:
import os
import torch.autograd
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import datetime
#from tqdm import tqdm
from tqdm.notebook import tqdm
import pickle
import copy
import gc
import random
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from PIL import Image
import skimage.transform
from collections import deque
from typing import Sequence, Dict, Tuple, Union, Optional
#from argparse import Namespace
from dataclasses import dataclass, field
from sacrebleu.metrics import BLEU
from evaluate import load
import jiwer
#from comet import download_model, load_from_checkpoint

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import models
import torchvision.transforms as T
import torchvision.datasets as dataset
from torchvision.transforms import v2

#from timm.scheduler import CosineLRScheduler
from transformers import  get_linear_schedule_with_warmup

from transformers import BertTokenizer, BertModel, CLIPVisionModel, BertForPreTraining

import sys
from evaluate import load

import util
from torchmetrics.multimodal import CLIPScore
from torch.amp import autocast, GradScaler
from collections import OrderedDict
from rouge_score import rouge_scorer
from pycocoevalcap.cider.cider import Cider
import json
import collections
from collections import Counter
import plotly
import optuna
import torch.optim as optim
import numpy as np
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend
import time
from torch.profiler import profile, record_function, ProfilerActivity
from multiprocessing import Pool

### 位置エンコーディング


In [None]:
class PositionalEmbedding(nn.Module):
    '''
    位置埋め込み （Positional embedding）
    dim_embedding: 埋込み次元
    max_len      : 入力の最大系列長
    '''
    def __init__(self, dim_embedding: int, max_len: int=2048):
        super().__init__()

        self.pos_emb = nn.Embedding(max_len, dim_embedding)

    '''
    位置エンコーディングの順伝播
    x: 位置エンコーディングを埋め込む対象のテンソル,
       [バッチサイズ, 系列長, 埋め込み次元]
    '''
    def forward(self, x: torch.Tensor):
        seq = x.shape[1]
        positions = torch.arange(start=0, end=seq, step=1, device=x.device).to(torch.long)
        positions = self.pos_emb(positions)[:seq,:]
        
        return positions

### 位置エンコーディングの実装

### Transformerデコーダの実装

### CaptioningTransformerの実装

In [None]:
model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained(model_id)
pad_token_id = tokenizer.pad_token_id
cls_token_id = tokenizer.cls_token_id
sep_token_id = tokenizer.sep_token_id
sos_token_id = tokenizer.encode( [ "[unused0]" ] )[1]
eos_token_id = tokenizer.encode( [ "[unused1]" ] )[1]
#sos_token_id = 1
#eos_token_id = 2
a_token_id = tokenizer.encode( "a"  )[1]
#print( a_token_id )
the_token_id = tokenizer.encode( "the" )[1]
and_token_id = tokenizer.encode( "and" )[1]
in_token_id = tokenizer.encode( "in" )[1]
we_token_id = tokenizer.encode( "we" )[1]
i_token_id = tokenizer.encode( "i" )[1]
he_token_id = tokenizer.encode( "he" )[1]
she_token_id = tokenizer.encode( "she" )[1]
it_token_id = tokenizer.encode( "it" )[1]
they_token_id = tokenizer.encode( "they" )[1]
period_token_id = tokenizer.encode( "." )[1]
comma_token_id = tokenizer.encode( "," )[1]
dbl_token_id = tokenizer.encode( '"' )[1]
sgl_token_id = tokenizer.encode( "'" )[1]

# 辞書サイズを保存
vocab_size = len( tokenizer )

class ComputeReward(nn.Module):
    def __init__(self, reward_t = 'ordinary', decode_t = 'ordinary', device="cpu", 
                 repeat_thresh = (3,2,2,2,2), repeat_weight = (1, 1, 1, 1, 1), cider_coef = 1.0, rouge_coef = 1.0, clip_coef = 1.0, 
                 bert_coef = 1.0, use_amp = True ):
        super().__init__()
        self.tokenizer = tokenizer
        self.tgt_lang = "en"
        self.device = device

        self.scorer = Cider()
        self.rougeL = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
        model_name = "distilbert-base-uncased"
        self.bert = load('bertscore' )
        self.metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32").to(self.device)
        for param in self.metric.parameters():
            param.requires_grad = False
        self.reward_t = reward_t
        self.repeat_thresh = repeat_thresh
        self.repeat_weight = repeat_weight
        self.decode_t = decode_t
        self.cider_coef = cider_coef
        self.rouge_coef = rouge_coef
        self.clip_coef = clip_coef
        self.bert_coef = bert_coef
        self.use_amp = use_amp
    
    def _compute_reward_ord(self, preds, targets, imgs2, sources=None):
        """
        Compute reward metric for a batch of prediction and target sentences
        """
        model_name = "distilbert-base-uncased"
        # detokenize (convert to str) preds & targets
        if self.decode_t == 'no-endoftext':
            preds_str = [self.tokenizer.decode(
                [pred[i] for i in range( 1,  len( pred )  ) if not (pred[i-1] == endoftext_token_id and pred[i] == endoftext_token_id) ]
                ) for pred in preds]
            targets_str = [self.tokenizer.decode(
                [target[i] for i in range( 1,  len( target )  ) if not (target[i-1] == endoftext_token_id and target[i] == endoftext_token_id) ]
                ) for target in targets]
        elif self.decode_t == 'no-pad':
            preds1 = preds.clone().to(self.device )
            targets1 = targets.clone().to(self.device )
            preds1[preds1 == eos_token_id] = pad_token_id
            decoded = self.tokenizer.batch_decode(preds1, skip_special_tokens=False)
            preds_str = [s.replace("[PAD]", "").replace("[unused1]", "").strip() for s in decoded]            
            preds_str2 = self.tokenizer.batch_decode(preds1, skip_special_tokens = True )
            targets1[targets1 == eos_token_id] = pad_token_id
            decoded = self.tokenizer.batch_decode(targets1, skip_special_tokens=False)
            targets_str = [s.replace("[PAD]", "").replace("[unused1]", "").strip() for s in decoded]            
        else:
            preds_str = [self.tokenizer.decode(pred) for pred in preds]
            targets_str = [self.tokenizer.decode(target) for target in targets]
        sources_str = [self.tokenizer.decode(source, ref="src") for source in sources] if sources is not None else None
        #print( "preds size:", preds.size() )
        #print( "targets size:", targets.size() )
        
        #print(f'1st target sent: {targets_str[0]}')
        #print(f'1st pred sent: {preds_str[0]}')
        
        # compute reward metric
        seq_len = preds.shape[1]

        pred_dict = { str(i): [item] for i, item in enumerate( preds_str)}
        target_dict = { str(i): [item] for i, item in enumerate( targets_str)}
        score, scores = self.scorer.compute_score(target_dict, pred_dict)
        reward_cider = torch.tensor( scores ).to( self.device )[:,None].expand( -1, seq_len )
       
        reward_rouge = [[self.rougeL.score(target, pred)['rougeL'][0]]  * seq_len for pred, target in zip(preds_str, targets_str)]
        reward_rouge = torch.tensor( reward_rouge ).to( self.device )
        with autocast(str(self.device),enabled=self.use_amp):
            with torch.no_grad():
                processed = self.metric.processor(text=preds_str2, images=imgs2, return_tensors="pt", padding=True, \
                                                  truncation=True, max_length=77, do_resize=False, do_rescale=False ).to(self.device)
                outputs = self.metric.model(**processed)
                # 特徴量の正規化
                image_features = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
                text_features = outputs.text_embeds / outputs.text_embeds.norm(p=2, dim=-1, keepdim=True)
                individual_scores = torch.clamp( (image_features.to(self.device) * \
                                                  text_features.to(self.device)).sum(axis=-1), min=0)
                clip_scores = individual_scores[:,None].expand( -1, seq_len ) / 100.0
                reward_clip = clip_scores
                model_name = 'distilbert-base-uncased' 
                bert_scores = self.bert.compute(
                    predictions=preds_str, 
                    references=targets_str,
                    model_type=model_name,
                    use_fast_tokenizer=True, 
                    lang='en', 
                    device=self.device,
                    batch_size=config.batch_size * config.num_samples,  # メモリ許容範囲で大きく設定
                    rescale_with_baseline=False
                )['f1']
                reward_bert = torch.tensor( bert_scores ).to( self.device )[:,None].expand( -1, seq_len )
        reward = self.cider_coef * reward_cider + self.rouge_coef * reward_rouge \
            + self.clip_coef * reward_clip + self.bert_coef * reward_bert
        reward2 = reward_cider + reward_rouge + reward_bert + reward_clip
        
        return reward, reward2

    def compute_length_reward(self, preds, targets):
        # preds.size(1) を L として取得
        max_len = preds.size(1)
        # 1から始まるインデックスを作成 (長さとして計算するため)
        arange_index = torch.arange(1, max_len + 1, device=self.device).float()
        
        def get_length(tokens):
            # eosの位置を特定
            eos_mask = (tokens == eos_token_id)
            # 最初のeosだけを抽出
            first_eos = (eos_mask.int().cumsum(dim=1) == 1) & eos_mask
            
            # eosがある場合はその位置(1-indexed)を、ない場合は最大長を返す
            lengths = torch.sum(first_eos.float() * arange_index, dim=1)
            # eosが一度も出現しなかった行は 0 になっているので、max_len で埋める
            no_eos_mask = (eos_mask.sum(dim=1) == 0)
            lengths[no_eos_mask] = float(max_len)
            
            # 正規化 (0.0 ~ 1.0)
            return lengths / max_len

        pred_lengths = get_length(preds)
        target_lengths = get_length(targets)
        #
        # MSEの計算 (負の値を報酬とする)
        # reduction='none' なので [batch_size] の形
        reward_lengths = - ((pred_lengths - target_lengths) ** 2)
        
        # [batch_size, 1] にしてから [batch_size, seq_len] に拡張
        reward = reward_lengths.unsqueeze(1).expand(-1, max_len)
        
        return reward

    def unique_ngram_ratio(self, preds):
        bsz, seq_len = preds.size()
        ng = 5
        device = preds.device
    
        # 1. 各シーケンスの有効な長さを特定 (eosまで)
        # cumsumを使って最初のeos以降をマスクする
        eos_mask = (preds == eos_token_id)
        first_eos_idx = (eos_mask.cumsum(dim=1) == 1) & eos_mask
        # eosがない場合はseq_lenとする
        lengths = torch.where(first_eos_idx.any(dim=1), 
                              first_eos_idx.float().argmax(dim=1), 
                              torch.tensor(seq_len, device=device)).unsqueeze(1)
    
        unr_list = []
        
        # n-gramのサイズごとに一括処理
        for n in range(1, ng + 1):
            # unfoldで全n-gramを抽出: (bsz, seq_len - n + 1, n)
            if seq_len < n:
                unr_list.append(torch.zeros(bsz, 1, device=device))
                continue
                
            ngrams = preds.unfold(1, n, 1)
            num_ngrams = ngrams.size(1)
    
            # マスク作成: 有効な長さ（eos前）に含まれるn-gramのみを残す
            # n-gramの開始位置が (length - n + 1) 未満である必要がある
            ngram_indices = torch.arange(num_ngrams, device=device).expand(bsz, -1)
            valid_mask = ngram_indices < (lengths - n + 1)
            
            # 非常に大きい値で無効なn-gramを埋める（uniqueカウントから除外するため）
            # または、バッチを跨いでユニーク判定するためにハッシュ化
            # ここでは各バッチ行ごとにユニーク数を数える必要があるため、
            # 完全なベクトル化には torch.unique の制限（バッチ非対応）を回避する工夫が必要
            
            # 解決策: 各行をユニークなオフセットでシフトして全体でuniqueをとる手法
            # ただし、シンプルさとメモリ効率のため、ngループのみ残すのが現実的です。
            
            # 行ごとのUniqueカウント (Vectorized version of unique per row)
            # 完全にforを消す場合、各n-gramを1つのスカラにパッキング(ハッシュ)して処理
            if n == 1:
                packed = ngrams.squeeze(-1)
            else:
                # 各要素を大きな基数でシフトして足し合わせ、1つの整数にする
                max_val = preds.max() + 1
                coeffs = max_val ** torch.arange(n, device=device)
                packed = (ngrams * coeffs).sum(dim=-1)
    
            # 無効な位置にユニークな（重ならない）負の値を代入
            invalid_fill = -1 - torch.arange(bsz * num_ngrams, device=device).reshape(bsz, num_ngrams)
            packed = torch.where(valid_mask, packed.float(), invalid_fill.float())
    
            # 行ごとにユニーク数をカウント
            # Note: PyTorchのuniqueはバッチ未対応なため、以下が最速の回避策の一つ
            def count_unique_rowwise(t, mask):
                # ソートして隣接要素との差を見ることでユニーク数を算出
                t_sorted, _ = torch.sort(t, dim=1)
                diffs = (t_sorted[:, 1:] != t_sorted[:, :-1]).int()
                # 最初の要素 + 変化があった回数 - 無効値の数(バッチ内の無効分を補正)
                unique_counts = diffs.sum(dim=1) + 1
                # 無効な埋め草（すべてユニークに設定済み）の数を引いて、
                # 有効なものが1つもなければ0にする処理
                invalid_count = (~mask).sum(dim=1)
                return (unique_counts - invalid_count).float()
    
            row_unique = count_unique_rowwise(packed, valid_mask)
            row_total = valid_mask.sum(dim=1).clamp(min=1).float()
            unr_list.append(row_unique / row_total)
    
        # 結果の整形
        unr_tensor = torch.stack(unr_list, dim=1) # (bsz, ng)
        return torch.mean(unr_tensor, dim=1)[:, None].expand(-1, seq_len)

    def calc_ngram_repeat_fast(self, preds):
        # preds が書き換わらないよう、この関数内では元の値を保護する
        bsz, seq_len = preds.size()
        ngram_cnt = torch.zeros(bsz, device=preds.device, dtype=torch.float)
        
        base_ignore = [pad_token_id, eos_token_id, cls_token_id, sep_token_id]
        extra_ignore = [a_token_id, the_token_id, period_token_id, comma_token_id, and_token_id, in_token_id]
    
        for n in range(1, 5):
            if seq_len < n:
                continue
    
            current_ignore_ids = base_ignore + (extra_ignore if n == 1 else [])
            # ignore_mask は bool なので preds に影響しません
            ignore_mask = torch.isin(preds, torch.tensor(current_ignore_ids, device=preds.device))
            
            # 【重要】unfoldの後に .clone() を入れてメモリを切り離す
            ngrams = preds.unfold(dimension=1, size=n, step=1).clone()
            num_ngrams = ngrams.size(1)
            
            ngram_ignore_mask = ignore_mask.unfold(dimension=1, size=n, step=1).any(dim=-1)
            
            if n > 1:
                # 語彙サイズに基づくハッシュ化
                vocab_size_max = max(preds.max().item(), 100000)
                weights = torch.pow(torch.tensor([vocab_size_max], device=preds.device), 
                                    torch.arange(n, device=preds.device)).long()
                hashed_ngrams = (ngrams.long() * weights).sum(dim=-1)
            else:
                hashed_ngrams = ngrams.squeeze(-1).long()
    
            # これで hashed_ngrams を書き換えても、clone 済みのデータなので preds は安全
            invalid_val = -1
            hashed_ngrams[ngram_ignore_mask] = invalid_val
    
            for b in range(bsz):
                b_ngrams = hashed_ngrams[b]
                valid_b_ngrams = b_ngrams[b_ngrams != invalid_val]
                
                if valid_b_ngrams.numel() == 0:
                    continue
                    
                unique_vals, counts = torch.unique(valid_b_ngrams, return_counts=True)
                
                mask = counts >= self.repeat_thresh[n-1]
                ngram_cnt[b] += counts[mask].sum().float() * self.repeat_weight[n-1]
    
        penalty = - torch.clamp(torch.pow(2.0, ngram_cnt - 1.0) / seq_len, max=1.0)
        return penalty[:, None].expand(-1, seq_len)
        
    def forward(self, b2_preds, b2_targets, b2_imgs2,  sources=None, masks=None):
        """
        outputs: batch x len x d_model
        targets: batch x len
        sources: batch x len
        masks:   batch x len
        """
        self.device = b2_preds.device
        # input to device
        targets = b2_targets.to(self.device)
        bsz, seq_len = b2_preds.size()
        eps = 1e-8

        if self.reward_t == 'ordinary':
            reward_ord = self.compute_reward(b2_preds, b2_targets, b2_imgs2, sources)   #  bsz
            reward_repeat = torch.zeros( ( bsz, seq_len ),  device = preds.device, dtype=torch.float ) # bsz * seq_len
            reward_length = torch.zeros( ( bsz, seq_len ),  device = preds.device, dtype=torch.float ) # bsz * seq_len
            reward_unr = torch.zeros( ( bsz, seq_len ),  device = preds.device, dtype=torch.float ) # bsz * seq_len
        elif self.reward_t == "ord+rep":
            reward_ord = self._compute_reward_ord(b2_preds, b2_targets, b2_imgs2, sources)  # bsz * seq_len
            reward_repeat = self.calc_ngram_repeat_fast( b2_preds ) # bsz * seq_len
            reward_length = torch.zeros( ( bsz, seq_len ),  device = preds.device, dtype=torch.float ) # bsz * seq_len
            reward_unr = torch.zeros( ( bsz, seq_len ),  device = preds.device, dtype=torch.float ) # bsz * seq_len
        elif self.reward_t == "ord+rep+len":
            reward_ord, reward_ord2 = self._compute_reward_ord(b2_preds, b2_targets, b2_imgs2, sources)  # bsz * seq_len
            reward_repeat = self.calc_ngram_repeat_fast( b2_preds ) # bsz * seq_len
            reward_length = self.compute_length_reward( b2_preds, b2_targets ) # bsz * seq_len
            reward_unr = torch.zeros( ( bsz, seq_len ),  device = preds.device, dtype=torch.float ) # bsz * seq_len
        elif self.reward_t == 'ord+len':
            reward_ord = self._compute_reward_ord(b2_preds, b2_targets, b2_imgs2, sources)  # bsz * seq_len
            reward_repeat = self.calc_ngram_repeat_fast( b2_preds ) # bsz * seq_len
            reward_length = self.compute_length_reward( b2_preds, b2_targets ) # bsz * seq_len
            reward_unr = torch.zeros( ( bsz, seq_len ),  device = preds.device, dtype=torch.float ) # bsz * seq_len
        elif self.reward_t == 'ord+rep+len+unr':
            reward_ord, reward_ord2 = self._compute_reward_ord(b2_preds, b2_targets, b2_imgs2, sources)  # bsz * seq_len
            reward_repeat = self.calc_ngram_repeat_fast( b2_preds ) # bsz * seq_len
            reward_length = self.compute_length_reward( b2_preds, b2_targets ) # bsz * seq_len
            reward_unr = self.unique_ngram_ratio(b2_preds)
        
        ## apply mask
        #if masks is not None:
        #    masks = masks.to(self.device)
        #    probs, targets = probs[masks], targets[masks]
        #    # outputs, targets = outputs[masks], targets[masks]
        #    reward, preds = reward[masks], preds[masks]
       
        #print(f'loss: {loss.item():.3f} | reward: {reward:.3f}')    
        
        return reward_ord, reward_ord2, reward_repeat, reward_length, reward_unr

In [None]:
def logsumexp(x, dim=1):
    return torch.logsumexp(x.float(), dim=dim).type_as(x)

class DynamicCRF(nn.Module):
    def __init__(self, num_embedding, low_rank=32, beam_size=64, crf_coef=1.0, temp = 0.5, num_samples = 11, ref_t = False ):
        super().__init__()

        #low_rank = num_embedding
        self.E1 = nn.Embedding(num_embedding, low_rank)
        self.E2 = nn.Embedding(num_embedding, low_rank)

        self.vocb = num_embedding
        self.rank = low_rank
        self.beam = beam_size
        self.crf_coef = crf_coef
        self.temp = temp
        self.num_samples = num_samples
        self.ref_t = ref_t
        
    def extra_repr(self):
        return "vocab_size={}, low_rank={}, beam_size={}".format(
            self.vocb, self.rank, self.beam)

    def forward(self, emissions, top_logits, top_indices, targets, masks, beam=None):
        numerator = self._compute_score(emissions, targets, masks)
        denominator = self._compute_normalizer(emissions, targets, masks, beam )
        beam_probs = self._compute_normalizer2(top_logits, top_indices, targets, masks, beam)

        return numerator - denominator, beam_probs
    
    def forward_decoder(self, emissions, masks=None, beam=None):
        return self._viterbi_decode(emissions, masks, beam)

    def _compute_score(self, emissions, targets, masks=None):
        batch_size, seq_len = targets.size()

        emission_scores = emissions.gather(2, targets[:, :, None])[:, :, 0]  # B x T
        transition_scores = (self.E1(targets[:, :-1]) * self.E2(targets[:, 1:])).sum(2)
       
        scores = emission_scores
        scores[:, 1:] += transition_scores
        
        if masks is not None:
            scores = scores * masks.type_as(scores)

        return scores.sum(-1)
        
    def _compute_normalizer(self, emissions, targets=None, masks=None, beam=None):

        eps = 1e-8
        
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        if targets is not None:
            #_emissions = emissions.scatter(2, targets[:, :, None], np.float('inf'))
            _emissions = emissions.scatter(2, targets[:, :, None], float('inf'))
            beam_targets = _emissions.topk(beam, 2)[1]
            beam_emission_scores = emissions.gather(2, beam_targets)
        else:
            beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D; position i - 1, previous step.
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D; position i, current step.
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        for i in range(1, seq_len):
            next_score = score[:, :, None] + beam_transition_matrix[:, i-1]
            next_score = logsumexp(next_score, dim=1) + beam_emission_scores[:, i]
            if masks is not None:
                score = torch.where(masks[:, i:i+1], next_score, score)
            else:
                score = next_score

        return logsumexp(score, dim=1)

    #def _compute_grpo_samples(beam_emission_scores, beam_transition_matrix, beam_targets, targets=None, masks=None, beam=None):
    def _compute_grpo_samples(self, beam_emission_scores, beam_transition_matrix, beam_targets, targets=None, masks=None, beam=None):

        
        eps = 1e-8
        device = beam_emission_scores.device

        beam = beam if beam is not None else self.beam
        batch_size, seq_len = beam_emission_scores.size()[:2]

        traj_tokens, traj_scores = [], []
        finalized_tokens, finalized_scores = [], []
        traj_scores2 = []
        traj_tokens2 = []

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        _score2 = beam_emission_scores[:,0][:,None,:].expand( -1, self.num_samples, -1 ) #B * self.num_samples * K

        for i in range(1, seq_len):
            traj_scores.append(score) # bsz * beam
            traj_scores2.append( _score2 )
            _score2 = _score2[:,:,:,None] + beam_transition_matrix[:, i-1,None,:,:].expand( -1, self.num_samples,-1,-1) 
                    # bsz, self.num_samples, bema, beam

            # greedy selection
            #_score, _index = _score.max(dim=1) # bsz, beam     bsz, beam 

            ## multinomial selection
            B, N, C, W = _score2.shape
            flat_score = _score2.permute(0, 3, 1, 2).reshape(-1, C)
            #flat_score = torch.clamp( flat_score, min = -100, max = 100 )
            probs = F.softmax(flat_score / self.temp, dim=-1)
            _index_flat = torch.multinomial(probs, num_samples=1, replacement = True )  

            _score_flat = torch.gather(flat_score, -1, _index_flat)
            _index2 = _index_flat.view(B, W, self.num_samples).transpose(1,2)
            _score2 = _score_flat.view(B, W, self.num_samples).transpose(1,2)

            _score2 = _score2 + beam_emission_scores[:, i][:,None,:].expand(-1,self.num_samples,-1) # bsz, self.num_samples, beam   

            #if masks is not None:
            #    score = torch.where(masks[:, i: i+1], _score, score)
            #    index = torch.where(masks[:, i: i+1], _index, dummy)
            #else:
            score, index = _score2[:,0,:], _index2[:,0,:]
            traj_tokens.append(index)
            traj_tokens2.append( _index2 )


        all_scores = traj_scores2
        all_scores.append( _score2 )
        all_scores = torch.stack( all_scores, dim = 0 ).transpose( 0, 1 ).to(device) #bsz, seq_len, beam, N
        beam_probs = F.softmax( all_scores.transpose( 2, 3 ), dim = 2 ) #bsz, seq_len, beam, N
        #beam_probs = F.softmax( all_scores.permute( 3, 0, 1, 2 ), dim = 3 ) #N, bsz, seq_len, beam

        # now running the back-tracing and find the best
        best_score, best_index = _score2.max(dim=2) # max( bsz, beam ), bsz, N
        finalized_tokens.append(best_index[:, None, :]) #bsz,1, N
        finalized_scores.append(best_score[:, None, :]) #bsz,1, N

        for idx, scs in zip(reversed(traj_tokens2), reversed(traj_scores2)): # each of seq_len -1, bsz, beam, N 
            previous_index = finalized_tokens[-1]
            finalized_tokens.append(idx.gather(2, previous_index))
            finalized_scores.append(scs.gather(2, previous_index))

        finalized_tokens.reverse() # seq_len, bsz, N
        sampled_beam_idx = torch.cat(finalized_tokens, 1) # seq_len, bsz, N
        finalized_tokens = beam_targets.gather(2, sampled_beam_idx)

        finalized_scores.reverse()
        finalized_scores = torch.cat(finalized_scores, 1)
        finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]

        #return beam_probs, sampled_beam_idx, finalized_tokens, finalized_scores 
        return beam_probs, sampled_beam_idx, finalized_tokens 
    
    '''
    def _viterbi_decode(self, emissions, masks=None, beam=None):
        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # bsz, seq_len, beam, beam

        traj_tokens, traj_scores = [], []
        finalized_tokens, finalized_scores = [], []

        # compute the normalizer in the log-space
        score = beam_emission_scores[:, 0]  # B x K
        #print( "score size:", score.size() )
        dummy = torch.arange(beam, device=score.device).expand(*score.size()).contiguous()

        for i in range(1, seq_len):
            traj_scores.append(score)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1] # bsz, beam, beam
            _score, _index = _score.max(dim=1) # bsz, beam     bsz, beam 
            _score = _score + beam_emission_scores[:, i] # bsz, beam

            if masks is not None:
                score = torch.where(masks[:, i: i+1], _score, score)
                index = torch.where(masks[:, i: i+1], _index, dummy)
            else:
                score, index = _score, _index
            traj_tokens.append(index)

        # now running the back-tracing and find the best
        best_score, best_index = score.max(dim=1)
        finalized_tokens.append(best_index[:, None])
        finalized_scores.append(best_score[:, None])

        for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)):
            previous_index = finalized_tokens[-1]
            finalized_tokens.append(idx.gather(1, previous_index))
            finalized_scores.append(scs.gather(1, previous_index))

        finalized_tokens.reverse()
        finalized_tokens = torch.cat(finalized_tokens, 1)
        finalized_tokens = beam_targets.gather(2, finalized_tokens[:, :, None])[:, :, 0]

        finalized_scores.reverse()
        finalized_scores = torch.cat(finalized_scores, 1)
        finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]

        return finalized_scores, finalized_tokens

   
    '''
    def _compute_many_values(self, emissions, targets, top_indices = None, masks=None, beam=None):

        beam = beam if beam is not None else self.beam
        batch_size, seq_len = emissions.size()[:2]
        
        if top_indices == None:
            beam_emission_scores, beam_targets = emissions.topk(beam, 2)
        else:
            beam_emission_scores = torch.gather( emissions, -1, top_indices )
            beam_targets = top_indices
        
        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # bsz, seq_len, beam, beam

        if not self.ref_t:
            traj_tokens, traj_scores = [], []
            finalized_tokens, finalized_scores = [], []

            # compute the normalizer in the log-space
            score = beam_emission_scores[:, 0]  # B x K
            dummy = torch.arange(beam, device=score.device).expand(*score.size()).contiguous()

            for i in range(1, seq_len):
                traj_scores.append(score)
                _score = score[:, :, None] + beam_transition_matrix[:, i-1] # bsz, beam, beam
                _score, _index = _score.max(dim=1) # bsz, beam     bsz, beam  step i-1 における 256 → 256 の max から 256 への遷移確率と 
                                                # 256 → 256 の前の 256 の max のインデックストークン
                                                # index b * 256 の 位置が i の token で、値が i-1 のtoken   

                _score = _score + beam_emission_scores[:, i] # bsz, beam i における 256 の遷移確率ではない確率を加える。i における 256 の全確率。

                if masks is not None:
                    score = torch.where(masks[:, i: i+1], _score, score)
                    index = torch.where(masks[:, i: i+1], _index, dummy)
                else:
                    score, index = _score, _index
                traj_tokens.append(index)

            all_scores = traj_scores
            all_scores.append( score )
            all_scores = torch.stack( all_scores, dim = 0 ).transpose( 0, 1 )
        
            # now running the back-tracing and find the best
            best_score, best_index = score.max(dim=1)
            finalized_tokens.append(best_index[:, None]) # 時刻 T における b*256 の確率最大の token
            finalized_scores.append(best_score[:, None]) #時刻 T における b*256 の確率最大の score

            for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)): #idx,scs は、反転時刻 i と i-1における b * 256のトークンと確率
                previous_index = finalized_tokens[-1] # 時刻 Tなど 求めたいトークンと確率の一個後 における token　b * 1
                finalized_tokens.append(idx.gather(1, previous_index)) # 時刻 一個後iのトークン previou_index に至るための時刻i-1 のトークン
                                                                    # b* 256 の token から b * 1 の previous_idnex token で gather
                finalized_scores.append(scs.gather(1, previous_index)) # 時刻一個後 i のトークンに至るための時刻 i-1 の確率

            finalized_tokens.reverse()
            finalized_tokens = torch.cat(finalized_tokens, 1)
            finalized_tokens = beam_targets.gather(2, finalized_tokens[:, :, None])[:, :, 0]

            finalized_scores.reverse()
            finalized_scores = torch.cat(finalized_scores, 1)
            finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
        
            if self.crf_coef != 0.0:
                numerator = self._compute_score(emissions, targets)
                denominator = self._compute_normalizer(emissions, targets)
                crf_loss = - ( numerator - denominator ).mean() / seq_len
            else:
                crf_loss = torch.zeros( (1), device = emissions.device, dtype = torch.float )
        
        top_probs, sampled_beam_idx, b_finalized_tokens = \
            self._compute_grpo_samples(beam_emission_scores, beam_transition_matrix, beam_targets)

        if not self.ref_t:
            return finalized_scores, finalized_tokens, top_probs, beam_targets, crf_loss, sampled_beam_idx, b_finalized_tokens
        else:
            return top_probs[:,:,:,0]

In [None]:
class TopLayer(nn.Module):
    def __init__(self, vocab_size, embed_dim, crf_low_rank, crf_beam_size, dropout, padding_idx,
                crf_coef = 1.0, temp = 0.5, num_samples = 10, ref_t = False ):
        super(TopLayer, self).__init__()

        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.padding_idx = padding_idx
        print( "in TopLayer:" )
        self.crf_layer = DynamicCRF(num_embedding = vocab_size, low_rank = crf_low_rank, beam_size = crf_beam_size, 
                                    crf_coef=crf_coef, temp=temp, num_samples= num_samples, ref_t = ref_t )

        #self.one_more_layer_norm = nn.LayerNorm(embed_dim)
        #self.tgt_word_prj = nn.Linear(self.embed_dim, self.vocab_size)
        ## gae 学習用
        #self.linear_critical = nn.Linear(crf_beam_size, 1 )

    def forward(self, src_representation, top_logits, top_indices, src_input, tgt_input, is_training ):
        '''
            src_representation : bsz x seqlen x embed_dim
            src_input : bsz x seqlen
            tgt_input : bsz x seqlen
        '''
        #assert src_input.size() == tgt_input.size()

        src_input = src_input.transpose(0, 1) # src_len x bsz
        #seqlen, bsz = src_input.size()
        seqlen, bsz = src_input.shape[:2]

        src_representation = F.dropout(src_representation, p=self.dropout, training=is_training)
        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim

        src = src_representation

        #emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)
        emissions = src_representation
        #log_probs = torch.log_softmax(emissions, -1)
        #assert log_probs.size() == torch.Size([seqlen, bsz, self.vocab_size])

        emissions = emissions.transpose(0, 1) # [bsz x src_len x vocab_size]
        #emission_mask = ~tgt_input.eq(self.padding_idx) # [bsz x src_len] #pad のところは 0 padでないところが 1
        emission_mask = torch.ones_like( tgt_input, dtype=torch.bool ) #全部　pad でないとして 1
        batch_crf_loss, top_probs = self.crf_layer(emissions, top_logits, top_indices, tgt_input, emission_mask) # [bsz]
        #critical_value = self.linear_critical( top_probs )
        #critical_value = torch.zeros( ( 1,1,1) )
        batch_crf_loss = - batch_crf_loss
        assert batch_crf_loss.size() == torch.Size([bsz])
        return batch_crf_loss, top_probs

    def decoding(self, src_representation, src_input):
        '''
            src_representation : bsz x seqlen x embed_dim
            src_input : bsz x seqlen
            tgt_input : bsz x seqlen
        '''
        src_input = src_input.transpose(0, 1) # src_len x bsz
        seqlen, bsz = src_input.size()

        src_representation = src_representation.transpose(0, 1) # seqlen x bsz x embed_dim
        src = src_representation

        emissions = self.tgt_word_prj(src.contiguous().view(-1, self.embed_dim)).view(seqlen, bsz, self.vocab_size)

        emissions = emissions.transpose(0, 1) # [bsz, seqlen, vocab_size]
        _, finalized_tokens = self.crf_layer.forward_decoder(emissions)
        assert finalized_tokens.size() == torch.Size([bsz, seqlen])
        return finalized_tokens


In [None]:
class CaptioningTransformer(nn.Module):
    
    #CaptioningTransformerのコンストラクタ
    #dim_embedding  : 埋め込み次元
    #dim_feedforward: FNNの中間特徴次元
    #num_heads      : マルチヘッドアテンションのヘッド数
    #num_layers     : Transformerデコーダ層の数
    #vocab_size     : 辞書の次元
    #null_index     : NULLのID
    #dropout        : ドロップアウト確率
    
    def __init__(self, img_size: int,  dim_embedding: int, length_max: int, vocab_size: int, tokenizer, dropout: float = 0.0, \
                 pad_token_id: int=0, use_repeat_logits_half=False, crf_coef = 1.0, temp=0.5, num_samples=10, ref_t = False):
        super().__init__()

        #CLIP
        model_id = "openai/clip-vit-large-patch14-336"
        self.clip_model = CLIPVisionModel.from_pretrained(model_id )
        memory = self.clip_model( torch.randn( 1, 3, 336, 336 ) )
        memory = memory.last_hidden_state
        img_length = memory.size(1)
        clip_dim = memory.size(2)
        self.connector_pool = nn.AdaptiveAvgPool1d(length_max - 1 )
        self.connector_ln = nn.LayerNorm( clip_dim )
        self.connector_linear1 = nn.Linear( clip_dim, dim_embedding )
        self.connector_gleu = nn.GELU()
        self.connector_linear2 = nn.Linear( dim_embedding, dim_embedding )

       
        # Connector
        self.connector_pool = nn.AdaptiveAvgPool1d(length_max - 1 )
       # Down Sampling
        cls_token = memory[:, :1, :] # (bsz, 1, 1024)
        patch_tokens = memory[:, 1:, :] # (bsz, 576, 1024)
        # パッチ部分を 576 -> 96 に圧縮
        patch_tokens = patch_tokens.transpose(1, 2) # (bsz, 1024, 576)
        patch_tokens = self.connector_pool(patch_tokens)
        patch_tokens = patch_tokens.transpose(1, 2) # (bsz, 96, 1024)
        # CLSと結合して合計 97 トークンにする
        memory = torch.cat([cls_token, patch_tokens], dim=1) # (bsz, 97, 1024)

        self.pos_emb = PositionalEmbedding( dim_embedding )

        model_id = "google-bert/bert-large-uncased"
        self.bert = BertModel.from_pretrained( model_id )

        ## 単語出力分布計算
        self.ln_outputs = nn.LayerNorm( dim_embedding )
        self.linear = nn.Linear(dim_embedding, vocab_size)

        crf_low_rank = 32
        crf_beam_size = 256
        self.crf_beam_size = crf_beam_size
        top_dropout = 0.0
        tgt_padding_idx = tokenizer.pad_token_id
        print( "initialize self.toplayer" )
        if ref_t:
            num_samples = 1
        self.ref_t = ref_t
        self.toplayer = TopLayer( vocab_size, dim_embedding, crf_low_rank, crf_beam_size, top_dropout, 
                                  tgt_padding_idx, crf_coef = crf_coef, temp=temp, num_samples=num_samples, ref_t = ref_t )

        self.dim_embedding = dim_embedding
        self.use_repeat_logits_half = use_repeat_logits_half


    def mlp_connector(self, memory ):

        cls_token = memory[:, :1, :] # (bsz, 1, 1024)
        patch_tokens = memory[:, 1:, :] # (bsz, 576, 1024)

        # パッチ部分を 576 -> 96 に圧縮
        patch_tokens = patch_tokens.transpose(1, 2) # (bsz, 1024, 576)
        patch_tokens = self.connector_pool(patch_tokens)
        patch_tokens = patch_tokens.transpose(1, 2) # (bsz, 96, 1024)

        # CLSと結合して合計 97 トークンにする
        memory = torch.cat([cls_token, patch_tokens], dim=1) # (bsz, 97, 1024)

        memory = self.connector_ln( memory )
        memory = self.connector_linear1( memory )
        memory = self.connector_gleu( memory )
        memory = self.connector_linear2( memory )
        
        return memory

    def forward(self, images: torch.Tensor, targets: torch.Tensor, top_indices = None ):

        self.device = images.device
        
        memory = self.clip_model( images ).last_hidden_state
        memory = self.mlp_connector( memory )
        memory += self.pos_emb( memory )
        
        outputs = self.bert( inputs_embeds = memory ).last_hidden_state
        outputs = self.ln_outputs( outputs )
        emissions = self.linear( outputs )

        if self.use_repeat_logits_half == True:
            emissions = repeat_logits_half( emissions )

        if not self.ref_t:
            finalized_scores, finalized_tokens, top_probs, top_indices, crf_loss, sampled_beam_idx, b_finalized_tokens  = \
                self.toplayer.crf_layer._compute_many_values(emissions, targets, top_indices)
            return finalized_scores, finalized_tokens, top_probs, top_indices, \
                 crf_loss, emissions, sampled_beam_idx, b_finalized_tokens
        else:
            top_probs = self.toplayer.crf_layer._compute_many_values(emissions, targets, top_indices)
            return top_probs

    def repeat_logits_half(self, emissions ):
        
        penalty = 1.2
        scores, preds = torch.max( emissions, 2 )
        masks = emissions == scores[:,:,None]
        masks = masks.permute( 1, 0, 2 )
        new_mask = torch.zeros( (  masks.size(1), masks.size(2)), device = emissions.device, dtype=torch.bool )
        new_masks = torch.zeros( ( masks.size(0), masks.size(1), masks.size(2)), device = emissions.device, dtype=torch.bool )
        for i, mask in enumerate( masks ):
            new_mask = torch.logical_or( mask,  new_mask  )
            new_masks[i] = new_mask
        new_masks = new_masks.transpose(0,1)
        first_true_mask = ( new_masks.int().cumsum(dim = 1 ) == 1 ) & new_masks
        new_masks = new_masks & ( ~first_true_mask )

        p_masks = emissions > 0
        m_masks = emissions < 0
        p_new_masks = p_masks & new_masks
        m_new_masks = p_masks & new_masks
        emissions2 = emissions.clone()
        emissions2[p_new_masks] = emissions[p_new_masks] / penalty
        emissions2[m_new_masks] = emissions2[m_new_masks] * penalty

        return emissions2

In [None]:
class MyDataset(Dataset):
    def __init__(self, file_path: str, img_directory: str, transforms, transforms2, tokenizer, length_max = None ) -> None:
        super().__init__()
        self.img_directory = img_directory
        self.transforms = transforms
        self.transforms2 = transforms2 
        # TODO: fix to original data
        #画像の前処理
        self.img_file = []
        self.tokens = []
        #vocab_size = len( tokenizer )
        #c1 = torch.zeros( ( vocab_size ) )
        #c2 = torch.zeros( ( vocab_size, vocab_size ) )
        if length_max == None:
            self.length_max = 0
        else:
            self.length_max = length_max
        length_sum = 0

        with open( file_path, 'rb') as f:
            data = pickle.load(f)
        for i, line_data in enumerate( data ):
            if i % 100000 == 0:
                print( "i:", i )
            self.img_file.append( line_data['img_file'] )
            id_tokens = line_data['id_tokens']
            id_tokens.append( eos_token_id )
            id_tokens.append( eos_token_id )
            length_sum += len( id_tokens )
            if length_max != None:
                id_tokens = torch.tensor( id_tokens )[:self.length_max]
            else:
                if self.length_max < len( id_tokens ):
                    self.length_max = len( id_tokens )
                id_tokens = torch.tensor( id_tokens )
            self.tokens.append( id_tokens )
        # w1, w2 を作る時は length_max = None　でお願いします。
        #    for i2 in range( len(id_tokens) ):
        #        if i2 == len( id_tokens ) - 1:
        #            c1[id_tokens[i2]] += 1
        #        else:
        #            c1[id_tokens[i2]] += 1
        #            c2[id_tokens[i2], id_tokens[i2+1] ] += 1
        '''
        c1avg = int( torch.sum( c1 ) / torch.sum( torch.ne( c1, 0 ).int()) )
        c2avg = int( torch.sum( torch.sum( c2, dim = 1 ), dim = 0 ) / torch.sum( torch.ne( c2, 0 ).int() ) )

        c1[0] = c1avg

        c2[:,0] = c2avg
        c2[0,:] = c2avg
        
        sumc1 = torch.sum( c1, dim = 0 )
        sumc2 = torch.sum( torch.sum( c2, dim = 1 ), dim = 0 )

        prob1 = c1 / sumc1
        prob2 = c2 / sumc2

        self.w1 = prob1 ** -0.4
        self.w1 = torch.nan_to_num( self.w1, nan = 0.0, posinf=0.0, neginf=0.0 )
        avg1 = torch.sum( self.w1, dim = 0 ) / torch.sum( torch.ne( self.w1, 0.0 ).int() )
        self.w1 = self.w1 / avg1

        self.w2 = prob2 ** -0.4
        self.w2 = torch.nan_to_num( self.w2, nan = 0.0, posinf=0.0, neginf=0.0 )
        avg2 = torch.sum( torch.sum( self.w2, dim = 1 ), dim = 0 ) / torch.sum( torch.ne( self.w2, 0.0 ).int() )
        self.w2 = self.w2 / avg2

        with open( "/mnt/ssd2/v7/w_unigrma.pkl", mode="wb" ) as f:
            pickle.dump( self.w1, f )

        with open( "/mnt/ssd2/v7/w_bigrma.pkl", mode="wb" ) as f:
            pickle.dump( self.w2, f )
        
        '''

        #with open( "/mnt/ssd2/v7/w_unigram.pkl", 'rb') as f:
        #    self.w1 = pickle.load(f)

        #with open( "/mnt/ssd2/v7/w_bigram.pkl", 'rb') as f:
        #    self.w2 = pickle.load(f)
        
        if length_max == None:
            print( "length max:", self.length_max )
            print( "avg length:", length_sum / len( self.tokens ) )
    
    # ここで取り出すデータを指定している
    def __getitem__(
        self,
        index: int
    ):
        tokens = self.tokens[index]
        img_file = self.img_file[index] + ".jpg"
        img_path = os.path.join( self.img_directory, img_file ) #index番目の画像のパスを取得
        img = Image.open(img_path) #PIL形式で画像を読み込み
        if img.mode != 'RGB':
            img = img.convert("RGB")
        img1 = self.transforms(img)
        img2 = self.transforms2(img)
        
        return img1, img2, tokens

    # この method がないと DataLoader を呼び出す際にエラーを吐かれる
    def __len__(self) -> int:
        return len(self.tokens)

    def length_max(self):
        return self.length_max

    #def w1(self):
    #    return self.w1

    #def w2(self):
    #    return self.w2

In [None]:
def collate_func(batch: Sequence[Tuple[Union[torch.Tensor, str]]], pad_index, length_max ):
    imgs1, imgs2, tokens = zip(*batch)

    max_length = length_max
    #max_length = 0
    #for target in tokens:
    #    if max_length < len( target ):
    #        max_length = len( target )
    
    targets = []
    lengths = []
    for target in tokens:
        pad_len = max_length - len( target ) 
        #print( "target:", target )
        input2= F.pad( target, (0, pad_len), mode='constant', value = pad_index)
        targets.append( input2 )
        lengths.append( len( target ) )
    
    imgs1 = torch.stack( imgs1, dim = 0 )
    imgs2 = torch.stack( imgs2, dim = 0 )
    targets = torch.stack( targets, dim = 0 )
    lengths = torch.tensor( lengths, requires_grad = False  )

    return imgs1, imgs2, targets, lengths

###学習におけるハイパーパラメータやオプションの設定

In [None]:
class ConfigTrain(object):
    '''
    ハイパーパラメータ、システム共通変数の設定
    '''
    def __init__(self):

        # ハイパーパラメータ
        self.img_size = 336
        self.dim_embedding = 1024   # 埋め込み層の次元
        self.length_max = 97
        #self.lr = 5e-5            # 学習率
        #self.lr = 2e-5            # 学習率
        #self.lr_clip = 2e-7
        #self.lr_bert = 2e-5            # 学習率
        #self.lr_others = 1e-4
        self.lr_clip = 0.0
        self.lr_con = 1.66e-11
        self.lr_bert = 2.19e-9            # 学習率
        self.lr_cri = 7.20e-6
        self.lr_others = 5.82e-8
        #self.clip_thresh_clip = 1
        self.clip_thresh_con = 3.3e-4
        self.clip_thresh_bert = 9e-3            # 学習率
        self.clip_thresh_cri = 8.60e-7
        self.clip_thresh_others = 3.3e-4
        #self.lr_top = 1e-4
        #self.lr = 5e-6            # 学習率
        self.dropout = 0.0         # dropout確率
        #self.batch_size = 160       # ミニバッチ数
        #self.batch_size = 128       # ミニバッチ数
        self.batch_size = 120       # ミニバッチ数
        #self.batch_size = 104
        #self.batch_size = 96       # ミニバッチ数
        #self.batch_size = 80       # ミニバッチ数
        #self.batch_size = 64
        #self.batch_size = 48
        #self.batch_size = 40       # ミニバッチ数
        #self.batch_size = 32       # ミニバッチ数
        #self.batch_size = 24       # ミニバッチ数
        #self.batch_size = 16       # ミニバッチ数
        #self.batch_size = 8       # ミニバッチ数
        self.num_epochs = 1       # エポック数→Colab無料版でテストする際は10未満に修正を推奨
        self.use_amp = True
        #self.use_amp = False
        self.use_saved_pth = True
        #self.use_saved_pth = False
        self.vocab_size = len( tokenizer )
        self.weight_decay = 0.0232
        self.betas = (0.9, 0.999 )
        self.warmup = 0.1
        self.metric = "special" # bleu, meteor, wer, rouge
        self.decode_t = "no-pad"
        self.reward_t = "ord+rep+len+unr"
        #self.reward_t = "ordinary"
        self.clip_range = 0.235
        self.clip_grad_threshold = 1.37
        self.ord_coef = 1.0
        self.cider_coef = 1.0
        self.rouge_coef = 2.53
        self.clip_coef = 1.65
        self.bert_coef = 6.68
        self.rep_coef = 5.84
        self.repeat_thresh = ( 3,2,2,2,2 )
        self.repeat_weight = ( 1,1,1,1,1 )
        self.len_coef = 4.17
        self.unr_coef = 3.79
        self.policy_coef = 1.0
        self.crf_coef = 0.205
        self.ce_coef = 0.661
        self.ent_coef = 0.00269
        self.cri_coef = 0.0 # モンテカルロ法
        self.gae_coef = 0.0 # GAE
        self.kl_coef = 0.0401
        self.target_kl = 8.0
        self.buffer_kl = 1.2
        self.kl_max = 0.1
        self.kl_min = 0.1
        self.gamma = 0.972
        self.lam = 0.974
        self.use_repeat_logits_half = False
        self.use_ce_bert = True
        self.display_include_coef = True
        self.use_adaptive_KL = False
        self.temp = 1.0
        self.num_samples = 8
        self.residual_samples = 3

        
        # パスの設定
        self.img_directory = '/mnt/ssd2/v7/img'
        self.anno_file = '/mnt/ssd2/v7/data.pkl'
        self.save_directory = './model'
        #self.PATH = "model/model_ar_hfgpt2_v7_curr.pth"
        #self.PATH = "../test/model/model_bert_large_NAR_PAD_curr.pth"
        self.PATH = "../test/pre_train_crf/model/model_bert_large_NAR_PAD_sft2_final.pth"
        #self.PATH = "model/model_bert_large_NAR_PAD_curr.pth"
        #self.PATH = "model/model_bert_mask_curr.pth"

        # 検証に使う学習セット内のデータの割合
        self.test_ratio = 0.1
        self.val_ratio = 0.1
        #self.val_ratio = 0.0020
        #self.test_ratio = 0.0020
        
        # 学習に使うデバイス
        #self.device = 'cuda'
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        #self.device = 'cpu'
        
        # データローダーに使うCPUプロセスの数
        #self.num_workers = 4
        #self.num_workers = 0 if self.device == torch.device('cpu') else 10
        #self.num_workers = 0 if self.device == torch.device('cpu') else 4
        self.num_workers = 0
        
        # 移動平均で計算する損失の値の数
        self.moving_avg = 100

### 学習を行う関数

In [None]:
#### os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
config = ConfigTrain()

model_id = "google-bert/bert-large-uncased"
tokenizer = BertTokenizer.from_pretrained(model_id)
pad_token_id = tokenizer.pad_token_id
cls_token_id = tokenizer.cls_token_id
sep_token_id = tokenizer.sep_token_id
sos_token_id = tokenizer.encode( [ "[unused0]" ] )[1]
eos_token_id = tokenizer.encode( [ "[unused1]" ] )[1]
a_token_id = tokenizer.encode( "a" )[1]
the_token_id = tokenizer.encode( "the" )[1]
and_token_id = tokenizer.encode( "and" )[1]
in_token_id = tokenizer.encode( "in" )[1]
we_token_id = tokenizer.encode( "we" )[1]
i_token_id = tokenizer.encode( "i" )[1]
he_token_id = tokenizer.encode( "he" )[1]
she_token_id = tokenizer.encode( "she" )[1]
it_token_id = tokenizer.encode( "it" )[1]
they_token_id = tokenizer.encode( "they" )[1]
period_token_id = tokenizer.encode( "." )[1]
comma_token_id = tokenizer.encode( "," )[1]
dbl_token_id = tokenizer.encode( '"' )[1]
sgl_token_id = tokenizer.encode( "'" )[1]


# 辞書サイズを保存
vocab_size = len( tokenizer )

# モデル出力用のディレクトリを作成
os.makedirs(config.save_directory, exist_ok=True)

# 画像のtransformsを定義
transforms = v2.Compose([
    v2.Resize((336, 336)),
    v2.AutoAugment(),
    #v2.ToTensor(),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    ## Coco データセット 2017 train の平均と標準偏差
    #v2.Normalize((0.456,0.427,0.401),(0.224,0.219,0.231) )
    # ImageNetデータセットの平均と標準偏差
    #v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    # Clip Model の config から引用。
    v2.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

transforms2 = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    #v2.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# v7 データセット
train_dataset = MyDataset( file_path=config.anno_file,
                           img_directory = config.img_directory,
                           transforms=transforms, transforms2 = transforms2,
                           tokenizer=tokenizer, length_max = config.length_max)

# Subset samplerの生成
test_set, val_set, train_set = util.generate_subset_test_val_train(
    train_dataset, config.test_ratio, config.val_ratio )
    
# 学習時にランダムにサンプルするためのサンプラー
train_sampler = SubsetRandomSampler(train_set)

# DataLoaderを生成
collate_func_lambda = lambda x: collate_func(x, tokenizer.pad_token_id, config.length_max)
train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                    sampler=train_sampler,
                    collate_fn=collate_func_lambda)
val_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                    sampler=val_set,
                    collate_fn=collate_func_lambda)
test_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    #batch_size=config.batch_size,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                    sampler=test_set,
                    collate_fn=collate_func_lambda)

torch.backends.cudnn.benchmark = True

print( "config.device:", config.device )
print( "学習セット数:",len( train_loader ) )
print( "評価セット数:",len( val_loader ))
print( "テストセット数:",len( test_loader ))
print( "use_amp:", config.use_amp )
print( "use_saved_pth:", config.use_saved_pth )

# モデルの定義
model = CaptioningTransformer( config.img_size,
    config.dim_embedding, config.length_max, config.vocab_size,
    tokenizer, config.dropout, pad_token_id = tokenizer.pad_token_id,
    use_repeat_logits_half = config.use_repeat_logits_half,
    crf_coef = config.crf_coef, temp=config.temp, num_samples=config.num_samples )
model.to(config.device)

ref_model = CaptioningTransformer( config.img_size,
    config.dim_embedding, config.length_max, config.vocab_size,
    tokenizer, config.dropout, pad_token_id = tokenizer.pad_token_id,
    use_repeat_logits_half = config.use_repeat_logits_half,
    crf_coef = config.crf_coef, temp=config.temp, num_samples=config.num_samples, ref_t = True )
ref_model.to(config.device)

compute_reward = ComputeReward(reward_t = config.reward_t, decode_t = config.decode_t, device = config.device, 
                               repeat_thresh = config.repeat_thresh, 
                          repeat_weight = config.repeat_weight, cider_coef = config.cider_coef, rouge_coef = config.rouge_coef, 
                          clip_coef = config.clip_coef, bert_coef = config.bert_coef, use_amp = config.use_amp )


# 最適化手法の定義
# Optimizerの生成, clipとそうでないモジュールとの
# パラメータで異なる学習率を適用
#params_clip = []
params_con = []
params_bert = []
params_others = []
params_cri = []
for name, parameter in model.named_parameters():
    if parameter.requires_grad:
        if 'clip_model' in name:
            #params_clip.append(parameter)
            parameter.requires_grad = False
        elif 'connector' in name:
            params_con.append(parameter)
        elif 'bert' in name and 'critical' not in name:
            params_bert.append(parameter)
        elif 'critical' in name:
            params_cri.append(parameter)
        else:
            params_others.append(parameter)
param_groups = [
    #{'params': params_clip, 'lr': config.lr_clip},
    {'params': params_con, 'lr': config.lr_con},
    {'params': params_bert, 'lr': config.lr_bert},
    {'params': params_cri, 'lr': config.lr_cri},
    {'params': params_others, 'lr': config.lr_others}]

optimizer = torch.optim.AdamW( param_groups, weight_decay = config.weight_decay, betas=config.betas )
thresh_groups = {
#    'clip': config.clip_thresh_clip,
    'con': config.clip_thresh_con,
    'bert': config.clip_thresh_bert,
    'cri': config.clip_thresh_cri,
    'others':config.clip_thresh_others
}

# 全ステップ数
num_global_steps = len( train_loader ) * config.num_epochs
print( "num_global_steps:", num_global_steps )
print( "warmup:", config.warmup )
num_warmup_steps = num_global_steps * config.warmup
print( "num_warmup_steps:", num_warmup_steps )

scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps, num_global_steps )   

print( "use_saved_pth:", config.use_saved_pth )
print( "PATH:", config.PATH )
print( "exist saved_pth:", os.path.isfile(config.PATH) ) 
use_saved_pth = config.use_saved_pth
if use_saved_pth and os.path.isfile(config.PATH):
    checkpoint = torch.load(config.PATH, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'], strict = False)
    print( "model parameters were loaded")
    ref_model.load_state_dict(checkpoint['model_state_dict'], strict = False)
    ref_model.eval() # 必須：DropoutやBatchNormを無効化
    for param in ref_model.parameters():
        param.requires_grad = False # 必須：メモリ節約と誤学習防止
    ref_model = ref_model.to(config.device )
    print( "ref_model parameters were loaded")

    #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ## optimizerのstateを現在のdeviceに移す。これをしないと、保存前後でdeviceの不整合が起こる可能性がある。
    #for state in optimizer.state.values():
        #for k, v in state.items():
            #if isinstance(v, torch.Tensor):
                #state[k] = v.to(device)
    #begin_epoch = checkpoint['epoch']
    #loss = checkpoint['loss']
    #global_step = checkpoint['global_step']    
    begin_epoch = 0
    global_step = 0
else:
    begin_epoch = 0
    global_step = 0


file_param = 10
print( "begin_epoch:", begin_epoch )
print( "global_step:", global_step )
print( "file_param:", file_param )

def get_nearest_multiple(a, b):
    """
    a に最も近い b の倍数を求める
    """
    # a/b を四捨五入して、それに b を掛ける
    # round() は .5 の場合、偶数側に丸める性質があるため、
    # 厳密な四捨五入が必要な場合は整数演算を使用する
    return round(a / b) * b

len_tr_loader = len( train_loader )
#train_param = len_tr_loader // 30
train_param = len_tr_loader // 100
#train_param = len_tr_loader // 3000
len_val_loader = len( val_loader )
#train_param = len_val_loader // 3
train_param = get_nearest_multiple( train_param, file_param )
#train_param = 1
val_param = len_val_loader // 3
print( "train_param:", train_param )
print( "val_param:", val_param )

print( "epochs:", config.num_epochs )
print( "batch_size:", config.batch_size )
print( "lr_clip:", config.lr_clip )
print( "lr_con:", config.lr_con )
print( "lr_bert:", config.lr_bert )
print( "lr_cri:", config.lr_cri )
print( "lr_others:", config.lr_others )
if config.clip_grad_threshold == 0.0:
    print( 'clip_thresh_con:', config.clip_thresh_con )
    print( 'clip_thresh_bert:', config.clip_thresh_bert )
    print( 'clip_thresh_cri:', config.clip_thresh_cri )
    print( 'clip_thresh_others:',config.clip_thresh_others )
print( "weight_decay:", config.weight_decay )
print( "betas:", config.betas )
print( "reward_type:", config.reward_t )
print( "decode_type:", config.decode_t )
print( "clip_range ppo clip:", config.clip_range )
print( "clip_grad_threshold gradient norm:", config.clip_grad_threshold)
print( "ord_coef:", config.ord_coef )
print( "cider_coef:", config.cider_coef )
print( "rouge_coef:", config.rouge_coef )
print( "clip_coef:", config.clip_coef )
print( "rep_coef:", config.rep_coef )
print( "repeat_thresh:", config.repeat_thresh )
print( "repeat_weight:", config.repeat_weight )
print( "len_coef:", config.len_coef )
print( "unr_coef:", config.unr_coef )
print( "policy_coef:", config.policy_coef )
print( "crf_coef:", config.crf_coef )
print( "ce_coef:", config.ce_coef )
print( "ent_coef:", config.ent_coef )
#print( "cri_coef:", config.cri_coef )
print( "gae_coef:", config.gae_coef )
print( "kl_coef:", config.kl_coef )
print( "target_kl:", config.target_kl )
print( "buffer_kl:", config.buffer_kl )
print( "kl_max:", config.kl_max )
print( "kl_min:", config.kl_min )
print( "gamma:", config.gamma )
print( "lambda:", config.lam )
print( "use_repeat_logits_half:", config.use_repeat_logits_half )
print( "use_ce_bert:", config.use_ce_bert )
print( "display_include_coef:", config.display_include_coef )
print( "temp:", config.temp )
print( "num_samples:", config.num_samples )
print( "residual_samples:", config.residual_samples )

# 学習経過の書き込み
now = datetime.datetime.now()
train_loss_file = '{}/MyOriginal_train_loss_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))
with open(train_loss_file, 'a') as f:
    print(f'{len_tr_loader}', file=f) 
print( "train_loss_file:", train_loss_file )
val_loss_file = '{}/MyOriginal_val_loss_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))
with open(val_loss_file, 'a') as f:
    print(f'{len_val_loader}', file=f) 
norm_file = '{}/norm_{}.csv'\
    .format(config.save_directory, now.strftime('%Y%m%d_%H%M%S'))

# 学習
val_loss_best = float('inf')

pad_token_id = tokenizer.pad_token_id

# AMP用のスケーラー
scaler = GradScaler(enabled=config.use_amp)

#torch.autograd.set_detect_anomaly(True)

eps = 1e-8
last_sample_log_probs = 0
model_name = "distilbert-base-uncased"

def entropy_func(probs):
    input_probs = torch.clamp( probs, eps )
    log_probs = torch.log( input_probs )
    p_log_p = probs * log_probs
    return - p_log_p.sum(-1)

def baseline( probs, targets, top_k ):
    probs_k, preds_k = torch.topk( probs, dim = 2, k = top_k )
    renorm_probs_k = probs_k / torch.sum( probs_k, dim = 2 )[:,:, None]
    base_ = torch.stack( [ probs_k[:,:,k] * compue_reward._compute_reward(preds_k[:,:,k], targets, sources = None ) \
                        for k in range( top_k ) ], dim = 0 )
    base = torch.sum( base_, dim = 0 )
    
    return base

def custom_gradient_clipping(clip_params, gpt2_params, cri_params, others_params, 
                             clip_threshold, gpt2_threshold, cri_threshold, others_threshold):
    # エンコーダーの勾配クリッピング
    if clip_params:
        # torch.nn.utils.clip_grad_norm_ は、与えられたパラメータのリストに勾配クリッピングを適用する
        torch.nn.utils.clip_grad_norm_(clip_params, clip_threshold)
    # デコーダーの勾配クリッピング
    if gpt2_params:
        torch.nn.utils.clip_grad_norm_(gpt2_params, gpt2_threshold)
    # デコーダーの勾配クリッピング
    if cri_params:
        torch.nn.utils.clip_grad_norm_(cri_params, cri_threshold)
    # デコーダーの勾配クリッピング
    if gpt2_params:
        torch.nn.utils.clip_grad_norm_(others_params, others_threshold)

def my_index( list1, target ):
    if target in list1:
        return list1.index( target )
    else:
        return -1

for epoch in range(config.num_epochs):
    with tqdm(train_loader) as pbar:
    #with tqdm(val_loader) as pbar:
        pbar.set_description(f'[エポック {epoch + 1}]')

        # 学習モードに設定
        model.train()

        train_losses = deque()
        train_policys = deque()
        train_entropies = deque()
        train_critics = deque()
        train_kl_divs = deque()
        train_rewards = deque()
        train_rewards2 = deque()
        train_ord = deque()
        train_repeat = deque()
        train_length = deque()
        train_adv = deque()
        train_errors = deque()
        train_bleus = deque()
        train_crfs = deque()
        train_ces = deque()
        train_clips = deque()
        train_unrs = deque()
        train_berts = deque()

        start = time.time()
        for n_batch, (imgs, imgs2, captions, caption_lengths) in enumerate( pbar ):
            #print( "captions[0]:", captions[0] )
            # ミニバッチを設定
            imgs = imgs.to(config.device)
            imgs2 = imgs2.to(config.device)
            captions = captions.to(config.device)
            
            if imgs.size(0) != config.batch_size:
                print( f"bsz {imgs.size(0)} is not batch_size {config.batch_size}. skip")
                continue

            #if imgs.size(0) != config.batch_size:
            #    print( f"bsz {imgs.size(0)} is not batch_size {config.batch_size}. skip")
            #    continue
            
            #optimizer.zero_grad()
            optimizer.zero_grad(set_to_none=True)
            #start_time0 = time.time()
            
            # 最後の単語から次を予測する必要はないため最後の単語を除外
            with autocast(str(config.device),enabled=config.use_amp):
                #start_time = time.time()
                finalized_scores, finalized_tokens, top_probss, top_indices, \
                crf_loss, bert_logits, sampled_beam_idxs, b_finalized_tokens  = \
                model( imgs, captions, top_indices = None )
                #end_time = time.time()
                #print( "model time:", end_time - start_time )
                #preds = finalized_tokens
                bsz, seq_len, beam, N = top_probss.size()
                top_probs = top_probss[:,:,:,0]
                sampled_beam_idx = sampled_beam_idxs[:,:,0].unsqueeze( -1 ) 
                if config.use_ce_bert == False:
                    ce_tensor = torch.full((bsz, seq_len, vocab_size), float(eps), device=config.device)
                    ce_tensor = torch.scatter( ce_tensor, 2, top_indices, top_probs )
                    ce_tensor = torch.clamp( ce_tensor, eps )
                    log_ce_tensor = torch.log( ce_tensor )
                hypo_ids = finalized_tokens
                with torch.no_grad():
                    preds = b_finalized_tokens[:,:,0]
                    b2_finalized_tokens = b_finalized_tokens[:,:,:].permute( 2, 0, 1 ).reshape( N * bsz, seq_len )
                    b2_imgs2 = imgs2[None].expand( N, -1, -1, -1, -1 ).reshape( N * bsz, 3, 224, 224 )
                    b2_captions = captions[None].expand(N,bsz,seq_len).reshape( N * bsz, seq_len)
                    b_ord, b_ord2, b_repeat, b_length, b_unr \
                        = compute_reward( b2_finalized_tokens, b2_captions, b2_imgs2 )
                    b_ord = b_ord.view( N, bsz, seq_len )
                    b_ord2 = b_ord2.view( N, bsz, seq_len )
                    b_repeat = b_repeat.view( N , bsz, seq_len )
                    b_length = b_length.view( N, bsz, seq_len )
                    b_unr = b_unr.view( N , bsz, seq_len ) 
                with torch.no_grad():
                    ref_captions = preds
                    #start_time = time.time()
                    ref_top_probs = ref_model( imgs, ref_captions, top_indices = top_indices )
                    #end_time = time.time()
                    #print( "ref_model time:", end_time - start_time )
                    
                # 1. Policy側の対数確率（学習対象）
                tmp = torch.clamp( top_probs, min = eps )
                top_log_probs = torch.log( tmp )
                policy_lp = torch.gather(top_log_probs, -1, sampled_beam_idx).squeeze(-1) # lp は log_prob の略と思われる。 bsz * seq_len
                
                # 2. Reference側の対数確率（固定）
                tmp = torch.clamp( ref_top_probs, min = eps )
                ref_top_log_probs = torch.log(tmp)
                ref_lp = torch.gather(ref_top_log_probs, -1, sampled_beam_idx).squeeze(-1)
                
                if config.decode_t == "no-pad":
                    lengths = []
                    for pred in preds:
                        length = my_index( pred.tolist(), eos_token_id )
                        if length != -1:
                            lengths.append( length )
                        else:
                            lengths.append( 0 )
                    lengths = torch.tensor( lengths, device = config.device )[:,None].expand( -1, preds.size(1) )
        
                    arange1 = torch.arange( 0, preds.size(1), device = config.device )
                    arange1 = arange1[None,:].expand( preds.size(0), -1 )
                    
                    masks = arange1 < ( lengths + 2 )
                
                    kl_divs = ( policy_lp - ref_lp ) * masks.float()
                else:
                    kl_divs = policy_lp - ref_lp 
                
                b_rewards = config.ord_coef * b_ord + config.rep_coef * b_repeat + config.len_coef * b_length \
                    + config.unr_coef * b_unr
                rewards = torch.mean( b_rewards, dim = 0 )
                rewards2 = torch.mean( (b_ord2 + b_repeat + b_length + b_unr ), dim = 0 )

                
                tmp = torch.clamp( top_probss, eps )
                top_log_probss = torch.log( tmp )
                sample_log_probs = torch.gather( top_log_probss, 2, sampled_beam_idxs[:,:,None,:] )[:,:,0,:]
                sample_log_probs = sample_log_probs.permute(2,0,1)
                if global_step == 0:
                    last_sample_log_probs = sample_log_probs.detach()

                  
                # 1. Advantageの計算とTop-K選別
                with torch.no_grad():
                    mean = b_rewards.mean(dim=0, keepdim=True)
                    std = b_rewards.std(dim=0, keepdim=True)
                    advantages_norm = (b_rewards - mean) / (std + eps)

                
                ## 2. 確率比 (Importance Ratio) の計算
                ## r_t(θ) = π_θ(a|s) / π_old(a|s)
                _, sample_idx = torch.topk( torch.abs( advantages_norm ), config.residual_samples, dim = 0 ) # 3, b, seq_len
                topk_advantages_norm = torch.gather( advantages_norm, 0, sample_idx )
                ratio = torch.exp(sample_log_probs - last_sample_log_probs)
                topk_ratio = torch.gather( ratio, 0, sample_idx) # re_sample, b, seq_len

                # 2. 通常の目的関数とクリップされた目的関数
                surr1 = topk_ratio * topk_advantages_norm
                surr2 = torch.clamp(topk_ratio, 1.0 - config.clip_range, 1.0 + config.clip_range) * topk_advantages_norm
                
                # 3. 損失の計算 (最大化したいのでマイナスを付ける)
                # GRPOの論文式に基づき、min(surr1, surr2) を最大化する
                loss_per_sample = -torch.min(surr1, surr2)

                # 4. 最終的な Policy Loss
                policy_loss = loss_per_sample.mean()

                entropy = entropy_func(top_probs)
                entropy_loss = - torch.mean(entropy)

                kl_per_sample = torch.sum(kl_divs, dim=1) / (torch.sum(masks, dim=1) + 1e-8)
                kl_div_loss = torch.mean(kl_per_sample)
                if config.use_adaptive_KL:
                    if n_batch % 100 == 0:
                        if kl_div_loss < config.target_kl / config.buffer_kl:
                            print(f"kl_coef:", config.kl_coef )
                            print(f"Update kl (low KL): {config.kl_coef} -> {config.kl_coef / 2.0}")
                            config.kl_coef = max(config.kl_coef / 2.0, config.kl_min) # 下限 0.05
                            print(f"kl_coef:", config.kl_coef )
                        elif kl_div_loss > config.target_kl * config.buffer_kl:
                            print(f"kl_coef:", config.kl_coef )
                            print(f"Update kl (high KL): {config.kl_coef} -> {config.kl_coef * 2.0}")
                            config.kl_coef = min(config.kl_coef * 2.0, config.kl_max)  # 上限 5.0
                            print(f"kl_coef:", config.kl_coef )
                
                loss =  config.policy_coef * policy_loss + config.ent_coef * entropy_loss + config.kl_coef * kl_div_loss

                if config.crf_coef != 0.0:
                    loss =  loss + config.crf_coef * crf_loss
                
                if config.ce_coef != 0.0:
                    if config.use_ce_bert:
                        ce_loss = nn.CrossEntropyLoss()( bert_logits.transpose(1,2), captions )
                    else:
                        ce_loss = nn.NLLLoss()( log_ce_tensor.view( bsz * seq_len, -1 ), captions.view( bsz * seq_len ) )
                    loss =  loss + config.ce_coef * ce_loss

                with torch.no_grad():
                    last_sample_log_probs = sample_log_probs.detach()                    

            #start_time = time.time()
            torch.cuda.synchronize()
            scaler.scale(loss).backward()
            #end_time = time.time()
            #print( "backward time:", end_time - start_time )
            loss_item = loss.item()
            reward_item = rewards.mean().item()
            reward2_item = rewards2.mean().item()
            adv_item = advantages_norm.mean().item()
            if config.display_include_coef:
                policy_item = config.policy_coef * policy_loss.item()
                entropy_item = config.ent_coef * entropy_loss.item()
                #critic_item = config.gae_coef * gae_loss.item()
                critic_item = 0
                kl_div_item = config.kl_coef * kl_div_loss.item()
                ord_item = config.ord_coef * b_ord.mean().item()
                repeat_item = config.rep_coef * b_repeat.mean().item()
                length_item = config.len_coef * b_length.mean().item()
                if config.crf_coef != 0.0:
                    crf_item = config.crf_coef * crf_loss.mean().item()
                else:
                    crf_item = 0.0
                if config.ce_coef != 0.0:
                    ce_item = config.ce_coef * ce_loss.mean().item()
                else:
                    ce_item = 0.0
                unr_item = config.unr_coef * b_unr.mean().item()
            else:
                policy_item = policy_loss.item()
                entropy_item = entropy_loss.item()
                #critic_item = gae_loss.item()
                critic_item = 0
                kl_div_item = kl_div_loss.item()
                ord_item = b_ord.mean().item()
                repeat_item = b_repeat.mean().item()
                length_item = b_length.mean().item()
                if config.crf_coef != 0.0:
                    crf_item = crf_loss.mean().item()
                else:
                    crf_item = 0.0
                if config.ce_coef != 0.0:
                    ce_item = ce_loss.mean().item()
                else:
                    ce_item = 0.0
                unr_item = b_unr.mean().item()
            del loss, rewards, b_ord, b_repeat, b_length
            del policy_loss, entropy_loss, kl_div_loss, topk_advantages_norm, #mean_advantages_norm
            del sample_log_probs, advantages_norm, ratio, topk_ratio, surr1, surr2
            del finalized_scores, finalized_tokens
            del kl_divs, masks, kl_per_sample
            torch.cuda.empty_cache()
            scaler.unscale_(optimizer)
            if config.clip_grad_threshold != 0.0:
                torch.nn.utils.clip_grad_norm_(\
                   model.parameters(),
                   config.clip_grad_threshold)
            else:
                custom_gradient_clipping(
                    #params_clip, 
                    params_con, 
                    params_bert,
                    params_cri,
                    params_others,
                    #thresh_groups['clip'], 
                    thresh_groups['con'], 
                    thresh_groups['bert'], 
                    thresh_groups['cri'], 
                    thresh_groups['others'], 
                )
          
            # オプティマイザにより，パラメータを更新する
            #for i, params in enumerate(model.parameters()):
            #    params.grad = grad[i]
            #print( "model parameters i:",i )
            #norm0 = torch.sqrt( torch.norm( model.clip_model.vision_model.encoder.layers[0].self_attn.q_proj.weight.grad, p = 2 ) ).item()
            #norm1 = torch.sqrt( torch.norm( model.clip_model.vision_model.encoder.layers[12].self_attn.q_proj.weight.grad, p = 2 ) ).item()
            #norm0 = torch.sqrt( torch.norm( model.clip_model.vision_model.encoder.layers[0].self_attn.q_proj.weight.grad, p = 2 ) ).item()
            #norm1 = torch.sqrt( torch.norm( model.clip_model.vision_model.encoder.layers[12].self_attn.q_proj.weight.grad, p = 2 ) ).item()
            norm0 = 0
            norm1 = 0
            norm2 = torch.norm( model.bert.encoder.layer[11].attention.self.query.weight.grad, p = 2 ).item()
            norm3 = torch.norm( model.bert.encoder.layer[23].attention.self.query.weight.grad, p = 2 ).item()
            norm_mean = torch.mean( torch.stack ( [ torch.norm( param.grad, p = 2 ) \
                                                  for param in model.parameters() if param.grad is not None ] , dim = 0 ) ).item()
            #total_norm = torch.nn.utils.clip_grad_norm_(params_bert, thresh_groups['bert']).item()
            #print( norm0, norm1, norm2, norm3, norm_mean )
            with open(norm_file, 'a') as f:
                print( "epcoch:", epoch, ", step:", global_step, ", norm0:", norm0, ", norm1:", norm1, ", norm2:", norm2, \
                       ", norm3:", norm3, ", norm_mean:", norm_mean, file=f  )
                f.flush()

            #optimizer.step()
            #start_time = time.time()
            scaler.step(optimizer)
            scaler.update()            
            scheduler.step()
            #end_time = time.time()
            #print( "step time:", end_time - start_time )
            
            if global_step % file_param == 0:
                #start_time = time.time()
                #start_time = time.time()
                hypo_sentence1 = []
                ref_sentence1 = []
                if config.decode_t == 'no-endoftext':
                    preds_str = [tokenizer.decode(
                        [pred[i] for i in range( 1, len( pred )  ) if not (pred[i-1] == endoftext_token_id and pred[i] == endoftext_token_id) ]
                        ) for pred in hypo_ids]
                    samps_str = [tokenizer.decode(
                        [pred[i] for i in range( 1, len( pred )  ) if not (pred[i-1] == endoftext_token_id and pred[i] == endoftext_token_id) ]
                        ) for pred in preds]
                    targets_str = [tokenizer.decode(
                        [target[i] for i in range( 1,  len( target )  ) if not (target[i-1] == endoftext_token_id and target[i] == endoftext_token_id) ]
                        ) for target in captions]
                elif config.decode_t == 'no-pad':
                    hypo_ids1 = copy.deepcopy(hypo_ids)  # deep copy
                    captions1 = copy.deepcopy(captions)
                    preds1 = copy.deepcopy(preds)
                    hypo_ids1[hypo_ids1 == eos_token_id] = pad_token_id
                    decoded = tokenizer.batch_decode(hypo_ids1, skip_special_tokens=False)
                    preds_str = [s.replace("[PAD]", "").replace("[unused1]", "").strip() for s in decoded]            
                    preds_str2 = tokenizer.batch_decode(hypo_ids1, skip_special_tokens=True)
                    captions1[captions1 == eos_token_id] = pad_token_id
                    decoded = tokenizer.batch_decode(captions1, skip_special_tokens=False)
                    targets_str = [s.replace("[PAD]", "").replace("[unused1]", "").strip() for s in decoded]   
                    preds1[preds1 == eos_token_id] = pad_token_id
                    decoded = tokenizer.batch_decode(preds1, skip_special_tokens=False)
                    samps_str = [s.replace("[PAD]", "").replace("[unused1]", "").strip() for s in decoded]   
                else:
                    preds_str = [tokenizer.decode(pred) for pred in hypo_ids]
                    targets_str = [tokenizer.decode(target) for target in captions]
                pred_dict = { str(i): [item] for i, item in enumerate( preds_str)}
                target_dict = { str(i): [item] for i, item in enumerate( targets_str)}
                with torch.no_grad():
                    avg_bleu, scores = compute_reward.scorer.compute_score(target_dict, pred_dict) # cider の計算
                    rouge_scores = [compute_reward.rougeL.score(target, pred)['rougeL'][0] for pred, target in zip(preds_str, targets_str)]
                    avg_error = sum( rouge_scores ) / len( rouge_scores )
                with autocast(str(config.device),enabled=config.use_amp):
                    with torch.no_grad():
                        #reward_clip = compute_reward_with_metrics( preds_str2, imgs2 )
                        processed = compute_reward.metric.processor(text=preds_str2, images=imgs2, return_tensors="pt", padding=True, \
                                                          truncation=True, max_length=77, do_resize=False, do_rescale=False ).to(config.device)
                        outputs = compute_reward.metric.model(**processed)
                        # 特徴量の正規化
                        image_features = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
                        text_features = outputs.text_embeds / outputs.text_embeds.norm(p=2, dim=-1, keepdim=True)
                        individual_scores = torch.clamp( (image_features.to(config.device) * \
                                                          text_features.to(config.device)).sum(axis=-1), min=0)
                        clip_scores = individual_scores[:,None].expand( -1, seq_len )
                        clip_score = clip_scores.mean()
                        #bert_scores = compute_reward.bert.compute(predictions=preds_str, references=targets_str, model_type=model_name, \
                        #                                  lang='en',  device=config.device)['f1']
                        model_name = 'distilbert-base-uncased' 
                        bert_scores = compute_reward.bert.compute(
                            predictions=preds_str, 
                            references=targets_str,
                            model_type=model_name,
                            use_fast_tokenizer=True, 
                            lang='en', 
                            device=config.device,
                            batch_size=config.batch_size,  # メモリ許容範囲で大きく設定
                            rescale_with_baseline=False
                            )['f1']
                        bert_score = sum( bert_scores ) / len( bert_scores )
                if  global_step % train_param == 0:
                    hypo_sentence1 = [preds_str[0]]
                    samp_sentence1 = [samps_str[0]]
                    ref_sentence1 = [targets_str[0]]

                # 学習時の損失をログに書き込み
                #エポック内の平均
                train_losses.append(loss_item)
                train_policys.append(policy_item)
                train_entropies.append(entropy_item)
                train_critics.append(critic_item)
                train_kl_divs.append(kl_div_item)
                train_rewards.append(reward_item)
                train_rewards2.append(reward2_item)
                train_ord.append(ord_item)
                train_repeat.append(repeat_item)
                train_length.append(length_item)
                train_adv.append(adv_item)
                train_errors.append( avg_error )
                train_bleus.append( avg_bleu )
                train_crfs.append( crf_item )
                train_ces.append( ce_item )
                train_clips.append( clip_score )
                train_unrs.append( unr_item )
                train_berts.append( bert_score )
                if len(train_losses) > config.moving_avg:
                    train_losses.popleft()
                    train_policys.popleft()
                    train_entropies.popleft()
                    train_critics.popleft()
                    train_kl_divs.popleft()
                    train_rewards.popleft()
                    train_rewards2.popleft()
                    train_ord.popleft()
                    train_repeat.popleft()
                    train_length.popleft()
                    train_adv.popleft()
                    train_errors.popleft()
                    train_bleus.popleft()
                    train_crfs.popleft()
                    train_ces.popleft()
                    train_clips.popleft()
                    train_unrs.popleft()
                    train_berts.popleft()
                mean_loss = torch.Tensor(train_losses).mean().item()
                mean_policy = torch.Tensor(train_policys).mean().item()
                mean_entropy = torch.Tensor(train_entropies).mean().item()
                mean_critic = torch.Tensor(train_critics).mean().item()
                mean_kl_div = torch.Tensor(train_kl_divs).mean().item()
                mean_reward = torch.Tensor(train_rewards).mean().item()
                mean_reward2 = torch.Tensor(train_rewards2).mean().item()
                mean_ord = torch.Tensor(train_ord).mean().item()
                mean_repeat = torch.Tensor(train_repeat).mean().item()
                mean_length = torch.Tensor(train_length).mean().item()
                mean_adv = torch.Tensor(train_adv).mean().item()
                mean_error = torch.Tensor(train_errors).mean().item()
                mean_bleu = torch.Tensor(train_bleus).mean().item()
                mean_crf = torch.Tensor(train_crfs).mean().item()
                mean_ce = torch.Tensor(train_ces).mean().item()
                mean_clip = torch.Tensor(train_clips).mean().item()
                mean_unr = torch.Tensor(train_unrs).mean().item()
                mean_bert = torch.Tensor(train_berts).mean().item()
                #print( "mean_reward2:", mean_reward2 ) 
                pbar.set_postfix({
                    'loss': mean_loss,
                    'policy': mean_policy,
                    'entropy': mean_entropy,
                    'gae': mean_critic,
                    'kl_div': mean_kl_div,
                    'reward': mean_reward,
                    'reward2': mean_reward2,
                    'ord': mean_ord,
                    'repeat': mean_repeat,
                    'length': mean_length,
                    'adv': mean_adv,
                    'rougeL': mean_error,
                    'cider': mean_bleu,
                    'crf': mean_crf,
                    'ce': mean_ce,
                    'clip': mean_clip,
                    'unr': mean_unr,
                    'bert': mean_bert,
                })
                with open(train_loss_file, 'a') as f:
                    print(f' {global_step}, {mean_loss}, {mean_policy}, {mean_entropy}, {mean_critic}, {mean_kl_div}, {mean_reward}, ' \
                          f'{mean_ord}, {mean_repeat}, {mean_length}, {mean_adv}, {mean_error}, {mean_bleu}, {mean_crf}, {mean_ce}, '\
                          f'{mean_clip}, {mean_unr}, {mean_bert}', file=f)
                print_flag = 1
                for ( hypo_se, ref_se, samp_se ) in zip( hypo_sentence1, ref_sentence1, samp_sentence1 ):
                    if print_flag == 1:
                        print( "lr con   :", optimizer.param_groups[0]["lr"] )
                        print( "lr bert  :", optimizer.param_groups[1]["lr"] )
                        print( "lr cri   :", optimizer.param_groups[2]["lr"] )
                        print( "lr others:", optimizer.param_groups[3]["lr"] )
                        print_flag = 0
                    print(f'Train epoch = {global_step/len_tr_loader}, loss = {mean_loss}, policy = {mean_policy}, '\
                          f'entropy_loss = {mean_entropy}, gae = {mean_critic}, kl_div = {mean_kl_div}, reward = {mean_reward}, '\
                          f'ord = {mean_ord}, repeat = {mean_repeat}, length = {mean_length}, adv = {mean_adv}, '\
                          f'rougeL = {mean_error}, cider = {mean_bleu}, clip = {mean_clip}, crf = {mean_crf}, ce = {mean_ce}, '\
                          f'unr = {mean_unr}, bert = {mean_bert}' )
                    print( "refe:", ref_se )
                    print( "hypo:", hypo_se )
                    print( "samp:", samp_se )
                #end_time = time.time()
                #print( "display time:", end_time - start_time )
            
            global_step += 1
            #end_time0 = time.time()
            #print( "all train time:", end_time0 - start_time0 )
        end = time.time()
        print( "time:",end - start )
        #print(prof.key_averages().table(sort_by = "cuda_time", row_limit = 30))
        #print(prof.key_averages().table(sort_by = "cpu_time", row_limit = 30))
    #各値を表示
    print(f'Train loss: {mean_loss}')
    print(f'Train policy: {mean_policy}')
    print(f'Train entropy: {mean_entropy}')
    print(f'Train gae: {mean_critic}')
    print(f'Train kl_div: {mean_kl_div}')
    print(f'Train reward: {mean_reward}')
    print(f'Train reward2: {mean_reward2}')
    print(f'Train ord: {mean_ord}')
    print(f'Train repeat: {mean_repeat}')
    print(f'Train pad: {mean_length}')
    print(f'Train adv: {mean_adv}')
    print(f'Train clip: {mean_clip}')
    print(f'Train rougeL: {mean_error}')        
    print(f'Train cider: {mean_bleu}')
    print(f'Train crf: {mean_crf}')        
    print(f'Train ce: {mean_ce}')
    print(f'Train unr: {mean_unr}')        
    print(f'Train bert: {mean_bert}')
    
    # 検証
    with tqdm(val_loader) as pbar:
        pbar.set_description(f'[検証]')
    
        # 評価モード
        model.eval()
    
        #val_losses = deque()
        #val_rewards = deque()
        val_errors = deque()
        val_bleus = deque()
        val_clips = deque()
        val_berts = deque()
        for n_batch, (imgs, imgs2, captions, caption_lengths) in enumerate( pbar ):
    
            # ミニバッチを設定
            imgs = imgs.to(config.device)
            imgs2 = imgs2.to(config.device)
            captions = captions.to(config.device)
                
            with torch.no_grad():
                finalized_scores, finalized_tokens, top_probs, top_indices, \
                critical_value, crf_loss, bert_logits, sampled_beam_idx  = \
                model( imgs, captions, top_indices = None )
                hypo_ids = finalized_tokens
                gc.collect()
                torch.cuda.empty_cache()
               
            n = 0
            hypo_sentence = []
            ref_sentence = []
            hypo_sentence1 = []
            ref_sentence1 = []
            total_error = 0
            total_token_length = 0
            total_bleu = 0
            n2 = 0
            if config.decode_t == 'no-endoftext':
                preds_str = [tokenizer.decode(
                    [pred[i] for i in range( 1, len( pred )  ) if not (pred[i-1] == endoftext_token_id and pred[i] == endoftext_token_id) ]
                    ) for pred in hypo_ids]
                samps_str = [tokenizer.decode(
                    [pred[i] for i in range( 1, len( pred )  ) if not (pred[i-1] == endoftext_token_id and pred[i] == endoftext_token_id) ]
                    ) for pred in preds]
                targets_str = [tokenizer.decode(
                    [target[i] for i in range( 1,  len( target )  ) if not (target[i-1] == endoftext_token_id and target[i] == endoftext_token_id) ]
                    ) for target in captions]
            elif config.decode_t == 'no-pad':
                hypo_ids1 = copy.deepcopy(hypo_ids)  # deep copy
                captions1 = copy.deepcopy(captions)
                preds1 = copy.deepcopy(preds)
                hypo_ids1[hypo_ids1 == eos_token_id] = pad_token_id
                decoded = tokenizer.batch_decode(hypo_ids1, skip_special_tokens=False)
                preds_str = [s.replace("[PAD]", "").replace("[unused1]", "").strip() for s in decoded]            
                preds_str2 = tokenizer.batch_decode(hypo_ids1, skip_special_tokens=True)
                captions1[captions1 == eos_token_id] = pad_token_id
                decoded = tokenizer.batch_decode(captions1, skip_special_tokens=False)
                targets_str = [s.replace("[PAD]", "").replace("[unused1]", "").strip() for s in decoded]   
            else:
                preds_str = [tokenizer.decode(pred) for pred in hypo_ids]
                targets_str = [tokenizer.decode(target) for target in captions]
            pred_dict = { str(i): [item] for i, item in enumerate( preds_str)}
            target_dict = { str(i): [item] for i, item in enumerate( targets_str)}
            with torch.no_grad():
                avg_bleu, scores = compute_reward.scorer.compute_score(target_dict, pred_dict) # cider の計算
                rouge_scores = [compute_reward.rougeL.score(target, pred)['rougeL'][0] for pred, target in zip(preds_str, targets_str)]
                avg_error = sum( rouge_scores ) / len( rouge_scores )
            with autocast(str(config.device),enabled=config.use_amp):
                with torch.no_grad():
                    #reward_clip = compute_reward_with_metrics( preds_str2, imgs2 )
                    processed = compute_reward.metric.processor(text=preds_str2, images=imgs2, return_tensors="pt", padding=True, \
                                                      truncation=True, max_length=77, do_resize=False, do_rescale=False ).to(config.device)
                    outputs = compute_reward.metric.model(**processed)
                    # 特徴量の正規化
                    image_features = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
                    text_features = outputs.text_embeds / outputs.text_embeds.norm(p=2, dim=-1, keepdim=True)
                    individual_scores = torch.clamp( (image_features.to(config.device) * \
                                                      text_features.to(config.device)).sum(axis=-1), min=0)
                    clip_scores = individual_scores[:,None].expand( -1, seq_len ) 
                    clip_score = clip_scores.mean()
                    #bert_scores = compute_reward.bert.compute(predictions=preds_str, references=targets_str, model_type=model_name, \
                    #                                  lang='en',  device=config.device)['f1']
                    bert_scores = compute_reward.bert.compute(
                        predictions=preds_str, 
                        references=targets_str,
                        model_type=model_name,
                        use_fast_tokenizer=True, 
                        lang='en', 
                        device=config.device,
                        batch_size=config.batch_size,  # メモリ許容範囲で大きく設定
                        rescale_with_baseline=False
                        )['f1']
                    bert_score = sum( bert_scores ) / len( bert_scores )
            if  n_batch % val_param == 0:
                hypo_sentence1 = [preds_str[0]]
                ref_sentence1 = [targets_str[0]]
    
            val_errors.append( avg_error )
            val_bleus.append( avg_bleu )
            val_clips.append( clip_score )
            val_berts.append( bert_score )
            if len(val_errors) > config.moving_avg:
                #val_losses.popleft()
                #val_rewards.popleft()
                val_errors.popleft()
                val_bleus.popleft()
                val_clips.popleft()
                val_berts.popleft()
             #mean_loss = torch.Tensor(val_losses).mean().item()
            #mean_reward = torch.Tensor(val_rewards).mean().item()
            mean_error = torch.Tensor(val_errors).mean().item()
            mean_bleu = torch.Tensor(val_bleus).mean().item()
            mean_clip = torch.Tensor(val_clips).mean().item()
            mean_bert = torch.Tensor(val_berts).mean().item()
            pbar.set_postfix({
                #'loss': mean_loss,
                #'reward': mean_reward,
                'rougeL': mean_error,
                'CIDER': mean_bleu,
                'clip': mean_clip,
                'bert': mean_bert,
            })
            # Validation Lossをログに書き込み
            with open(val_loss_file, 'a') as f:
                print(f'{epoch}, {mean_error}, {mean_bleu}, {mean_clip}, {mean_bert}', file=f)
            
            for ( hypo_se, ref_se ) in zip( hypo_sentence1, ref_sentence1 ):
                print(f'Val epoch = {epoch}, rougeL = {mean_error}, cider = {mean_bleu}, clip = {mean_clip}, bert = {mean_bert}')
                print( "refe:", ref_se )
                print( "hypo:", hypo_se )
    
    # Loss 表示
    #print(f'Validation loss: {val_loss}')
    #print(f'Validation loss: {val_reward}')
    print(f'Validation rougeL: {mean_error}')
    print(f'Validation cider: {mean_bleu}')
    print(f'Validation clip: {mean_clip}')
    print(f'Validation bert: {mean_bert}')
    
    ## より良い検証結果が得られた場合、モデルを保存
            
    # モデルを保存
    torch.save({'epoch': epoch,
                'global_step': global_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        #'loss': loss,
        },
        f'{config.save_directory}/model_rl_grpo_crf33_.pth')
    ## モデルを保存
        
## モデルを保存
#torch.save({'epoch': epoch,
#    'global_step': global_step,
#    'model_state_dict': model.state_dict(),
#    'optimizer_state_dict': optimizer.state_dict(),
#    #'scheduler_state_dict': scheduler.state_dict(),
#    #'loss': loss,
#    },
#    f'{config.save_directory}/model_rl_grpo_crf24_final.pth')


In [None]:
###### モデルを保存
torch.save({'epoch': epoch,
    'global_step': global_step,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    #'loss': loss,
    },
    f'{config.save_directory}/model_rl_grpo_crf33_.pth')