## 训练

In [None]:
import torch
from codes.model import  ResUnet, ResUnet_LSTM, Loss_With_Weight
from torch.utils.data import DataLoader
from codes.dataloader import init_dataset, STEAD_Dataset
from codes.visualize_and_evaluate import evaluate
from codes import configs

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

train_set,val_set,noise_set = init_dataset()
train_set = STEAD_Dataset(train_set)
val_set = STEAD_Dataset(val_set)

train_dataloader = DataLoader(train_set, shuffle=True, batch_size=configs.BATCH_SIZE, num_workers=64, drop_last=True)
eval_dataloader = DataLoader(val_set, shuffle=True,batch_size=1, num_workers=16, drop_last=True)

mseLoss = torch.nn.MSELoss()
QuackNet = ResUnet().to(device)# 不带lstm的版本
QuackNet = ResUnet_LSTM().to(device)# 带lstm的版本

optimizer = torch.optim.Adam(QuackNet.parameters(),
                lr=configs.LEARNING_RATE,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0,
                amsgrad=False)

total_step = 0
loss_sum = 0
loss_list = []
for epoch in range(configs.EPOCH_NUM):
    # training
    for index,data in enumerate(train_dataloader,0):
        stream,label_p,label_s,p_start,s_start,coda_end = data

        # 清空累加梯度
        QuackNet.zero_grad()
        output = QuackNet(stream.to(device))
        loss = 0
        loss += Loss_With_Weight(output[:,0,:], label_p.to(device))
        loss += Loss_With_Weight(output[:,1,:], label_s.to(device))
        loss.backward()
        loss_sum += loss.item()
        optimizer.step()
        total_step += 1
        if total_step%configs.LOSS_RECORD_ITER == 0:
            loss_list.append({
                "total_step":total_step,
                "loss:":loss.item()
            })
            print("total_step:{} loss:{}".format(total_step,loss_sum/configs.LOSS_RECORD_ITER))
            loss_sum = 0

        # if total_step%configs.CHECKPOINT_ITER == 0:
        #     torch.save(QuackNet.eval(),configs.model_save_dir+"QuackPicker_iter{}.pth".format(total_step))
        #     evaluate(configs.model_save_dir+"QuackPicker_iter{}.pth".format(total_step),"batchsize:{} steps:{}".format(configs.BATCH_SIZE,total_step),eval_dataloader,50,40)

    torch.save(QuackNet.eval(),configs.model_save_dir+"QuackPicker_iter{}.pth".format(total_step))
evaluate(configs.model_save_dir+"QuackPicker_iter{}.pth".format(total_step),"batchsize:{} steps:{}".format(configs.BATCH_SIZE,total_step),eval_dataloader,50,45)    
        

## 评估

In [None]:
from codes.visualize_and_evaluate import evaluate

model_path = ''
evaluate(configs.model_save_dir+"QuackPicker_iter{}.pth".format(total_step),"batchsize:{} steps:{}".format(configs.BATCH_SIZE,total_step),eval_dataloader,50,45)

## 保存损失变化曲线

In [None]:
df_loss = pd.DataFrame(loss_list,columns=header,index=None)
eva_df_old.append(eva_df_new).to_csv(configs.evaluation_csv_path,index=None)