In [None]:
# necessary imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch import optim
from tqdm import tqdm
from dataload import MyDataset
# import all the models
from lgg_model import *

In [None]:
# generate dataset and dataloader
BS = 32
input_word_count = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_path = input("请输入数据集名称（无需添加后缀）：")
data_path = "texts/" + data_path + ".txt"
full_dataset = MyDataset(data_path, input_word_count, save_word_model=False)
vocabulary_length = full_dataset.vocabulary_length

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BS, shuffle=True)

In [None]:
vocabulary_length

In [None]:
from torch.utils.tensorboard import SummaryWriter 
writer_train = SummaryWriter('lines/train')
writer_val = SummaryWriter('lines/val')

In [None]:
# some components in the training process
LR = 0.005 # the learning rate of 0.001 is still too large, maybe needs lr_decay or batch_norm
num_epoches = 100
net = LSTM_enhanced(vocabulary_length, 50, 100, 2).to(device)
optimizer = optim.Adam(net.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

# 2022/2/27 add a lr decay controller
ctrl = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)

# start training!
for epoch in tqdm(range(num_epoches)):
    # train
    net.train()
    train_loss = 0
    for i, data in enumerate(train_dataloader):
        data = data.to(device)
        data = data.to(torch.long)
        label = data[:,1:]
        out = net(data)[:,:-1,:]
        out = torch.transpose(out, 2, 1)

        optimizer.zero_grad()
        loss = criterion(out, label)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_avg_loss = train_loss / len(train_dataloader)
    writer_train.add_scalar('Loss/Epoch', train_avg_loss, epoch+1) # epoch+1 because epoch starts from 0
    writer_train.flush()
    ctrl.step() # lr decay
    
    # validation
    net.eval()
    val_loss = 0
    with torch.no_grad():
        for i, data in enumerate(val_dataloader):
            data = data.to(device)
            data = data.to(torch.long)
            label = data[:,1:]
            out = net(data)[:,:-1,:]
            out = torch.transpose(out, 2, 1)
            loss = criterion(out, label)
            val_loss += loss.item()
    
    val_avg_loss = val_loss / len(val_dataloader)
    writer_val.add_scalar('Loss/Epoch', val_avg_loss, epoch+1) # epoch+1 because epoch starts from 0
    writer_val.flush()
    
print("Finish training!")

In [None]:
# if you want to save your language model...
str = input("请输入语言模型的名称：")
torch.save(net, "lgg_model_paths/"+str)