In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/chinese-couplets/couplet/vocabs
/kaggle/input/chinese-couplets/couplet/test/out.txt
/kaggle/input/chinese-couplets/couplet/test/in.txt
/kaggle/input/chinese-couplets/couplet/test/.in.txt.swp
/kaggle/input/chinese-couplets/couplet/test/.out.txt.swp
/kaggle/input/chinese-couplets/couplet/train/out.txt
/kaggle/input/chinese-couplets/couplet/train/in.txt


In [1]:
import torch.nn as nn
import math
from transformer_model_week09 import Seq2SeqTransformer 
from train_transformer_week09 import build_vocab,collate_fn,generate_square_subsequent_mask,MyDataset
from torch.utils.data import Dataset,DataLoader
import torch

In [2]:
corpus="人生得意须尽欢，莫使金樽空对月"
chs=list(corpus)
enc_tokens,dec_tokens=[],[]
for i in range(1,len(chs)):
    enc=chs[:i]
    dec=['<s>']+chs[i:]+['</s>']
    enc_tokens.append(enc)
    dec_tokens.append(dec)
##构建词典
enc_vocab=build_vocab(enc_tokens)
dec_vocab=build_vocab(dec_tokens)
inv_dec_vocab={v:k for k,v in dec_vocab.items()}

##构建数据集和dataloader
dataset=MyDataset(enc_tokens,dec_tokens,enc_vocab,dec_vocab)
dataloader=DataLoader(dataset,batch_size=2,shuffle=True,collate_fn=collate_fn)

##模型参数
d_model=32
nhead=4
num_enc_layers=2
num_dec_layers=2
dim_forward=64
dropout=0.1
enc_voc_size=len(enc_vocab)
dec_voc_size=len(dec_vocab)

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=Seq2SeqTransformer(d_model,nhead,num_enc_layers,num_dec_layers
                         ,dim_forward,dropout,enc_voc_size,dec_voc_size).to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=1e-3)
loss_fn=torch.nn.CrossEntropyLoss(ignore_index=0)

#训练
for epoch in range(50):
    model.train()
    total_loss=0
    for enc_batch,dec_in,dec_out in dataloader:
        enc_batch,dec_in,dec_out=enc_batch.to(device),dec_in.to(device),dec_out.to(device)
        tgt_mask=generate_square_subsequent_mask(dec_in.size(1)).to(device)
        enc_pad_mask=(enc_batch==0)
        dec_pad_mask=(dec_in==0)
        logits=model(enc_batch,dec_in,tgt_mask,enc_pad_mask,dec_pad_mask)

        loss=loss_fn(logits.reshape(-1,logits.size(-1)),dec_out.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()
    print(f"Epoch{epoch+1},Loss:{total_loss/len(dataloader)}")

##保存模型
# 方式二：推荐 - 只保存模型参数（state_dict）
torch.save(model.state_dict(), 'model_transformer.pth')


        








Epoch1,Loss:2.637634447642735
Epoch2,Loss:2.3152852739606584
Epoch3,Loss:2.064793348312378
Epoch4,Loss:1.8241918427603585
Epoch5,Loss:1.546981964792524
Epoch6,Loss:1.4836711713245936
Epoch7,Loss:1.3184285163879395
Epoch8,Loss:1.2160266126905168
Epoch9,Loss:1.1633753180503845
Epoch10,Loss:1.0630184412002563
Epoch11,Loss:0.9997739706720624
Epoch12,Loss:0.9749292305537632
Epoch13,Loss:0.8870891758373806
Epoch14,Loss:0.8453975575310844
Epoch15,Loss:0.7804585354668754
Epoch16,Loss:0.6980357681001935
Epoch17,Loss:0.6932005115917751
Epoch18,Loss:0.6870882085391453
Epoch19,Loss:0.6303758450916835
Epoch20,Loss:0.5855180237974439
Epoch21,Loss:0.6186241379805973
Epoch22,Loss:0.6231013749326978
Epoch23,Loss:0.5645113289356232
Epoch24,Loss:0.535359902041299
Epoch25,Loss:0.5496102358613696
Epoch26,Loss:0.5309996179171971
Epoch27,Loss:0.4968283176422119
Epoch28,Loss:0.48055081708090647
Epoch29,Loss:0.4605814005647387
Epoch30,Loss:0.40119126439094543
Epoch31,Loss:0.4205294634614672
Epoch32,Loss:0.4172