In [1]:
import os
import json
import time
import random
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from util.conf import *
from tqdm import tqdm
from model import Predictor as Net
from scipy.stats import pearsonr

def preprocess_data(data, crop_size):
    """预处理数据，包括随机裁剪和标准化"""
    # 随机裁剪
    original_size = data.shape[1]
    start_idx = np.random.randint(0, original_size - crop_size + 1)
    cropped_data = data[:, start_idx:start_idx + crop_size]
    
    # 使用StandardScaler进行标准化
    # scaler = StandardScaler()
    # scaled_data = scaler.fit_transform(cropped_data)
    
    return cropped_data

#定义加载数据集函数
def DataLoaderGenerator(directory, input_indices, test_size):
    x_data_list = []
    y_data_list = []

    for filename in os.listdir(directory):
        if filename.endswith('.csv'):
            file_path = os.path.join(directory, filename)
            df = pd.read_csv(file_path)
            
             # 确保数据行数为64行
            if df.shape[0] != 64:
                raise ValueError(f"CSV文件 {filename} 应包含64行数据")  
            
            for i in range(10):
                df = preprocess_data(df.to_numpy(), 8192)
                df = pd.DataFrame(df)
            
                x_data = df.iloc[input_indices].to_numpy()
                y_data = df.drop(index=input_indices).to_numpy()

                x_data_list.append(x_data)
                y_data_list.append(y_data)
            
    
    # 将列表转换为三维NumPy数组
    x_data_array = np.array(x_data_list)
    y_data_array = np.array(y_data_list)
    
    # 三维数据转换为四维数据，用于2D卷积
    # x_data_array = np.expand_dims(x_data_array, axis=1)
    # y_data_array = np.expand_dims(y_data_array, axis=1)

    # 分割数据集为训练集和测试集
    x_train, x_test, y_train, y_test = train_test_split(x_data_array, y_data_array, test_size=test_size, random_state=42)
    
    print("-----------------------------------------")
    print("训练集(x_train)的大小:",x_train.shape)
    print("测试集(x_test)的大小:",x_test.shape)
    print("-----------------------------------------")

    # 转换为PyTorch张量
    x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
    x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

    #创建训练集和测试集
    train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
    test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
    
    return train_dataset, test_dataset

# 训练模型的函数
def train_model(model, device, train_loader, criterion, optimizer, clip):
    model.train()
    epoch_loss = 0.0
    epoch_correlation_score = 0.0
    train_loss = 0.0
    train_correlation_score = 0.0
    cnt = 0
    for i, batch in tqdm(enumerate(train_loader)):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = outputs.cpu()
        labels = labels.cpu()

        loss = criterion(outputs, labels)
        loss.backward()
        
        #梯度裁剪与爆炸检测
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        update = True
        for pp in model.parameters():
            if pp.requires_grad and pp.grad is not None:
                if torch.isnan(pp.grad).any():
                    update=False
                    break
        if update:
            optimizer.step()
        else:
            print('Gradient explosion and discard this batch!!!')

        y_trues = labels.numpy()
        y_preds = outputs.detach().numpy()
        correlation_score = correlation(y_trues, y_preds)
        
        epoch_correlation_score += correlation_score
        epoch_loss += loss.item()
        cnt = i
    train_loss = epoch_loss / (cnt+1)
    train_correlation_score = epoch_correlation_score / (cnt+1)
    return train_loss, train_correlation_score

# 评估模型的函数
def evaluate_model(model, device, test_loader, criterion):
    model.eval()
    epoch_loss = 0.0
    epoch_correlation_score = 0.0
    test_loss = 0.0
    test_correlation_score = 0.0
    cnt = 0
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_loader)):
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            outputs = outputs.cpu()
            labels = labels.cpu()
                
            loss = criterion(outputs, labels)

            y_trues = labels.numpy()
            y_preds = outputs.detach().numpy()
            correlation_score = correlation(y_trues, y_preds)

            epoch_correlation_score += correlation_score
            epoch_loss += loss.item()
            cnt = i
    test_loss = epoch_loss / (cnt+1)
    test_correlation_score = epoch_correlation_score / (cnt+1)
    return test_loss, test_correlation_score

# 参数量统计函数
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# 初始化权重函数
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform_(m.weight.data)
        
# 时间统计函数
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# 计算相关系数函数
def correlation(y_trues, y_pred):
    all_correlation = []
    for i in range(BATCH_SIZE):
        for j in range(OUTPUT_CHANNELS):
            correlation, _ = pearsonr(y_trues[i,j,:], y_pred[i,j,:])
            all_correlation.append(correlation)
    return np.mean(all_correlation)

# 设置随机种子
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.determinsitic=True


def run(model, device, train_dataset, test_dataset, total_epoch, best_loss, optimizer, criterion, clip):
    losses = {'train':[], 'val':[]}
    acces = {'train':[], 'val':[]}
    for step in range(total_epoch):
        print('Epoch {} / {}:'.format(step + 1, total_epoch))
        start_time = time.time()
        print('Training...')
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        train_loss, train_acc = train_model(model, device, train_loader, criterion, optimizer, clip)
        print('Evaluating...')
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
        val_loss, val_acc = evaluate_model(model, device, test_loader, criterion)
        end_time = time.time()

        if step > WARMUP:
            scheduler.step(val_loss)

        losses['train'].append(train_loss)
        losses['val'].append(val_loss)
        acces['train'].append(train_acc)
        acces['val'].append(val_acc)

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), './saved/best_model_%s_%s_%d.pt'%(MODEL, DATASET, SEED))
        torch.save(model.state_dict(), './saved/latest_model_%s_%s_%d.pt'%(MODEL, DATASET, SEED))
        
        with open('result/result_%s_%s_%d.json'%(MODEL, DATASET, SEED), 'w') as f:
            json.dump({'loss':losses, 'acc':acces}, f)
        
        print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc:.3f}')
        print(f'\tVal Loss: {val_loss:.3f} | Val Acc: {val_acc:.3f}')


if __name__ == '__main__':
    setup_seed(SEED)

    print('Loading dataset...')
    train_dataset, test_dataset= DataLoaderGenerator(DIRECTORY, INPUT_INDICES, TEST_SIZE)
    
    model = Net(INPUT_CHANNELS, HIDDEN_CHANNELS, OUTPUT_CHANNELS, model = MODEL)
    print(f'The model has {count_parameters(model):,} trainable parameters in total')
    
    model.apply(initialize_weights)
    model = model.to(DEVICE)
  
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR, weight_decay=WEIGHT_DECAY,eps=ADAM_EPS)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, verbose=True, factor=FACTOR, patience=PATIENCE)
    criterion = nn.MSELoss()

    run(model=model, device=DEVICE, train_dataset=train_dataset, test_dataset=test_dataset, total_epoch=EPOCH, best_loss=INF, 
        optimizer=optimizer, criterion=criterion, clip=CLIP)

Loading dataset...
-----------------------------------------
训练集(x_train)的大小: (512, 8, 8192)
测试集(x_test)的大小: (128, 8, 8192)
-----------------------------------------
The model has 1,458,232 trainable parameters in total
Epoch 1 / 50:
Training...


8it [02:15, 16.88s/it]


Evaluating...


2it [00:14,  7.28s/it]


Epoch: 1 | Time: 2m 29s
	Train Loss: 42.701 | Train Acc: 0.576
	Val Loss: 213.554 | Val Acc: 0.754
Epoch 2 / 50:
Training...


8it [02:15, 16.89s/it]


Evaluating...


2it [00:14,  7.32s/it]


Epoch: 2 | Time: 2m 29s
	Train Loss: 26.818 | Train Acc: 0.766
	Val Loss: 67.287 | Val Acc: 0.771
Epoch 3 / 50:
Training...


8it [02:17, 17.20s/it]


Evaluating...


2it [00:14,  7.18s/it]


Epoch: 3 | Time: 2m 31s
	Train Loss: 25.298 | Train Acc: 0.781
	Val Loss: 30.592 | Val Acc: 0.795
Epoch 4 / 50:
Training...


8it [02:17, 17.15s/it]


Evaluating...


2it [00:14,  7.32s/it]


Epoch: 4 | Time: 2m 31s
	Train Loss: 23.593 | Train Acc: 0.795
	Val Loss: 25.217 | Val Acc: 0.804
Epoch 5 / 50:
Training...


8it [02:16, 17.07s/it]


Evaluating...


2it [00:14,  7.38s/it]


Epoch: 5 | Time: 2m 31s
	Train Loss: 22.692 | Train Acc: 0.803
	Val Loss: 24.491 | Val Acc: 0.816
Epoch 6 / 50:
Training...


8it [02:17, 17.21s/it]


Evaluating...


2it [00:14,  7.44s/it]


Epoch: 6 | Time: 2m 32s
	Train Loss: 22.130 | Train Acc: 0.810
	Val Loss: 25.350 | Val Acc: 0.819
Epoch 7 / 50:
Training...


8it [02:17, 17.23s/it]


Evaluating...


2it [00:15,  7.53s/it]


Epoch: 7 | Time: 2m 32s
	Train Loss: 21.339 | Train Acc: 0.816
	Val Loss: 26.945 | Val Acc: 0.823
Epoch 8 / 50:
Training...


8it [02:21, 17.71s/it]


Evaluating...


2it [00:15,  7.90s/it]


Epoch: 8 | Time: 2m 37s
	Train Loss: 20.856 | Train Acc: 0.820
	Val Loss: 26.429 | Val Acc: 0.827
Epoch 9 / 50:
Training...


8it [02:27, 18.46s/it]


Evaluating...


2it [00:15,  7.59s/it]


Epoch: 9 | Time: 2m 42s
	Train Loss: 20.396 | Train Acc: 0.824
	Val Loss: 23.704 | Val Acc: 0.832
Epoch 10 / 50:
Training...


8it [02:21, 17.68s/it]


Evaluating...


2it [00:15,  7.54s/it]


Epoch: 10 | Time: 2m 36s
	Train Loss: 19.857 | Train Acc: 0.826
	Val Loss: 23.428 | Val Acc: 0.834
Epoch 11 / 50:
Training...


8it [02:18, 17.28s/it]


Evaluating...


2it [00:15,  7.53s/it]


Epoch: 11 | Time: 2m 33s
	Train Loss: 19.724 | Train Acc: 0.828
	Val Loss: 23.039 | Val Acc: 0.837
Epoch 12 / 50:
Training...


8it [02:17, 17.17s/it]


Evaluating...


2it [00:14,  7.40s/it]


Epoch: 12 | Time: 2m 32s
	Train Loss: 19.267 | Train Acc: 0.830
	Val Loss: 23.553 | Val Acc: 0.835
Epoch 13 / 50:
Training...


8it [02:17, 17.18s/it]


Evaluating...


2it [00:14,  7.21s/it]


Epoch: 13 | Time: 2m 31s
	Train Loss: 18.739 | Train Acc: 0.833
	Val Loss: 21.238 | Val Acc: 0.841
Epoch 14 / 50:
Training...


8it [02:18, 17.35s/it]


Evaluating...


2it [00:14,  7.49s/it]


Epoch: 14 | Time: 2m 33s
	Train Loss: 18.673 | Train Acc: 0.834
	Val Loss: 22.100 | Val Acc: 0.835
Epoch 15 / 50:
Training...


8it [02:20, 17.54s/it]


Evaluating...


2it [00:14,  7.43s/it]


Epoch: 15 | Time: 2m 35s
	Train Loss: 18.592 | Train Acc: 0.834
	Val Loss: 20.614 | Val Acc: 0.843
Epoch 16 / 50:
Training...


8it [02:19, 17.48s/it]


Evaluating...


2it [00:15,  7.53s/it]


Epoch: 16 | Time: 2m 34s
	Train Loss: 17.882 | Train Acc: 0.838
	Val Loss: 20.630 | Val Acc: 0.841
Epoch 17 / 50:
Training...


8it [02:19, 17.38s/it]


Evaluating...


2it [00:14,  7.26s/it]


Epoch: 17 | Time: 2m 33s
	Train Loss: 17.827 | Train Acc: 0.838
	Val Loss: 20.131 | Val Acc: 0.848
Epoch 18 / 50:
Training...


8it [02:19, 17.43s/it]


Evaluating...


2it [00:14,  7.48s/it]


Epoch: 18 | Time: 2m 34s
	Train Loss: 17.605 | Train Acc: 0.840
	Val Loss: 19.934 | Val Acc: 0.850
Epoch 19 / 50:
Training...


8it [02:21, 17.72s/it]


Evaluating...


2it [00:14,  7.37s/it]


Epoch: 19 | Time: 2m 36s
	Train Loss: 17.408 | Train Acc: 0.841
	Val Loss: 21.090 | Val Acc: 0.847
Epoch 20 / 50:
Training...


8it [02:16, 17.08s/it]


Evaluating...


2it [00:14,  7.24s/it]


Epoch: 20 | Time: 2m 31s
	Train Loss: 17.447 | Train Acc: 0.841
	Val Loss: 18.759 | Val Acc: 0.850
Epoch 21 / 50:
Training...


8it [02:16, 17.01s/it]


Evaluating...


2it [00:14,  7.40s/it]


Epoch: 21 | Time: 2m 30s
	Train Loss: 17.367 | Train Acc: 0.841
	Val Loss: 21.044 | Val Acc: 0.839
Epoch 22 / 50:
Training...


8it [02:18, 17.26s/it]


Evaluating...


2it [00:14,  7.38s/it]


Epoch: 22 | Time: 2m 32s
	Train Loss: 17.185 | Train Acc: 0.842
	Val Loss: 18.255 | Val Acc: 0.852
Epoch 23 / 50:
Training...


8it [02:18, 17.32s/it]


Evaluating...


2it [00:15,  7.59s/it]


Epoch: 23 | Time: 2m 33s
	Train Loss: 17.138 | Train Acc: 0.843
	Val Loss: 18.203 | Val Acc: 0.852
Epoch 24 / 50:
Training...


8it [02:19, 17.39s/it]


Evaluating...


2it [00:15,  7.90s/it]


Epoch: 24 | Time: 2m 34s
	Train Loss: 16.800 | Train Acc: 0.845
	Val Loss: 18.470 | Val Acc: 0.852
Epoch 25 / 50:
Training...


8it [02:24, 18.07s/it]


Evaluating...


2it [00:15,  7.99s/it]


Epoch: 25 | Time: 2m 40s
	Train Loss: 16.731 | Train Acc: 0.845
	Val Loss: 19.088 | Val Acc: 0.854
Epoch 26 / 50:
Training...


8it [02:24, 18.08s/it]


Evaluating...


2it [00:14,  7.47s/it]


Epoch: 26 | Time: 2m 39s
	Train Loss: 16.747 | Train Acc: 0.846
	Val Loss: 19.950 | Val Acc: 0.852
Epoch 27 / 50:
Training...


8it [02:21, 17.66s/it]


Evaluating...


2it [00:14,  7.42s/it]


Epoch: 27 | Time: 2m 36s
	Train Loss: 16.612 | Train Acc: 0.846
	Val Loss: 18.124 | Val Acc: 0.854
Epoch 28 / 50:
Training...


8it [02:17, 17.25s/it]


Evaluating...


2it [00:15,  7.70s/it]


Epoch: 28 | Time: 2m 33s
	Train Loss: 16.400 | Train Acc: 0.847
	Val Loss: 17.696 | Val Acc: 0.854
Epoch 29 / 50:
Training...


8it [02:19, 17.50s/it]


Evaluating...


2it [00:15,  7.55s/it]


Epoch: 29 | Time: 2m 35s
	Train Loss: 16.361 | Train Acc: 0.847
	Val Loss: 18.897 | Val Acc: 0.856
Epoch 30 / 50:
Training...


8it [02:17, 17.18s/it]


Evaluating...


2it [00:14,  7.27s/it]


Epoch: 30 | Time: 2m 32s
	Train Loss: 16.201 | Train Acc: 0.849
	Val Loss: 19.286 | Val Acc: 0.854
Epoch 31 / 50:
Training...


8it [02:18, 17.28s/it]


Evaluating...


2it [00:15,  7.51s/it]


Epoch: 31 | Time: 2m 33s
	Train Loss: 16.249 | Train Acc: 0.849
	Val Loss: 19.675 | Val Acc: 0.855
Epoch 32 / 50:
Training...


8it [02:14, 16.78s/it]


Evaluating...


2it [00:14,  7.16s/it]


Epoch: 32 | Time: 2m 28s
	Train Loss: 16.180 | Train Acc: 0.849
	Val Loss: 17.507 | Val Acc: 0.854
Epoch 33 / 50:
Training...


8it [02:13, 16.69s/it]


Evaluating...


2it [00:14,  7.21s/it]


Epoch: 33 | Time: 2m 27s
	Train Loss: 16.332 | Train Acc: 0.849
	Val Loss: 17.866 | Val Acc: 0.857
Epoch 34 / 50:
Training...


8it [02:17, 17.23s/it]


Evaluating...


2it [00:14,  7.22s/it]


Epoch: 34 | Time: 2m 32s
	Train Loss: 16.261 | Train Acc: 0.850
	Val Loss: 17.513 | Val Acc: 0.858
Epoch 35 / 50:
Training...


8it [02:15, 16.97s/it]


Evaluating...


2it [00:14,  7.24s/it]


Epoch: 35 | Time: 2m 30s
	Train Loss: 16.039 | Train Acc: 0.850
	Val Loss: 17.627 | Val Acc: 0.858
Epoch 36 / 50:
Training...


8it [02:19, 17.38s/it]


Evaluating...


2it [00:15,  7.55s/it]


Epoch: 36 | Time: 2m 34s
	Train Loss: 16.024 | Train Acc: 0.850
	Val Loss: 16.786 | Val Acc: 0.859
Epoch 37 / 50:
Training...


8it [02:17, 17.23s/it]


Evaluating...


2it [00:14,  7.33s/it]


Epoch: 37 | Time: 2m 32s
	Train Loss: 16.018 | Train Acc: 0.850
	Val Loss: 17.364 | Val Acc: 0.859
Epoch 38 / 50:
Training...


8it [02:21, 17.67s/it]


Evaluating...


2it [00:16,  8.21s/it]


Epoch: 38 | Time: 2m 37s
	Train Loss: 15.963 | Train Acc: 0.850
	Val Loss: 18.710 | Val Acc: 0.854
Epoch 39 / 50:
Training...


8it [02:20, 17.55s/it]


Evaluating...


2it [00:15,  7.54s/it]


Epoch: 39 | Time: 2m 35s
	Train Loss: 16.014 | Train Acc: 0.851
	Val Loss: 17.737 | Val Acc: 0.861
Epoch 40 / 50:
Training...


8it [02:18, 17.31s/it]


Evaluating...


2it [00:14,  7.20s/it]


Epoch: 40 | Time: 2m 32s
	Train Loss: 15.735 | Train Acc: 0.852
	Val Loss: 16.637 | Val Acc: 0.859
Epoch 41 / 50:
Training...


8it [02:17, 17.14s/it]


Evaluating...


2it [00:14,  7.23s/it]


Epoch: 41 | Time: 2m 31s
	Train Loss: 15.702 | Train Acc: 0.852
	Val Loss: 17.027 | Val Acc: 0.859
Epoch 42 / 50:
Training...


8it [02:17, 17.23s/it]


Evaluating...


2it [00:14,  7.24s/it]


Epoch: 42 | Time: 2m 32s
	Train Loss: 15.794 | Train Acc: 0.852
	Val Loss: 16.933 | Val Acc: 0.860
Epoch 43 / 50:
Training...


8it [02:18, 17.36s/it]


Evaluating...


2it [00:14,  7.50s/it]


Epoch: 43 | Time: 2m 33s
	Train Loss: 15.829 | Train Acc: 0.852
	Val Loss: 16.732 | Val Acc: 0.859
Epoch 44 / 50:
Training...


8it [02:17, 17.18s/it]


Evaluating...


2it [00:14,  7.26s/it]


Epoch: 44 | Time: 2m 32s
	Train Loss: 15.821 | Train Acc: 0.851
	Val Loss: 16.867 | Val Acc: 0.860
Epoch 45 / 50:
Training...


8it [02:19, 17.43s/it]


Evaluating...


2it [00:14,  7.35s/it]


Epoch: 45 | Time: 2m 34s
	Train Loss: 15.733 | Train Acc: 0.852
	Val Loss: 16.934 | Val Acc: 0.862
Epoch 46 / 50:
Training...


8it [02:16, 17.11s/it]


Evaluating...


2it [00:14,  7.19s/it]


Epoch: 46 | Time: 2m 31s
	Train Loss: 15.331 | Train Acc: 0.854
	Val Loss: 16.897 | Val Acc: 0.864
Epoch 47 / 50:
Training...


8it [02:20, 17.60s/it]


Evaluating...


2it [00:14,  7.23s/it]


Epoch: 47 | Time: 2m 35s
	Train Loss: 15.392 | Train Acc: 0.854
	Val Loss: 16.602 | Val Acc: 0.864
Epoch 48 / 50:
Training...


8it [02:20, 17.58s/it]


Evaluating...


2it [00:15,  7.97s/it]


Epoch: 48 | Time: 2m 36s
	Train Loss: 15.433 | Train Acc: 0.854
	Val Loss: 16.284 | Val Acc: 0.863
Epoch 49 / 50:
Training...


8it [02:20, 17.56s/it]


Evaluating...


2it [00:14,  7.43s/it]


Epoch: 49 | Time: 2m 35s
	Train Loss: 15.774 | Train Acc: 0.853
	Val Loss: 16.355 | Val Acc: 0.862
Epoch 50 / 50:
Training...


8it [02:17, 17.21s/it]


Evaluating...


2it [00:15,  7.58s/it]

Epoch: 50 | Time: 2m 32s
	Train Loss: 15.504 | Train Acc: 0.853
	Val Loss: 16.283 | Val Acc: 0.861



