<a href="https://colab.research.google.com/github/project-ccap/project-ccap.github.io/blob/master/2023notebooks/2023_1223dasic_speech_errors_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 準備作業

In [None]:
%config InlineBackend.figure_format = 'retina'
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

from IPython import get_ipython
isColab =  'google.colab' in str(get_ipython())

if isColab:

    # GPU 情報を表示
    !nvidia-smi -L

    # `import bit` する前に termcolor を downgrade しないと colab ではテキストに色がつかない
    !pip install --upgrade termcolor==1.1
    import termcolor

    # 日本語 transformer をインストールするためには，以下のインストールが必要
    !pip install --upgrade xlrd
    !pip install --upgrade 'fugashi[ipadic]'
    !pip install --upgrade 'fugashi[unidic]'
    !python -m unidic download
    !pip install --upgrade ipadic
    !pip install --upgrade transformers
    !pip install --upgrade termcolor
    !pip install --upgrade jaconv
    !pip install jaconv
    #!git clone https://github.com/ShinAsakawa/RAM.git

import platform
HOSTNAME = platform.node().split('.')[0]

import os
HOME = os.environ['HOME']

import sys
from collections import OrderedDict

try:
    import ipynbname
except ImportError:
    !pip install ipynbname
    import ipynbname
FILEPATH = str(ipynbname.path()).replace(HOME+'/','')

import pwd
USER=pwd.getpwuid(os.geteuid())[0]

from datetime import date
TODAY=date.today()

import torch
TORCH_VERSION = torch.__version__

from termcolor import colored

try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

from tqdm.notebook import tqdm

color = 'green'
print('日付:',colored(f'{TODAY}', color=color, attrs=['bold']))
print('HOSTNAME:',colored(f'{HOSTNAME}', color=color, attrs=['bold']))
print('ユーザ名:',colored(f'{USER}', color=color, attrs=['bold']))
print('HOME:',colored(f'{HOME}', color=color,attrs=['bold']))
print('ファイル名:',colored(f'{FILEPATH}', color=color, attrs=['bold']))
print('torch.__version__:',colored(f'{TORCH_VERSION}', color=color, attrs=['bold']))

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
!cp /content/drive/MyDrive/2023_1211snow_transformer_gpu2.pt .

In [None]:
if isColab:
     !git clone https://github.com/ShinAsakawa/RAM.git
!ls -lt RAM

In [23]:
import numpy as np

# SNOW データの読み込み

In [None]:
import os
import pandas as pd
import requests
from termcolor import colored
import jaconv

# やさしい日本語をダウンロード
SNOWs={'T15': {'url':"https://filedn.com/lit4DCIlHwxfS1gj9zcYuDJ/SNOW/T15-2020.1.7.xlsx"},
       'T23': {'url':"https://filedn.com/lit4DCIlHwxfS1gj9zcYuDJ/SNOW/T23-2020.1.7.xlsx"},
      }
print('エクセルファイル読込', end='...')
for corpus in SNOWs:
    url = SNOWs[corpus]['url']
    excel_fname = corpus + '-2020.1.7.xlsx'

    if not os.path.exists(excel_fname):  # ファイルが存在しない場合ダウンロード
        print(f'url:{url}')
        r = requests.get(url)
        with open(excel_fname, 'wb') as f:
            total_length = int(r.headers.get('content-length'))
            print(f'{excel_fname} をダウンロード中 {total_length} バイト')
            f.write(r.content)

    SNOWs[corpus]['df'] = pd.read_excel(excel_fname)
    SNOWs[corpus]['df'] = SNOWs[corpus]['df'].rename(columns={'#日本語(原文)': 'ja',
                                                              '#やさしい日本語':'easy_ja',
                                                              '#英語(原文)':'en'})
# 2 つのデータをあわせる
_snow = SNOWs['T15']['df']['ja'].tolist() + SNOWs['T23']['df']['ja'].tolist()
#_snow = SNOWs['T15']['df']['easy_ja'].tolist() + SNOWs['T23']['df']['easy_ja'].tolist()
snow = [jaconv.normalize(line, 'NFKC') for line in _snow] # 正規化

# 訓練済 BERT を Huggingface から読み込み

In [9]:
from transformers import EncoderDecoderModel, BertTokenizer, BertConfig
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

_bertmodel_name = 'bert-base-uncased'
sbertmodel_name = 'sonoisa/sentence-bert-base-ja-mean-tokens-v2'
tknz = BertTokenizer.from_pretrained(sbertmodel_name)

class snow_Dataset(torch.utils.data.Dataset):
    def __init__(self,
                 data_list:list=snow):

        super().__init__()
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        sent = self.data_list[idx]
        return sent, sent

snow_ds = snow_Dataset()
ds = snow_ds
batch_size = 1024
batch_size = 128

def _collate_fn(batch):
    inps, tgts = list(zip(*batch))
    inps = list(inps)
    tgts = list(tgts)
    return inps, tgts

snow_dl = DataLoader(
    dataset=snow_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=_collate_fn)

dl = snow_dl
inp, tch = next(iter(dl))
encoded_input = tknz.batch_encode_plus(inp,
                                       padding="longest",
                                       truncation=True,
                                       return_tensors="pt").to(device)
print(encoded_input.input_ids.detach().cpu().numpy()[:3])
print(snow[:3])

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertJapaneseTokenizer'. 
The class this function is called from is 'BertTokenizer'.


[[    2   859 23372   429   146  5553   609    77  6172 28457 11665     8
      3     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0]
 [    2  1325     9 10798 10118 28477 28491  4263     8     3     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0]
 [    2  5807 28450  3599 28452    69  1173  8747 28446  2160  6494 17234
   3721     8     3     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0]]
['誰が一番に着くか私には分かりません。', '多くの動物が人間によって滅ぼされた。', '私はテニス部員です。']


# 自作 Transformer の読み込み，モデルの定義

In [67]:
config = {
    'model_dim': 384,
    'num_heads': 4,
    'num_layers': 2,
    'max_seq_length': 64,
    'dropout': 0.,
    'ff_dim': 384,
    'device': device,
}

from RAM import Transformer
model = Transformer(src_vocab_size=tknz.vocab_size,
                    tgt_vocab_size=tknz.vocab_size,
                    model_dim=config['model_dim'],
                    num_heads=config['num_heads'],
                    num_layers=config['num_layers'],
                    max_seq_length=config['max_seq_length'],
                    dropout=config['dropout'],
                    ff_dim=config['ff_dim']).to(device)



# SNOW で訓練済パラメータの読み込み

In [68]:
#訓練済ファイルの読み込み
fname = '2023_1211snow_transformer_gpu.pt'
fname = '2023_1211snow_transformer_gpu2.pt'
state_dict = torch.load(fname, map_location=torch.device('cpu'))['state_dict']
#state_dict.keys()
model.load_state_dict(state_dict)

<All keys matched successfully>

In [16]:
def save_checkpoint(checkpoint_path, model):
    state = {'state_dict': model.state_dict() }
    torch.save(state, checkpoint_path)

def load_checkpoint(checkpoint_path, model):
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])
    print(f'model loaded from {checkpoint_path}')

# 訓練の実施

In [None]:
# %%time
# criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
# # [Adam](https://arxiv.org/abs/1412.6980) による最適化関数の定義
# from torch.optim import AdamW
# optimizer = AdamW(model.parameters(), lr=5e-5)  # 最適化関数を初期化
# #optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

# epochs = 100
# epochs = 30
# epochs = 10
# for epoch in range(epochs):
#     epoch_loss = 0.

#     loop = tqdm(dl, leave=True)
#     for batch in loop:
#     #for inp, tch in tqdm(dl):
#         inp = batch[0]
#         encoded_input = tknz.batch_encode_plus(inp,
#                                                padding="longest",
#                                                truncation=True,
#                                                return_tensors="pt").to(device)

#         optimizer.zero_grad()
#         output = model(src=encoded_input.input_ids,
#                        tgt=encoded_input.input_ids).to(device)
#         loss = criterion(output[0], encoded_input.input_ids[0])  # 損失値の計算
#         for h in range(1,len(output)):
#             loss += criterion(output[h], encoded_input.input_ids[h])

#         loss.backward()                      # 誤差逆伝播
#         epoch_loss += loss.item()            # 損失値総和
#         optimizer.step()                     # 誤差に基づき学習ステップ実行

#         loop.set_description(f'エポック {epoch}')
#         loop.set_postfix(OrderedDict(loss=loss.item()/len(inp)))

#     #print(f'loss:{epoch_loss/ds.__len__()}')
#     print(f'epoch:{epoch:03d}',
#           f'eopch_loss:{epoch_loss/ds.__len__():.5f}')
#           #f'出力:{tknz.convert_ids_to_tokens(output_ids)}')
#     checkpoint_fname = f'2023_1211SNOW_transfomer_epoch{epoch:02d}.pt'
#     save_checkpoint(checkpoint_fname, model)

In [None]:
encoded_input = tknz.batch_encode_plus(inp[:3],
                                       padding="longest",
                                       truncation=True,
                                       return_tensors="pt").to(device)
print(tknz.convert_ids_to_tokens(encoded_input.input_ids.squeeze(0)[1]))
print(inp[:3])
#encoded_input.input_ids

# 検証

In [14]:
model.eval()
n_corrects, total, results = 0, 0, {}
isPrint = False
loop = tqdm(snow, leave=True)
for N in loop:
#for N in tqdm(snow, leave=True)
    total += 1
    encoded_input = tknz.batch_encode_plus([N],
                                           padding="longest",
                                           truncation=True,
                                           return_tensors="pt").to(device)
    output = model(src=encoded_input.input_ids,
                   tgt=encoded_input.input_ids).to(device)
    #print(tknz.convert_ids_to_tokens(output.squeeze(0).topk(1)[1]), inp[N])
    outstr = "".join(tknz.convert_ids_to_tokens(output.squeeze(0).topk(1)[1][1:-1])).replace('##','')

    yesno = outstr == N
    if yesno == True:
        n_corrects += 1
    else:
        results[total] = {'正誤':yesno, '出力':outstr, '教師':N}
        if isPrint:
            print(f'{total:3d}:{yesno} N:{N}, outstr:{outstr}')

    loop.set_description('正解率')
    loop.set_postfix(OrderedDict(cr=n_corrects/total))

print(n_corrects) # , N)
#output.size()

  0%|          | 0/84300 [00:00<?, ?it/s]

82317


# 寺尾先生のタテ，ヨコデータ読み込み

`2023_1211terao_tate2.txt` と `2023_1211terao_yoko2.txt` とをアップロードする

In [92]:
from google.colab import files
uploaded = files.upload()

Saving 2023_1221terao_tate2.txt to 2023_1221terao_tate2.txt


## 文を復唱するための関数を定義

In [20]:
def eval_a_sent(sent:str=None,
              model:torch.nn.Module=model,
              verbose:bool=False,
              isPrint:bool=False,
             )->None:

    model.eval()
    if sent == None:
        sent = jaconv.normalize(input('文を入力してください: '))
    encoded_input = tknz.batch_encode_plus(
            [sent], padding="longest", truncation=True, return_tensors="pt").to(device)
    output = model(src=encoded_input.input_ids, tgt=encoded_input.input_ids).to(device)
    outstr = "".join(tknz.convert_ids_to_tokens(output.squeeze(0).topk(1)[1][1:-1])).replace('##','')
    output_ids = output.squeeze(0).topk(1)[1][1:-1].detach().cpu().numpy().flatten()
    input_ids = encoded_input.input_ids.squeeze().detach().cpu().numpy()[1:-1]
    yesno = (input_ids == output_ids).all()

    if isPrint:
        if (yesno==False) or verbose:

            print('誤:' if not yesno else '正:', end=" 出力:")
            for idx, idy in zip(output_ids, input_ids):
                chr_x, chr_y = tknz.convert_ids_to_tokens([idx])[0], tknz.convert_ids_to_tokens([idy])[0]
                chr_x = chr_x.replace('#','')
                if idx == idy:
                    color='blue'
                    print(colored(chr_x, color), end="")
                else:
                    color='red'
                    print(colored(chr_x, color, attrs=['bold']), end="")
            print(f" <- 入力:{sent}")
    return yesno

eval_a_sent('やめられないとまらないかっぱびえせん', verbose=True, isPrint=True)


正: 出力:[34mやめ[0m[34mら[0m[34mれ[0m[34mない[0m[34mと[0m[34mまら[0m[34mない[0m[34mかっ[0m[34mぱ[0m[34mび[0m[34mえ[0m[34mせん[0m <- 入力:やめられないとまらないかっぱびえせん


True

## 未学習の文を使って検証

In [None]:
sents = ['行く川の流れ絶えずしてしかも元の水にあらず、流れに浮かぶ泡沫は、かつ消え、かつ結びて、久しくとどまることなし',
         '吾輩は猫である。名前はまだない。',
         'とうきょうとっきょきょかきょく',
         '人を殺す魔法ゾルトラークを使うフリーレンと超能力者アーニャの声優は同じだよ。',
         '天は人の上に人を作らず、人の下に人を作らず'
        ]
for sent in sents:
    eval_a_sent(sent, verbose=True, isPrint=True)

In [93]:
fname = '2023_1221terao_tate2.txt'
terao_ = [l.split(',') for l in open(fname).read().strip().split('\n')[1:]]

C_inp = [eval_a_sent(sent[0],isPrint=True, verbose=True) for sent in terao_]
C_err = [eval_a_sent(sent[1],isPrint=True, verbose=True) for sent in terao_]

print('-' * 77)
C_inp_ok = (np.array(C_inp) * 1).sum()
C_err_ok = (np.array(C_err) * 1).sum()
print(f'{C_inp_ok}, {C_inp_ok/len(terao_):.3f}')
print(f'{C_err_ok}, {C_err_ok/len(terao_):.3f}')


正: 出力:[34mまん[0m[34mなか[0m[34mの[0m[34mスト[0m[34mレート[0m[34mだ[0m[34mそう[0m[34mで[0m[34mす[0m <- 入力:まんなかのストレートだそうです
正: 出力:[34mこの[0m[34m景[0m[34m色[0m[34mは[0m[34m頑[0m[34m張[0m[34mって[0m[34m沢[0m[34m登[0m[34mり[0m[34mを[0m[34mした[0m[34m人[0m[34mだけ[0m[34mが[0m[34m見[0m[34mる[0m[34mこと[0m[34mが[0m[34mで[0m[34mき[0m[34mる[0m[34mご[0m[34mほう[0m[34mび[0m <- 入力:この景色は頑張って沢登りをした人だけが見ることができるごほうび
正: 出力:[34mテレビ[0m[34mつけ[0m[34mながら[0m[34m、[0m[34m電[0m[34m話[0m[34mを[0m[34mして[0m[34mる[0m[34mから[0m <- 入力:テレビつけながら、電話をしてるから
正: 出力:[34mみ[0m[34mたま[0m[34mま[0m[34mを[0m[34m克[0m[34m明[0m[34mに[0m[34m解[0m[34m説[0m[34mしま[0m[34mして[0m <- 入力:みたままを克明に解説しまして
正: 出力:[34m成[0m[34m人[0m[34mの[0m[34m日[0m[34mの[0m[34m前[0m[34m後[0m[34mは[0m[34m.[0m[34m.[0m[34m.[0m <- 入力:成人の日の前後は...
正: 出力:[34m年[0m[34m上[0m[34mの[0m[34m人[0m[34mに[0m[34m対[0m[34mし[0m[34mて[0m <- 入力:年上の人に対して
正: 出力:[34m非[0m[34m常[0m[34mに[0m[34m短[0m[34mい[0m[

In [98]:
fname = '2023_1221terao_tate2.txt'
terao_ = [l.split(',') for l in open(fname).read().strip().split('\n')[1:]]
terao_tate = terao_

fname = '2023_1221terao_yoko2.txt'
terao_ = [l.split(',') for l in open(fname).read().strip().split('\n')[1:]]
terao_yoko = terao_

In [99]:
fname = '2023_1221terao_tate2.txt'
terao_ = [l.split(',') for l in open(fname).read().strip().split('\n')[1:]]

C_inp = [eval_a_sent(sent[0],isPrint=True, verbose=True) for sent in terao_]
C_err = [eval_a_sent(sent[1],isPrint=True, verbose=True) for sent in terao_]

print('-' * 77)
C_inp_ok = (np.array(C_inp) * 1).sum()
C_err_ok = (np.array(C_err) * 1).sum()
print(f'{C_inp_ok}, {C_inp_ok/len(terao_):.3f}')
print(f'{C_err_ok}, {C_err_ok/len(terao_):.3f}')


正: 出力:[34mまん[0m[34mなか[0m[34mの[0m[34mスト[0m[34mレート[0m[34mだ[0m[34mそう[0m[34mで[0m[34mす[0m <- 入力:まんなかのストレートだそうです
正: 出力:[34mこの[0m[34m景[0m[34m色[0m[34mは[0m[34m頑[0m[34m張[0m[34mって[0m[34m沢[0m[34m登[0m[34mり[0m[34mを[0m[34mした[0m[34m人[0m[34mだけ[0m[34mが[0m[34m見[0m[34mる[0m[34mこと[0m[34mが[0m[34mで[0m[34mき[0m[34mる[0m[34mご[0m[34mほう[0m[34mび[0m <- 入力:この景色は頑張って沢登りをした人だけが見ることができるごほうび
正: 出力:[34mテレビ[0m[34mつけ[0m[34mながら[0m[34m、[0m[34m電[0m[34m話[0m[34mを[0m[34mして[0m[34mる[0m[34mから[0m <- 入力:テレビつけながら、電話をしてるから
正: 出力:[34mみ[0m[34mたま[0m[34mま[0m[34mを[0m[34m克[0m[34m明[0m[34mに[0m[34m解[0m[34m説[0m[34mしま[0m[34mして[0m <- 入力:みたままを克明に解説しまして
正: 出力:[34m成[0m[34m人[0m[34mの[0m[34m日[0m[34mの[0m[34m前[0m[34m後[0m[34mは[0m[34m.[0m[34m.[0m[34m.[0m <- 入力:成人の日の前後は...
正: 出力:[34m年[0m[34m上[0m[34mの[0m[34m人[0m[34mに[0m[34m対[0m[34mし[0m[34mて[0m <- 入力:年上の人に対して
正: 出力:[34m非[0m[34m常[0m[34mに[0m[34m短[0m[34mい[0m[

In [100]:
class terao_Dataset(torch.utils.data.Dataset):
    def __init__(self,
                 _list:list=terao_tate):

        super().__init__()
        self.list = _list

    def __len__(self):
        return len(self.list)

    def __getitem__(self, idx):
        inp_sent = self.list[idx][0]
        err_sent = self.list[idx][1]
        return inp_sent, err_sent

terao_V_ds = terao_Dataset(_list=terao_tate)
terao_H_ds = terao_Dataset(_list=terao_yoko)

bach_size = 10
V_dl = DataLoader(
    dataset=terao_V_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=_collate_fn)

H_dl = DataLoader(
    dataset=terao_H_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=_collate_fn)

dl = V_dl
inp, tch = next(iter(dl))
encoded_input = tknz.batch_encode_plus(inp,
                                       padding="longest",
                                       truncation=True,
                                       return_tensors="pt").to(device)
print(encoded_input.input_ids.detach().cpu().numpy()[:3])
#print(snow[:3])
#ds = terao_V_ds
#for idx in range(ds.__len__()):
#    inp, goal = ds.__getitem__(idx)
#    print(f'{inp}', f'{goal}')
#    #print(terao_V_ds.__getitem__(idx))
#    #print(terao_H_ds.__getitem__(idx))

[[    2  5461 28446    77  1018 28470     3     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0]
 [    2  1778    30 28452 28539 28531 28450    98   155    12 28761 28517
  28532 28453 28761   926 15076 28446  2122 11665 29124     3     0     0
      0     0     0     0     0     0]
 [    2   683  4058    11  1366  1582  2434  4551 28444 22684 28512 27919
   3721  2266     3     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0]]


In [101]:
import numpy as np
import copy

model_V = Transformer(src_vocab_size=tknz.vocab_size,
                      tgt_vocab_size=tknz.vocab_size,
                      model_dim=config['model_dim'],
                      num_heads=config['num_heads'],
                      num_layers=config['num_layers'],
                      max_seq_length=config['max_seq_length'],
                      dropout=config['dropout'],
                      ff_dim=config['ff_dim']).to(device)

model_H = Transformer(src_vocab_size=tknz.vocab_size,
                      tgt_vocab_size=tknz.vocab_size,
                      model_dim=config['model_dim'],
                      num_heads=config['num_heads'],
                      num_layers=config['num_layers'],
                      max_seq_length=config['max_seq_length'],
                      dropout=config['dropout'],
                      ff_dim=config['ff_dim']).to(device)
#model.eval()
#np.random.permutation(3)
#model_V = copy.deepcopy(model)
#model_H = copy.deepcopy(model)

## ヨコモデルの訓練

In [None]:
from torch.optim import AdamW

#model_V = copy.deepcopy(model)
_model = model_H
_model.train()
ds = terao_H_ds
dl = H_dl

criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
# [Adam](https://arxiv.org/abs/1412.6980) による最適化関数の定義

optimizer = AdamW(_model.parameters(), lr=5e-4)  # 最適化関数を初期化
#optimizer = AdamW(_model.parameters(), lr=5e-5)  # 最適化関数を初期化
#optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

#epochs = 100
epochs = 30
for epoch in range(epochs):
    epoch_loss = 0.
    loop = tqdm(dl, leave=True)
    for batch in loop:
        inp = batch[0]
        encoded_input = tknz.batch_encode_plus(
            inp, padding="longest", truncation=True,  return_tensors="pt").to(device)

        optimizer.zero_grad()
        output = _model(src=encoded_input.input_ids,
                        tgt=encoded_input.input_ids).to(device)
        loss = criterion(output[0], encoded_input.input_ids[0])  # 損失値の計算
        for h in range(1,len(output)):
            loss += criterion(output[h], encoded_input.input_ids[h])

        loss.backward()                      # 誤差逆伝播
        epoch_loss += loss.item()            # 損失値総和
        optimizer.step()                     # 誤差に基づき学習ステップ実行

        loop.set_description(f'エポック {epoch}')
        loop.set_postfix(OrderedDict(loss=loss.item()/len(inp)))

    print(f'epoch:{epoch:03d}',
          f'loss:{epoch_loss/ds.__len__():.5f}')


# ヨコモデルのパラメータを保存

In [35]:
save_checkpoint('2023_1222yoko_only.pt', _model)

## タテモデルの訓練

In [102]:
from torch.optim import AdamW

#model_V = copy.deepcopy(model)
_model = model_V
_model.train()
ds = terao_V_ds
dl = V_dl

criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
# [Adam](https://arxiv.org/abs/1412.6980) による最適化関数の定義

optimizer = AdamW(_model.parameters(), lr=5e-4)  # 最適化関数を初期化
#optimizer = AdamW(_model.parameters(), lr=5e-5)  # 最適化関数を初期化
#optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

#epochs = 100
epochs = 60
for epoch in range(epochs):
    epoch_loss = 0.
    loop = tqdm(dl, leave=True)
    for batch in loop:
        inp = batch[0]
        encoded_input = tknz.batch_encode_plus(
            inp, padding="longest", truncation=True,  return_tensors="pt").to(device)

        optimizer.zero_grad()
        output = _model(src=encoded_input.input_ids,
                        tgt=encoded_input.input_ids).to(device)
        loss = criterion(output[0], encoded_input.input_ids[0])  # 損失値の計算
        for h in range(1,len(output)):
            loss += criterion(output[h], encoded_input.input_ids[h])

        loss.backward()                      # 誤差逆伝播
        epoch_loss += loss.item()            # 損失値総和
        optimizer.step()                     # 誤差に基づき学習ステップ実行

        loop.set_description(f'エポック {epoch}')
        loop.set_postfix(OrderedDict(loss=loss.item()/len(inp)))

    print(f'epoch:{epoch:03d}',
          f'loss:{epoch_loss/ds.__len__():.5f}')


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:000 loss:10.18977


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:001 loss:8.79968


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:002 loss:8.03435


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:003 loss:7.31979


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:004 loss:6.64904


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:005 loss:6.02071


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:006 loss:5.43380


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:007 loss:4.88082


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:008 loss:4.37816


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:009 loss:3.92329


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:010 loss:3.52187


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:011 loss:3.16135


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:012 loss:2.84439


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:013 loss:2.55922


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:014 loss:2.30061


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:015 loss:2.06648


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:016 loss:1.84340


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:017 loss:1.63856


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:018 loss:1.44981


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:019 loss:1.27542


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:020 loss:1.11633


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:021 loss:0.97364


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:022 loss:0.84167


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:023 loss:0.72366


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:024 loss:0.61905


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:025 loss:0.52451


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:026 loss:0.44360


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:027 loss:0.37364


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:028 loss:0.31310


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:029 loss:0.26043


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:030 loss:0.21850


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:031 loss:0.18349


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:032 loss:0.15354


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:033 loss:0.13093


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:034 loss:0.11157


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:035 loss:0.09600


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:036 loss:0.08362


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:037 loss:0.07344


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:038 loss:0.06558


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:039 loss:0.05881


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:040 loss:0.05326


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:041 loss:0.04877


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:042 loss:0.04497


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:043 loss:0.04177


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:044 loss:0.03899


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:045 loss:0.03666


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:046 loss:0.03462


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:047 loss:0.03284


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:048 loss:0.03125


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:049 loss:0.02983


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:050 loss:0.02855


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:051 loss:0.02739


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:052 loss:0.02635


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:053 loss:0.02540


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:054 loss:0.02452


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:055 loss:0.02370


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:056 loss:0.02295


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:057 loss:0.02225


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:058 loss:0.02160


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:059 loss:0.02098


In [103]:
model_V = copy.deepcopy(_model)

# 事前訓練済モデルによる微調整

## ヨコモデルの微調整

In [40]:
from torch.optim import AdamW

_model = copy.deepcopy(model)
#_model = model_H
_model.train()
ds = terao_H_ds
dl = H_dl

criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
# [Adam](https://arxiv.org/abs/1412.6980) による最適化関数の定義

optimizer = AdamW(_model.parameters(), lr=5e-4)  # 最適化関数を初期化
#optimizer = AdamW(_model.parameters(), lr=5e-5)  # 最適化関数を初期化
#optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

#epochs = 100
epochs = 30
for epoch in range(epochs):
    epoch_loss = 0.
    loop = tqdm(dl, leave=True)
    for batch in loop:
        inp = batch[0]
        tch = batch[1]
        encoded_input = tknz.batch_encode_plus(
            inp, padding="longest", truncation=True,  return_tensors="pt").to(device)

        encoded_tch = tknz.batch_encode_plus(
            tch, padding="longest", truncation=True,  return_tensors="pt").to(device)

        optimizer.zero_grad()
        output = _model(src=encoded_input.input_ids,
                        tgt=encoded_tch.input_ids).to(device)
        loss = criterion(output[0], encoded_tch.input_ids[0])  # 損失値の計算
        for h in range(1,len(output)):
            loss += criterion(output[h], encoded_tch.input_ids[h])

        loss.backward()                      # 誤差逆伝播
        epoch_loss += loss.item()            # 損失値総和
        optimizer.step()                     # 誤差に基づき学習ステップ実行

        loop.set_description(f'エポック {epoch}')
        loop.set_postfix(OrderedDict(loss=loss.item()/len(inp)))

    print(f'epoch:{epoch:03d}',
          f'loss:{epoch_loss/ds.__len__():.5f}')


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:000 loss:0.17165


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:001 loss:0.15347


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:002 loss:0.11084


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:003 loss:0.09698


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:004 loss:0.07945


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:005 loss:0.06271


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:006 loss:0.05356


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:007 loss:0.04304


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:008 loss:0.03472


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:009 loss:0.02672


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:010 loss:0.02049


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:011 loss:0.01504


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:012 loss:0.01106


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:013 loss:0.00781


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:014 loss:0.00566


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:015 loss:0.00404


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:016 loss:0.00302


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:017 loss:0.00228


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:018 loss:0.00173


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:019 loss:0.00140


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:020 loss:0.00114


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:021 loss:0.00096


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:022 loss:0.00083


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:023 loss:0.00073


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:024 loss:0.00065


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:025 loss:0.00059


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:026 loss:0.00054


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:027 loss:0.00050


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:028 loss:0.00047


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:029 loss:0.00044


In [50]:
model_yoko = copy.deepcopy(_model)
#save_checkpoint('2023_1222yoko_only.pt', model_yoko)
for s in terao_yoko[:13]:
    #eval_a_sent(s[0], model=model, verbose=True, isPrint=True)
    eval_a_sent(s[0], model=model_yoko, verbose=True, isPrint=True)

正: 出力:[34m店[0m[34m歩[0m[34mい[0m[34mて[0m[34mたら[0m <- 入力:店歩いてたら
正: 出力:[34m松[0m[34m本[0m[34mに[0m[34mは[0m[34mポン[0m[34mポン[0m[34mと[0m[34mストラ[0m[34mイク[0m[34m入[0m[34mっ[0m[34mたん[0m[34mで[0m[34mす[0m <- 入力:松本にはポンポンとストライク入ったんです
正: 出力:[34m常[0m[34m磐[0m[34m線[0m[34mの[0m[34m中[0m[34mで[0m[34mタ[0m[34mバ[0m[34mコ[0m[34m吸[0m[34mって[0m[34mる[0m[34m人[0m[34mがい[0m[34mて[0m <- 入力:常磐線の中でタバコ吸ってる人がいて
正: 出力:[34m近[0m[34mくに[0m[34m銭[0m[34m湯[0m[34mが[0m[34mある[0m[34mん[0m[34mだ[0m <- 入力:近くに銭湯があるんだ
正: 出力:[34mホーム[0m[34mで[0m[34m電[0m[34m車[0m[34mを[0m[34m待[0m[34mって[0m[34mます[0m <- 入力:ホームで電車を待ってます
正: 出力:[34m生[0m[34m徒[0m[34mと[0m[34m先[0m[34m生[0m[34mの[0m[34m関[0m[34m係[0m <- 入力:生徒と先生の関係
正: 出力:[34mお[0m[34mせん[0m[34mべ[0m[34mには[0m[34mお[0m[34m茶[0m[34mが[0m[34mよ[0m[34mく[0m[34m似[0m[34m合[0m[34mう[0m <- 入力:おせんべにはお茶がよく似合う
正: 出力:[34m夜[0m[34mは[0m[34m山[0m[34m菜[0m[34mご[0m[34mはん[0m[34mだ[0m[34mった[0m <- 入力:夜は山

In [111]:
# from torch.optim import AdamW

#_model = copy.deepcopy(model)
_model = copy.deepcopy(model_yoko)
# #_model = model_H
#_model.train()
_model.eval()
ds = terao_H_ds
dl = H_dl

n_corrects = 0
for idx in range(ds.__len__()):
    batch = ds.__getitem__(idx)
    inp, tch = [batch[0]], [batch[1]]
    encoded_input = tknz.batch_encode_plus(
                    inp, padding="longest", truncation=True,  return_tensors="pt").to(device)

    encoded_tch = tknz.batch_encode_plus(
                    tch, padding="longest", truncation=True,  return_tensors="pt").to(device)

    output = _model(src=encoded_input.input_ids,
                    tgt=encoded_tch.input_ids).to(device)
    output_ids = output.squeeze(0).topk(1)[1][1:-1].detach().cpu().numpy().flatten()
    outstr = "".join(tknz.convert_ids_to_tokens(output.squeeze(0).topk(1)[1][1:-1])).replace('##','')
    input_ids = encoded_input.input_ids.squeeze().detach().cpu().numpy()[1:-1]
    #yesno = (input_ids == output_ids).all()
    yesno = outstr == tch[0]
    if yesno == True:
        n_corrects += 1

    if yesno == True:
        color = 'grey'
    else:
        color = 'red'
    print(f'出力文:{colored(outstr,color=color,attrs=["bold"])}',
          f'<- 入力文:{inp[0]}',
          f'{yesno}') #  tch) # output_ids)


print(f'n_corrects:{n_corrects}',
      f'正解率: {n_corrects/ds.__len__()*100:.3f} %'
      )

出力文:[1m[30mまち歩いてたら[0m <- 入力文:店歩いてたら True
出力文:[1m[30m松本にはポンポンとストレート入ったんです[0m <- 入力文:松本にはポンポンとストライク入ったんです True
出力文:[1m[30m常磐線の中で電車吸ってる人がいて[0m <- 入力文:常磐線の中でタバコ吸ってる人がいて True
出力文:[1m[30m近くに近所があるんだ[0m <- 入力文:近くに銭湯があるんだ True
出力文:[1m[30mホームで駅を待ってます[0m <- 入力文:ホームで電車を待ってます True
出力文:[1m[30m生徒と先生の関係[0m <- 入力文:生徒と先生の関係 True
出力文:[1m[30mおせんべにはおかきがよく似合う[0m <- 入力文:おせんべにはお茶がよく似合う True
出力文:[1m[30m夜はあさだったかもしれません[0m <- 入力文:夜は山菜ごはんだった True
出力文:[1m[30mマーガリンはバターにつかえない[0m <- 入力文:マーガリンは料理につかえない True
出力文:[1m[30mあの方は頼むと矢切りの渡しでも何でもわたってくれる[0m <- 入力文:あの方は頼むと矢切りの渡しでも何でもうたってくれる True
出力文:[1m[30m気をつけよう暗い言葉と暗い道[0m <- 入力文:気をつけよう甘い言葉と暗い道 True
出力文:[1m[30m晴天じゃないとてんきが作れない[0m <- 入力文:晴天じゃないとてんぐが作れない True
出力文:[1m[30m打ったことない[0m <- 入力文:ちょっと打ったところない True
出力文:[1m[30m今日から三連休という方があるかもしれませんが、この週末連休が気になる[0m <- 入力文:今日から三連休という方があるかもしれませんが、この週末 天気が気になる True
出力文:[1m[30m人の話に、人車にのる[0m <- 入力文:人の話に、口車にのる True
出力文:[1m[30mアイロンがけは明日にしよう、当分アイロンいらないから[0m <- 入力文:アイロンがけは明日にしよう、当分ワイシャツいらないから True
出力文:[1m

In [115]:
save_checkpoint('2023_1222yoko_finetuned.pt', model_yoko)

## タテモデルの微調整

In [106]:
from torch.optim import AdamW

#_model = copy.deepcopy(model)
_model = model_V
_model.train()
ds = terao_V_ds
dl = V_dl

criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
# [Adam](https://arxiv.org/abs/1412.6980) による最適化関数の定義

optimizer = AdamW(_model.parameters(), lr=5e-4)  # 最適化関数を初期化
#optimizer = AdamW(_model.parameters(), lr=5e-5)  # 最適化関数を初期化
#optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

#epochs = 100
epochs = 30
for epoch in range(epochs):
    epoch_loss = 0.
    loop = tqdm(dl, leave=True)
    for batch in loop:
        inp = batch[0]
        tch = batch[1]
        encoded_input = tknz.batch_encode_plus(
            inp, padding="longest", truncation=True,  return_tensors="pt").to(device)

        encoded_tch = tknz.batch_encode_plus(
            tch, padding="longest", truncation=True,  return_tensors="pt").to(device)

        optimizer.zero_grad()
        output = _model(src=encoded_input.input_ids,
                        tgt=encoded_tch.input_ids).to(device)
        loss = criterion(output[0], encoded_tch.input_ids[0])  # 損失値の計算
        for h in range(1,len(output)):
            loss += criterion(output[h], encoded_tch.input_ids[h])

        loss.backward()                      # 誤差逆伝播
        epoch_loss += loss.item()            # 損失値総和
        optimizer.step()                     # 誤差に基づき学習ステップ実行

        loop.set_description(f'エポック {epoch}')
        loop.set_postfix(OrderedDict(loss=loss.item()/len(inp)))

    print(f'epoch:{epoch:03d}',
          f'loss:{epoch_loss/ds.__len__():.5f}')


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:000 loss:0.65777


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:001 loss:0.48768


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:002 loss:0.45389


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:003 loss:0.39564


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:004 loss:0.34505


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:005 loss:0.29993


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:006 loss:0.26015


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:007 loss:0.22281


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:008 loss:0.18849


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:009 loss:0.15549


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:010 loss:0.12868


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:011 loss:0.10375


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:012 loss:0.08292


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:013 loss:0.06502


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:014 loss:0.05042


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:015 loss:0.03900


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:016 loss:0.02972


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:017 loss:0.02308


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:018 loss:0.01796


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:019 loss:0.01416


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:020 loss:0.01142


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:021 loss:0.00950


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:022 loss:0.00796


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:023 loss:0.00690


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:024 loss:0.00602


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:025 loss:0.00535


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:026 loss:0.00481


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:027 loss:0.00437


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:028 loss:0.00402


  0%|          | 0/2 [00:00<?, ?it/s]

epoch:029 loss:0.00371


In [107]:
model_tate = copy.deepcopy(_model)
#save_checkpoint('2023_1222tate_only.pt', model_tate)
for s in terao_tate[:13]:
    eval_a_sent(s[0], model=model_tate, verbose=True, isPrint=True)

正: 出力:[34mまん[0m[34mなか[0m[34mの[0m[34mスト[0m[34mレート[0m[34mだ[0m[34mそう[0m[34mで[0m[34mす[0m <- 入力:まんなかのストレートだそうです
正: 出力:[34mこの[0m[34m景[0m[34m色[0m[34mは[0m[34m頑[0m[34m張[0m[34mって[0m[34m沢[0m[34m登[0m[34mり[0m[34mを[0m[34mした[0m[34m人[0m[34mだけ[0m[34mが[0m[34m見[0m[34mる[0m[34mこと[0m[34mが[0m[34mで[0m[34mき[0m[34mる[0m[34mご[0m[34mほう[0m[34mび[0m <- 入力:この景色は頑張って沢登りをした人だけが見ることができるごほうび
正: 出力:[34mテレビ[0m[34mつけ[0m[34mながら[0m[34m、[0m[34m電[0m[34m話[0m[34mを[0m[34mして[0m[34mる[0m[34mから[0m <- 入力:テレビつけながら、電話をしてるから
正: 出力:[34mみ[0m[34mたま[0m[34mま[0m[34mを[0m[34m克[0m[34m明[0m[34mに[0m[34m解[0m[34m説[0m[34mしま[0m[34mして[0m <- 入力:みたままを克明に解説しまして
正: 出力:[34m成[0m[34m人[0m[34mの[0m[34m日[0m[34mの[0m[34m前[0m[34m後[0m[34mは[0m[34m.[0m[34m.[0m[34m.[0m <- 入力:成人の日の前後は...
正: 出力:[34m年[0m[34m上[0m[34mの[0m[34m人[0m[34mに[0m[34m対[0m[34mし[0m[34mて[0m <- 入力:年上の人に対して
正: 出力:[34m非[0m[34m常[0m[34mに[0m[34m短[0m[34mい[0m[

In [108]:
model_tate = copy.deepcopy(model_V)

In [None]:
# from torch.optim import AdamW

#_model = copy.deepcopy(model)
_model = copy.deepcopy(model_tate)
# #_model = model_H
#_model.train()
_model.eval()
ds = terao_V_ds
#dl = V_dl

n_corrects = 0
for idx in range(ds.__len__()):
    batch = ds.__getitem__(idx)
    inp, tch = [batch[0]], [batch[1]]
    encoded_input = tknz.batch_encode_plus(
                    inp, padding="longest", truncation=True,  return_tensors="pt").to(device)

    encoded_tch = tknz.batch_encode_plus(
                    tch, padding="longest", truncation=True,  return_tensors="pt").to(device)

    output = _model(src=encoded_input.input_ids,
                    tgt=encoded_tch.input_ids).to(device)
    output_ids = output.squeeze(0).topk(1)[1][1:-1].detach().cpu().numpy().flatten()
    outstr = "".join(tknz.convert_ids_to_tokens(output.squeeze(0).topk(1)[1][1:-1])).replace('##','')
    input_ids = encoded_input.input_ids.squeeze().detach().cpu().numpy()[1:-1]
    #yesno = (input_ids == output_ids).all()
    yesno = outstr == tch[0]
    if yesno == True:
        n_corrects += 1

    if yesno == True:
        color = 'grey'
    else:
        color = 'red'
    print(f'{idx:03d}',
          f'出力文:{colored(outstr,color=color,attrs=["bold"])}',
          f'<- 入力文:{inp[0]}',
          f'{yesno}') #  tch) # output_ids)


print(f'n_corrects:{n_corrects}',
      f'正解率: {n_corrects/ds.__len__()*100:.3f} %'
      )

In [114]:
save_checkpoint('2023_1222tate_finetuned.pt', model_tate)

In [117]:
from google.colab import files
files.download('2023_1222yoko_finetuned.pt')
#files.download('2023_1222tate_finetuned.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [122]:
#!cp -p 2023_1222tate_finetuned.pt /content/drive/MyDrive/colab_data/
!cp -p 2023_1222yoko_finetuned.pt /content/drive/MyDrive/colab_data/

In [120]:
!mkdir -p /content/drive/MyDrive/colab_data/