In [16]:
# -*- coding: utf-8 -*-

#
# Pytorchで用いるDatasetの定義
#

#!pip install janome

# sysモジュールをインポート
import sys
import os

import matplotlib.pyplot as plt
import pandas as pd
import torch
import random
from torch import nn, Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import math
import janome
from janome.tokenizer import Tokenizer
from collections import Counter
from torch.utils.data.sampler import SubsetRandomSampler
import time
import levenshtein
import pickle
from timm.scheduler import CosineLRScheduler
from nltk import bleu_score
#from tqdm import tqdm
from tqdm.notebook import tqdm
from torch import autocast, GradScaler
from typing import Sequence, Dict, Tuple, Union
from transformers import  get_linear_schedule_with_warmup, BertTokenizer
from sklearn.metrics import accuracy_score
from datasets import load_dataset

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device( "cpu")
use_amp = False
use_saved_pth = True

In [2]:
path = "../../python_image_recognition-main/6_img_captioning/6_7_ImgCap_pre_trained_Feature_extractor/CLIP_ENCODER/models--google-bert--bert-large-uncased/snapshots/6da4b6a26a1877e173fca3225479512db81a5e5b"
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path = path)

vocab_size = len(tokenizer)
print(vocab_size)

30522


In [3]:
class MyDataset(Dataset):
    def __init__(self, mode, tokenizer, pad_token_id, max_length ):

        #dataset = load_dataset("yelp_review_full") 
        #dataset = load_dataset("imdb")
        #dataset = load_dataset("kunishou/databricks-dolly-15k-ja")  
        # Dataset id from huggingface.co/dataset
        dataset_id = "argilla/synthetic-domain-text-classification"
        dataset = load_dataset(dataset_id, split='train')
 
        ## Load raw dataset
        #train_dataset = load_dataset(dataset_id, split='train')
        
        if mode == "train":
            dataset = dataset[:900]
        elif mode == "val":
            dataset = dataset[900:950]
        else:
            dataset = dataset[950:1000]
        self.dataset = dataset
        #print( "self.dataset[1]:", self.dataset['label'][1])
            
        #self.encoded_texts = [ tokenizer.encode(text, max_length = max_length, truncation = True ) for text in dataset['text']]
        self.encoded_texts = [ tokenizer.encode(text ) for text in dataset['text']]
        
        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
            # Truncate sequences if they are longer than max_length
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]

        # Pad sequences to the longest sequence
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]
        
    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.dataset["label"][index]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )

    def __len__(self):
        return len(self.encoded_texts)

    def _longest_encoded_length(self):
        max_length = 0
        for encoded_text in self.encoded_texts:
            encoded_length = len(encoded_text)
            if encoded_length > max_length:
                max_length = encoded_length
        return max_length
        # Note: A more pythonic version to implement this method
        # is the following, which is also used in the next chapter:
        # return max(len(encoded_text) for encoded_text in self.encoded_texts)

In [4]:
train_dataset = MyDataset(
    mode="train",
    tokenizer=tokenizer,
    pad_token_id = tokenizer.pad_token_id,
    max_length=None,
)

print(train_dataset.max_length)
print( len( train_dataset ))

408
900


In [5]:
val_dataset = MyDataset(
    mode="val",
    tokenizer=tokenizer,
    pad_token_id = tokenizer.pad_token_id,
    max_length=512,
)
test_dataset = MyDataset(
    mode="test",
    tokenizer=tokenizer,
    pad_token_id = tokenizer.pad_token_id,
    max_length=512,
)

In [6]:
#num_workers = 0 if device == torch.device( 'cpu' ) else 8
num_workers = 0 if device == torch.device( 'cpu' ) else 8
batch_size = 4

train_loader = DataLoader(train_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            pin_memory=True,)
                            #collate_fn = collate_func_lambda)
print( "train_loader defiend" )
# 開発データのDataLoaderを呼び出す
# 開発データはデータはシャッフルしない
val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=num_workers,)
                        #collate_fn = collate_func_lambda)
print( "val_loader defined" )
print( len( train_loader ))
print( len( val_loader ))

inputs, labels = next(iter(val_loader))
print( inputs )
print( labels )
#print(text[0], target[0], text_len[0],target_len[0], sep="\n")
#print( "text:", text )
#print( "text_len:", text_len )

train_loader defiend
val_loader defined
225
13
tensor([[ 101, 5983, 2092,  ...,    0,    0,    0],
        [ 101, 1996, 3795,  ...,    0,    0,    0],
        [ 101, 2023, 3720,  ...,    0,    0,    0],
        [ 101, 3795, 3119,  ...,    0,    0,    0]])
tensor([17,  9,  4,  0])


In [7]:
print( len( inputs ))
print( len( labels ))
print( batch_size )

4
4
4


In [8]:
class PositionalEmbedding(nn.Module):
    '''
    位置埋め込み （Positional embedding）
    dim_embedding: 埋込み次元
    max_len      : 入力の最大系列長
    '''
    def __init__(self, dim_embedding: int, max_len: int=2048):
        super().__init__()

        self.pos_emb = nn.Embedding(max_len, dim_embedding)

    '''
    位置エンコーディングの順伝播
    x: 位置エンコーディングを埋め込む対象のテンソル,
       [バッチサイズ, 系列長, 埋め込み次元]
    '''
    def forward(self, x: torch.Tensor):
        seq = x.shape[1]
        positions = torch.arange(start=0, end=seq, step=1, device=x.device).to(torch.long)
        positions = self.pos_emb(positions)[:seq,:]
        
        return positions

In [9]:
class TransformerEncoder(nn.Module):
    '''
    CaptioningTransformerのコンストラクタ
    dim_embedding  : 埋め込み次元
    dim_feedforward: FNNの中間特徴次元
    num_heads      : マルチヘッドアテンションのヘッド数
    num_layers     : Transformerデコーダ層の数
    vocab_size     : 辞書の次元
    pad_index      : PADのID
    dropout        : ドロップアウト確率
    '''
    def __init__(self, vocab_size: int, dim_embedding: int, dim_feedforward: int,
                 num_heads: int, num_layers: int ):
        super().__init__()

        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=dim_embedding, nhead=num_heads, batch_first=True, activation='gelu', norm_first=True)
            for _ in range(num_layers)
        ])

    ''' CaptioningTransformerの順伝播処理
    features: 画像特徴量 [バッチサイズ, 埋め込み次元]
    captions: 正解キャプション [バッチサイズ, 系列長]

    '''
    def forward(self, src: torch.Tensor, src_mask: torch.Tensor=None, \
                src_key_padding_mask: torch.Tensor=None ):


        for layer in self.encoder_layers:
            src = layer( src, src_mask = src_mask, src_key_padding_mask = src_key_padding_mask,\
                        #is_causal= True )
                        is_causal= False )

        return src

In [10]:
class Transformer(nn.Module):
    def __init__(self, dim_embedding: int, dim_feedforward: int,
                 num_heads: int, num_layers: int, vocab_size: int,
                 pad_index: int, dropout: float=0.5, us_rate: float=2.0 ):
        super().__init__()

        # 単語埋め込み
        self.embed = nn.Embedding(
            vocab_size, dim_embedding, padding_idx=pad_index)

        # 位置エンコーディング
        self.pos_emb = PositionalEmbedding(dim_embedding)        
        
        # dropout
        self.dropout = nn.Dropout( dropout )
        
        self.encoder = TransformerEncoder(vocab_size, dim_embedding, dim_feedforward, num_heads, num_layers)

        # 単語出力分布計算
        self.ln = nn.LayerNorm( dim_embedding )
        self.linear = nn.Linear(dim_embedding, vocab_size)

        self.pad_index = pad_index
        
    ''' CaptioningTransformerの順伝播処理
    features: 画像特徴量 [バッチサイズ, 埋め込み次元]
    captions: 正解キャプション [バッチサイズ, 系列長]

    '''
    def forward(self, text):

        device = text.device

        src = self.embed( text )
        
        src += self.pos_emb( src )
        src = self.dropout( src )
        src_key_padding_mask = torch.eq(text, self.pad_index)

        #ones = torch.ones( text.size(1) ).to( device = device )
        #src_mask = torch.diag( ones ).bool()
        src_mask = None
        
        preds = self.encoder( src, src_mask, src_key_padding_mask )
        
        preds = self.ln( preds )
        preds = preds[:,0,:]
        logits = self.linear( preds )

        return logits

In [11]:
epoch_num = 5
model = Transformer(768, 3072, 12, 6, vocab_size, tokenizer.pad_token_id ).to(device)
#criterion = nn.BCEWithLogitsLoss()
criterion = nn.CrossEntropyLoss()
#lr = 1e-4
lr = 5e-5
optimizer = optim.AdamW( model.parameters(), lr = lr )
# 全ステップ数
num_global_steps = len( train_loader ) * epoch_num
print( "num_global_steps:", num_global_steps )
num_warmup_steps = num_global_steps * 0.1
print( "num_warmup_steps:", num_warmup_steps )
#スケジューラーの定義
scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps, num_global_steps )
eps = 1e-4

#PATH = 'Diag_Mask3_curr.pt'
#if use_saved_pth and os.path.isfile(PATH):
#    print( "loaded Diag_Mask6_curr.pt")
#    checkpoint = torch.load(PATH)
#    model.load_state_dict(checkpoint['model_state_dict'])

#model.linear = nn.Linear( 768, 5 ) 
model.linear = nn.Linear( 768, 26 ) 


num_global_steps: 1125
num_warmup_steps: 112.5


In [12]:
# WarmupとCosine Decayを行うスケジューラを利用
#scheduler = CosineLRScheduler(
#    optimizer, t_initial=epoch_num, lr_min=1e-1,
#    warmup_t=5, warmup_lr_init=5e-2, warmup_prefix=True)

#tr_print_coef = 64000
#tr_save_coef = 1000
#val_print_coef = 240
len_tr_loader = len( train_loader )
len_val_loader = len( val_loader )
tr_print_coef = len_tr_loader // 3
tr_save_coef = len_tr_loader // 30
val_print_coef = len_val_loader // 3
print( "len( train_loader ):", len_tr_loader )
print( "len( val_loader ):", len_val_loader )
print( "tr_print_coef:", tr_print_coef )
print( "tr_save_coef:", tr_save_coef )
print( "val_print_coef:", val_print_coef )
#tr_print_coef = 1
#tr_save_coef = 1
#val_print_coef = 1
#train_length = len(train_loader)
#train_int = train_length // tr_print_coef
#print( train_int )
#print( train_length )
#val_length = len(val_loader)
#val_int = val_length // val_print_coef
#print( val_int )
#print( val_length )

history = {"len_tr_loader":[],"len_val_loader":[], "train_loss":[], "val_loss": [], "train_acc": [], "val_acc": [] }
history["len_tr_loader"].append( len_tr_loader )
history["len_val_loader"].append( len_val_loader )
with open("Diag_Mask5_FT_argilla.pkl", "wb") as f:
    pickle.dump( history, f )      
n = 0
train_loss = 0
val_loss = 0


# 学習率の減衰やEarly stoppingの
# 判定を開始するエポック数
# (= 最低限このエポックまではどれだけ
# validation結果が悪くても学習を続ける)
lr_decay_start_epoch = 5

# 学習率を減衰する割合
# (減衰後学習率 <- 現在の学習率*lr_decay_factor)
# 1.0以上なら，減衰させない
lr_decay_factor = 0.5

# Early stoppingの閾値
# 最低損失値を更新しない場合が
# 何エポック続けば学習を打ち切るか
early_stop_threshold = 3

# 最も低い損失値，
# そのときのモデルとエポック数を記憶しておく
best_loss = -1
best_model = None
best_epoch = 0

# Early stoppingフラグ．Trueになると学習を打ち切る
early_stop_flag = False
# Early stopping判定用(損失値の最低値が
# 更新されないエポックが何回続いているか)のカウンタ
counter_for_early_stop = 0

fn = bleu_score.SmoothingFunction().method7

# AMP用のスケーラー
scaler = GradScaler(enabled=use_amp)

for epoch in range(epoch_num):
    # early stopフラグが立っている場合は，
    # 学習を打ち切る
    if early_stop_flag:
        print('    Early stopping.'\
            ' (early_stop_threshold = %d)' \
            % (early_stop_threshold))
        #log_file.write('\n    Early stopping.'\
        #        ' (early_stop_threshold = %d)' \
        #        % (early_stop_threshold))
        break

    with tqdm(train_loader) as pbar:
    #with tqdm(val_loader) as pbar:
        pbar.set_description(f'[Train エポック {epoch + 1}]')
    
        model.train()
        #scheduler.step(epoch)
        #print( "Train")
        train_loss = 0
        total_acc = 0
        n3 = 0
        for i, ( inputs, labels ) in enumerate( pbar ):
            optimizer.zero_grad()
            inputs = inputs.to(device)
            y_true = labels.tolist()
            labels = labels.to(device)
            #labels = nn.functional.one_hot(labels.to(device).long(), num_classes = 5 ).float()
            #labels = labels.to(device).long()
            
            with autocast(str(device),enabled=use_amp):
                logits = model( inputs )
                #print( "logits size:", logits.size() )
                y_pred = torch.argmax( logits, dim = 1 )
                # 損失の計算
                loss = criterion( logits, labels )

                
            # 誤差逆伝播
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            clip_grad_threshold = 5.0
            torch.nn.utils.clip_grad_norm_(\
                    model.parameters(),
                    clip_grad_threshold)
            # オプティマイザにより，パラメータを更新する
            scaler.step(optimizer)
            scaler.update()            
            
            scheduler.step()
            
            #lr = optimizer.param_groups[0]['lr']
            #print( "lr:", lr )
            
            n2 = 0
            input_sentence = []
            target_label = []
            pred_label = []
            for input_id, label, pred in zip( inputs, y_true, y_pred ):
                #print( "label:", label )
                #print( "pred:",pred )
                if n2 < 2 and ( i % tr_print_coef == tr_print_coef -1 ) :
                    input_s = tokenizer.decode( input_id, skip_special_tokens = True )
                    input_sentence.append( input_s )
                    target_label.append( label )
                    pred_label.append( pred.item() )
                
                n2 += 1

            #print( "y_true size:", torch.tensor( y_true).size() )
            #print( "y_pred size:", y_pred.size() )
            accuracy = accuracy_score(y_true, y_pred.detach().cpu() )
            total_acc += accuracy
            n3 += 1
    
            train_loss += loss.item()
            history["train_loss"].append( loss.item() )
            history["train_acc"].append( accuracy )
            if i % tr_save_coef == tr_save_coef - 1:
                with open("Diag_Mask5_FT_argilla.pkl", "wb") as f:
                    pickle.dump( history, f )            
            if i % tr_print_coef == tr_print_coef - 1:
                lr = optimizer.param_groups[0]['lr']
                print(f"Train epoch:{epoch+1}  index:{i+1} loss:{train_loss/n3:.10f} ACC:{ total_acc / n3 } lr:{lr:.10f}")
            #if i == len( train_loader ) - 1:
            for (input_s, target_lbl, pred_lbl ) in zip( input_sentence, target_label, pred_label ):
                print( "index:", i+1, "input :", input_s)
                print( "index:", i+1, "target:", target_lbl)
                print( "index:", i+1, "pred  :", pred_lbl )
            pbar.set_postfix({
                    'loss': train_loss / n3,
                    'acc': total_acc / n3,
                })

    with tqdm(val_loader) as pbar:
        pbar.set_description(f'[検証]')
        model.eval()
        #for i, ( text, target, text_len, target_len ) in enumerate(val_loader):
        val_loss = 0
        total_acc = 0
        n3 = 0
        for i, ( inputs, labels ) in enumerate( pbar ):
            inputs = inputs.to(device)
            y_true = labels.tolist()
            labels = labels.to(device)
            #labels = nn.functional.one_hot(labels.to(device).long(), num_classes = 2 ).float()
            #labels = labels.to(device).long()

            with torch.no_grad():
                logits = model( inputs )
                y_pred = torch.argmax( logits, dim = 1 )
                # 損失の計算
                loss = criterion( logits, labels )
           

            n2 = 0
            input_sentence = []
            target_label = []
            pred_label = []
            for input_id, label, pred in zip( inputs, y_true, y_pred ):
                if n2 < 2 and ( i % val_print_coef == val_print_coef -1 ) :
                    input_s = tokenizer.decode( input_id, skip_special_tokens = True )
                    input_sentence.append( input_s )
                    target_label.append( label )
                    pred_label.append( pred.item() )
                
                n2 += 1
                
            #print( "y_true size:", torch.tensor( y_true).size() )
            #print( "y_pred size:", y_pred.size() )
            accuracy = accuracy_score(y_true, y_pred.detach().cpu() )
            total_acc += accuracy
            n3 += 1
            
            val_loss += loss.item()
            history["val_loss"].append( loss.item() )
            history["val_acc"].append( accuracy )

            if i % val_print_coef == val_print_coef - 1:
                lr = optimizer.param_groups[0]['lr']
                print(f"Val epoch:{epoch+1}  index:{i+1} loss:{val_loss/n3:.10f} ACC:{total_acc / n3 } lr:{lr:.10f}")
                PATH = './Diag_Mask5_FT_argilla_curr.pt'
                torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,},
                     PATH)
                with open("Diag_Mask5_FT_argilla.pkl", "wb") as f:
                    pickle.dump( history, f )
            #if i == len( train_loader ) - 1:
            for (input_s, target_lbl, pred_lbl ) in zip( input_sentence, target_label, pred_label ):
                print( "index:", i+1, "input :", input_s)
                print( "index:", i+1, "target:", target_lbl)
                print( "index:", i+1, "pred  :", pred_lbl )
            pbar.set_postfix({
                    'loss': val_loss / n3,
                    'acc': total_acc / n3,
                })
    
    epoch_loss = val_loss/n3
    if epoch == 0 or best_loss > epoch_loss:
        # 損失値が最低値を更新した場合は，
        # その時のモデルを保存する
        best_loss = epoch_loss
        torch.save(model.state_dict(), 
                    './best_model_Mask5_FT_argilla.pt')
        best_epoch = epoch
        # Early stopping判定用の
        # カウンタをリセットする
        counter_for_early_stop = 0
    else:
        # 最低値を更新しておらず，
        if epoch+1 >= lr_decay_start_epoch:
            # かつlr_decay_start_epoch以上の
            # エポックに達している場合
            if counter_for_early_stop+1 \
                    >= early_stop_threshold:
                # 更新していないエポックが，
                # 閾値回数以上続いている場合，
                # Early stopping フラグを立てる
                early_stop_flag = True
            else:
                # Early stopping条件に
                # 達していない場合は
                # 学習率を減衰させて学習続行
                if lr_decay_factor < 1.0:
                    for i, param_group \
                            in enumerate(\
                            optimizer.param_groups):
                        if i == 0:
                            lr = param_group['lr']
                            dlr = lr_decay_factor \
                                * lr
                            print('    (Decay '\
                                'learning rate:'\
                                ' %f -> %f)' \
                                % (lr, dlr))
                        param_group['lr'] = dlr
                # Early stopping判定用の
                # カウンタを増やす
                counter_for_early_stop += 1
#torch.cuda.synchronize()    

len( train_loader ): 225
len( val_loader ): 13
tr_print_coef: 75
tr_save_coef: 7
val_print_coef: 4


  0%|          | 0/225 [00:00<?, ?it/s]

Train epoch:1  index:75 loss:3.3305014356 ACC:0.056666666666666664 lr:0.0000333333
index: 75 input : a recent forum discussion on various online communities has highlighted the unique characteristics of these platforms. users engage in diverse activities ranging from video gaming and anime discussions to more niche topics like diy woodworking projects. one particular thread caught my attention, where a user sought advice for creating a community - based website for garden enthusiasts. the responses were varied, offering tips on using wordpress or squarespace as a platform, emphasizing the importance of community guidelines to ensure respectful interactions, and even suggesting ways to monetize through affiliate marketing or sponsored content.
index: 75 target: 7
index: 75 pred  : 20
index: 75 input : the legal system in many countries operates through courts that interpret laws created by government bodies. these systems ensure fairness, justice, and the protection of rights. in recent

  0%|          | 0/13 [00:00<?, ?it/s]

Val epoch:1  index:4 loss:3.4758304358 ACC:0.0625 lr:0.0000444444
index: 4 input : the culinary techniques used in creating unique flavor profiles across various dishes have been the topic of many discussions and articles. from understanding how certain spices can enhance the taste of a dish to exploring the effects of fermentation on foods like yogurt, beer, and soy sauce, these aspects add layers of complexity and depth that are often appreciated by food enthusiasts. additionally, discussing the cultural significance of specific ingredients in various cuisines around the world, such as the use of turmeric in indian cuisine or miso paste in japanese cooking, brings to light how deeply ingrained food is within social identities.
index: 4 target: 13
index: 4 pred  : 17
index: 4 input : adult - themed content has long been part of human culture, explored in various mediums including literature, film, and the arts. this piece delves into the portrayal of adult relationships across differe

  0%|          | 0/225 [00:00<?, ?it/s]

Train epoch:2  index:75 loss:3.2178145854 ACC:0.09666666666666666 lr:0.0000407407
index: 75 input : the rise of digital privacy concerns has prompted governments worldwide to consider new legislation aimed at regulating data collection practices. this includes measures that would require companies to disclose their methods for gathering user information and implement stricter controls on how this data is used and shared. the implications extend beyond just consumer protection, influencing international trade agreements and the framework under which businesses operate globally.
index: 75 target: 24
index: 75 pred  : 24
index: 75 input : the supreme court of the united states has been at the center of many historic decisions that have shaped american jurisprudence over the years. from civil rights cases like brown v. board of education to recent debates on healthcare reform with the affordable care act, the role of the court in interpreting the constitution and impacting national policy 

  0%|          | 0/13 [00:00<?, ?it/s]

Val epoch:2  index:4 loss:3.1430642009 ACC:0.125 lr:0.0000333333
index: 4 input : the culinary techniques used in creating unique flavor profiles across various dishes have been the topic of many discussions and articles. from understanding how certain spices can enhance the taste of a dish to exploring the effects of fermentation on foods like yogurt, beer, and soy sauce, these aspects add layers of complexity and depth that are often appreciated by food enthusiasts. additionally, discussing the cultural significance of specific ingredients in various cuisines around the world, such as the use of turmeric in indian cuisine or miso paste in japanese cooking, brings to light how deeply ingrained food is within social identities.
index: 4 target: 13
index: 4 pred  : 17
index: 4 input : adult - themed content has long been part of human culture, explored in various mediums including literature, film, and the arts. this piece delves into the portrayal of adult relationships across differen

  0%|          | 0/225 [00:00<?, ?it/s]

Train epoch:3  index:75 loss:3.1156349659 ACC:0.11 lr:0.0000296296
index: 75 input : the siberian tiger is one of the largest cats in the world, known for its distinctive orange fur with black stripes and white underbelly. these powerful predators inhabit the dense forests of eastern russia, particularly in areas like sikhote - alin. they have a rich history dating back thousands of years, deeply intertwined with the myths and legends of the local people. siberian tigers are not only majestic but also play a crucial role in their ecosystem by maintaining balance among prey species, which is why conservation efforts for this endangered subspecies are vital.
index: 75 target: 8
index: 75 pred  : 4
index: 75 input : the market for luxury properties in downtown manhattan continues to rise, with many new developments featuring cutting - edge technology and sustainable materials. developers are focusing on high - end amenities such as private rooftop terraces, state - of - the - art fitness 

  0%|          | 0/13 [00:00<?, ?it/s]

Val epoch:3  index:4 loss:2.8422745168 ACC:0.125 lr:0.0000222222
index: 4 input : the culinary techniques used in creating unique flavor profiles across various dishes have been the topic of many discussions and articles. from understanding how certain spices can enhance the taste of a dish to exploring the effects of fermentation on foods like yogurt, beer, and soy sauce, these aspects add layers of complexity and depth that are often appreciated by food enthusiasts. additionally, discussing the cultural significance of specific ingredients in various cuisines around the world, such as the use of turmeric in indian cuisine or miso paste in japanese cooking, brings to light how deeply ingrained food is within social identities.
index: 4 target: 13
index: 4 pred  : 21
index: 4 input : adult - themed content has long been part of human culture, explored in various mediums including literature, film, and the arts. this piece delves into the portrayal of adult relationships across differen

  0%|          | 0/225 [00:00<?, ?it/s]

Train epoch:4  index:75 loss:2.8150397635 ACC:0.18666666666666668 lr:0.0000185185
index: 75 input : the global financial crisis of 2008 had profound impacts on various economies worldwide. in the united states, it led to widespread unemployment and decreased consumer spending. this period saw significant changes in banking regulations, including the dodd - frank act aimed at preventing future crises. financial experts point out that factors such as subprime mortgage lending played a key role in initiating the crisis. despite the measures taken post - crisis, many economists still believe that risks remain within certain financial sectors.
index: 75 target: 23
index: 75 pred  : 5
index: 75 input : the theory of relativity proposed by albert einstein changed our understanding of space and time. it suggests that the observed velocity of light in vacuum has the same value for all observers regardless of their relative motion. this revolutionary concept also implies that events can appear s

  0%|          | 0/13 [00:00<?, ?it/s]

Val epoch:4  index:4 loss:2.8851913214 ACC:0.125 lr:0.0000111111
index: 4 input : the culinary techniques used in creating unique flavor profiles across various dishes have been the topic of many discussions and articles. from understanding how certain spices can enhance the taste of a dish to exploring the effects of fermentation on foods like yogurt, beer, and soy sauce, these aspects add layers of complexity and depth that are often appreciated by food enthusiasts. additionally, discussing the cultural significance of specific ingredients in various cuisines around the world, such as the use of turmeric in indian cuisine or miso paste in japanese cooking, brings to light how deeply ingrained food is within social identities.
index: 4 target: 13
index: 4 pred  : 24
index: 4 input : adult - themed content has long been part of human culture, explored in various mediums including literature, film, and the arts. this piece delves into the portrayal of adult relationships across differen

  0%|          | 0/225 [00:00<?, ?it/s]

Train epoch:5  index:75 loss:2.5677829965 ACC:0.3 lr:0.0000074074
index: 75 input : the pet store was bustling with activity as people came in to purchase various supplies for their pets. among the items available were dog toys, cat treats, bird cages, and fish tanks full of colorful tropical fish. the staff at the counter helped a customer pick out suitable bedding material for her rabbit while explaining the importance of proper hygiene and nutrition for small mammals. nearby, an article in the local newspaper highlighted recent research showing how interaction with pets can improve mental health among older adults. online forums featured discussions about the best ways to train puppies using positive reinforcement techniques and the benefits of regular veterinary check - ups for maintaining pet health.
index: 75 target: 8
index: 75 pred  : 13
index: 75 input : the human body is an intricate system that relies on various nutrients for optimal performance. recent studies have shown th

  0%|          | 0/13 [00:00<?, ?it/s]

Val epoch:5  index:4 loss:2.6451095343 ACC:0.25 lr:0.0000000000
index: 4 input : the culinary techniques used in creating unique flavor profiles across various dishes have been the topic of many discussions and articles. from understanding how certain spices can enhance the taste of a dish to exploring the effects of fermentation on foods like yogurt, beer, and soy sauce, these aspects add layers of complexity and depth that are often appreciated by food enthusiasts. additionally, discussing the cultural significance of specific ingredients in various cuisines around the world, such as the use of turmeric in indian cuisine or miso paste in japanese cooking, brings to light how deeply ingrained food is within social identities.
index: 4 target: 13
index: 4 pred  : 24
index: 4 input : adult - themed content has long been part of human culture, explored in various mediums including literature, film, and the arts. this piece delves into the portrayal of adult relationships across different

In [None]:
PATH = 'Diag_Mask5_FT_argilla_curr.pt'
if use_saved_pth and os.path.isfile(PATH):
    print( "loaded Diag_Mask5_curr.pt")
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])

In [13]:
test_loader = DataLoader(test_dataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=num_workers,)
                        #collate_fn = collate_func_lambda)
print( "test_loader defined" )
print( len( test_loader ))

test_loader defined
50


In [15]:
model.eval()
print("Test")

n = 0
total_acc = 0
total_loss = 0
test_print_coef = len( test_loader ) // 10

for i, ( inputs, labels ) in enumerate( test_loader ):
    inputs = inputs.to(device)
    y_true = labels
    labels = labels.to(device)
    #labels = nn.functional.one_hot(labels.to(device).long(), num_classes = 2 ).float()

    with torch.no_grad():
        logits = model( inputs )
        y_pred = torch.argmax( logits, dim = 1 )
        # 損失の計算
        loss = criterion( logits, labels )
           
    total_loss += loss.item()
    n2 = 0
    input_sentence = []
    target_label = []
    pred_label = []
    for input_id, label, pred in zip( inputs, y_true, y_pred ):
        if n2 < 2 and ( i % test_print_coef == test_print_coef -1 ) :
            input_s = tokenizer.decode( input_id, skip_special_tokens = True )
            input_sentence.append( input_s )
            target_label.append( label )
            pred_label.append( pred.item() )
                
        n2 += 1

    n3 += 1
    accuracy = accuracy_score(y_true, y_pred.detach().cpu() )
    total_acc += accuracy

print(f"index:{i+1} loss: {total_loss / i}")        
print(f"index:{i+1}  ACC: {total_acc / i}")


Test
index:50 loss: 2.5006457012222736
index:50  ACC: 0.32653061224489793
