In [6]:
# necessary imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch import optim
from tqdm import tqdm
from dataload import MyDataset
from lgg_model import vanilla_LSTM, LSTM_enhanced

In [7]:
# generate dataset and dataloader
BS = 32
input_word_count = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_data_path = input("请输入训练语料库文件名(无需添加后缀)：")
train_data_path = "texts/" + train_data_path + '.txt'
val_data_path = input("请输入验证语料库文件名(无需添加后缀)：")
val_data_path = "texts/" + val_data_path + '.txt'

train_dataset = MyDataset(train_data_path, input_word_count, save_word_model=False)
train_dataloader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
val_dataset = MyDataset(val_data_path, input_word_count, save_word_model=False)
val_dataloader = DataLoader(val_dataset, batch_size=BS, shuffle=True)

In [8]:
# test
train_dataset.vocabulary_length, val_dataset.vocabulary_length

(709, 191)

In [9]:
from torch.utils.tensorboard import SummaryWriter 

In [10]:
# some components in the training process
LR = 0.001 # the learning rate of 0.001 is still too large, maybe needs lr_decay or batch_norm
num_epoches = 200
net = LSTM_enhanced(train_dataset.vocabulary_length, 100, 100, 4, input_word_count=input_word_count).to(device)
optimizer = optim.Adam(net.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

# 2022/2/27 add a lr decay controller
ctrl = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

# start training!
for epoch in tqdm(range(num_epoches)):
    # train
    train_loss = 0
    train_count = 0
    for i, data in enumerate(train_dataloader):
        train_count += 1
        data = data.to(device)
        data = data.to(torch.long)
        label = data[:,1:]
        out = net(data)[:,:-1,:]
        out = torch.transpose(out, 2, 1)

        optimizer.zero_grad()
        loss = criterion(out, label)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_avg_loss = train_loss/train_count
    with SummaryWriter('lines/train') as writer:
        writer.add_scalar('Loss/Epoch', train_avg_loss, epoch+1) # 'epoch+1' because epoch starts from 0
        writer.flush()
    ctrl.step() # lr decay
    
    # validation
    val_loss = 0
    val_count = 0
    for i, data in enumerate(val_dataloader):
        val_count += 1
        data = data.to(device)
        data = data.to(torch.long)
        label = data[:,1:]
        out = net(data)[:,:-1,:]
        out = torch.transpose(out, 2, 1)
        loss = criterion(out, label)
        val_loss += loss.item()
    
    val_avg_loss = val_loss/val_count
    with SummaryWriter('lines/val') as writer:
        writer.add_scalar('Loss/Epoch', val_avg_loss, epoch+1) # 'epoch+1' because epoch starts from 0
        writer.flush()
    
print("Finish training!")

100%|██████████| 200/200 [40:08<00:00, 12.04s/it]

Finish training!





In [7]:
# if you want to save your language model...
str = input("请输入语言模型的名称：")
torch.save(net, "lgg_model_paths/"+str)