In [1]:
import os
import argparse
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader
import time
import torch
from GEvaluator import Evaluator
from tqdm import tqdm
from Mid import GINGraphPooling
from Module import load_data,train,evaluate,test,prepartion,continue_train
print('torch version:',torch.__version__)


torch version: 2.0.1+cu118


参数输入

In [2]:
class MyNamespace(argparse.Namespace):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.batch_size = 20
        self.device=0
        self.drop_ratio=0.15
        self.early_stop=30
        self.early_stop_open = True
        self.emb_dim=128
        self.epochs=2
        self.graph_pooling='mean'
        self.num_layers=3
        self.n_head=3
        self.num_workers=5
        self.num_tasks=1
        self.save_test=True
        self.task_name='GINGraph_smiles_3d_test'
        self.weight_decay=0.1e-05
        self.learning_rate=0.0001
        self.data_type='smiles'
        self.dataset_pt = './PTs/smiles_3d'
        self.dataset_split=[0.8,0.19,0.01]
        self.evaluate_epoch=1
        self.continue_train=False
        self.checkpoint_path='/home/ml/hctang/TGNN/saves/GINGraph2-730-v100_=0/checkpoint.pt'
        self.job_level='graph' #graph,node


数据载入

In [3]:
def main(args):
    prepartion(args)
    nn_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'n_head':args.n_head,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling,
        'num_tasks':args.num_tasks,
        'data_type':args.data_type,
        'job_level':args.job_level,


    }

    # automatic dataloading and splitting
    train_loader,valid_loader,test_loader=load_data(args)

    # automatic evaluator. takes dataset name as input
    evaluator = Evaluator()
    criterion_fn = torch.nn.MSELoss()

    device = args.device

    model = GINGraphPooling(**nn_params).to(device)
    optimizer =  torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)
    if args.continue_train:
        continue_train(args,model,optimizer)

    num_params = sum(p.numel() for p in model.parameters())
    print('train data:', len(train_loader), 'valid data:', len(valid_loader), file=args.output_file, flush=True)
    print(f'#Params: {num_params}', file=args.output_file, flush=True)
    print(model, file=args.output_file, flush=True)


    writer = SummaryWriter(log_dir=args.save_dir)

    not_improved = 0
    eva=1
    best_valid_mae = 9999
    valid_mae=10000

    for epoch in range(1, args.epochs + 1):

        print('=====epoch:', epoch,time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) )

        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),"=====Epoch {}".format(epoch), file=args.output_file, flush=True)
        print('Training...', file=args.output_file, flush=True)
        train_mae,maxP,minN,avgP,avgN = train(model, device, train_loader, optimizer, criterion_fn,epoch,args.epochs)
        print(train_mae,maxP,minN,avgP,avgN)
        print('Evaluating...', file=args.output_file, flush=True)
        if epoch==eva:
            valid_mae = evaluate(model, device, valid_loader, evaluator)
            eva += args.evaluate_epoch

        print({'Train': train_mae, 'Validation': valid_mae}, file=args.output_file, flush=True)

        writer.add_scalar('valid/mae', valid_mae, epoch)
        writer.add_scalar('train/mae', train_mae, epoch)
        writer.add_scalar('train/maxP', maxP, epoch)
        writer.add_scalar('train/minN', minN, epoch)
        writer.add_scalar('train/avgP', avgP, epoch)
        writer.add_scalar('train/avgN', avgN, epoch)



        if valid_mae < best_valid_mae:
            print('valid_mae:',valid_mae,'Saving checkpoint...')
            best_valid_mae = valid_mae
            if args.save_test:
                print('Saving checkpoint...', file=args.output_file, flush=True)
                checkpoint = {
                    'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae, 'num_params': num_params
                }
                torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint.pt'))
                print('Predicting on test data...', file=args.output_file, flush=True)
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...', file=args.output_file, flush=True)
                evaluator.save_test_submission({'y_pred': y_pred}, args.save_dir)

            not_improved = 0
        else:
            not_improved += 1
            if not_improved == args.early_stop:
                print(f"Have not improved for {not_improved} epoches.", file=args.output_file, flush=True)
                break

        scheduler.step()
        print(f'Best validation MAE so far: {best_valid_mae}', file=args.output_file, flush=True)

    # writer.add_graph(model,train_loader)
    writer.close()
    args.output_file.close()

In [None]:
if __name__ == '__main__':
    # args=p_args()
    args=MyNamespace()
    main(args)
    print('finish')

data loading in dir: ./PTs/smiles_3d


100%|██████████| 10/10 [00:01<00:00,  7.18it/s]


train data: 8000 valid data: 1900
=====epoch: 1 2023-08-05 03:17:45
on training:


Epoch 1/2:   0%|          | 0/400 [00:00<?, ?it/s]

torch.Size([1000, 9])


Epoch 1/2:   1%|          | 4/400 [00:27<33:53,  5.13s/it, loss=24.55504]  

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:   3%|▎         | 13/400 [00:27<06:53,  1.07s/it, loss=21.07047]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:   6%|▌         | 23/400 [00:27<02:12,  2.85it/s, loss=16.76733]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:   8%|▊         | 32/400 [00:27<01:27,  4.19it/s, loss=15.21791]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  10%|█         | 42/400 [00:27<00:41,  8.71it/s, loss=10.66012]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  13%|█▎        | 52/400 [00:28<00:22, 15.69it/s, loss=5.89107] 

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  16%|█▌        | 62/400 [00:28<00:13, 24.19it/s, loss=3.13277]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  18%|█▊        | 72/400 [00:28<00:10, 31.89it/s, loss=1.74340]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  20%|██        | 81/400 [00:28<00:08, 36.89it/s, loss=1.02199]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  23%|██▎       | 91/400 [00:29<00:07, 41.02it/s, loss=0.26326]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  25%|██▌       | 101/400 [00:29<00:06, 44.14it/s, loss=0.38646]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  28%|██▊       | 111/400 [00:29<00:06, 47.65it/s, loss=0.52800]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  30%|███       | 122/400 [00:29<00:05, 50.13it/s, loss=0.53435]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  33%|███▎      | 132/400 [00:29<00:05, 49.42it/s, loss=0.57522]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  35%|███▌      | 141/400 [00:30<00:05, 48.79it/s, loss=0.44326]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  38%|███▊      | 151/400 [00:30<00:05, 47.67it/s, loss=0.76418]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  40%|████      | 162/400 [00:30<00:04, 48.64it/s, loss=0.38426]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  43%|████▎     | 171/400 [00:30<00:04, 47.78it/s, loss=0.64114]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  45%|████▌     | 181/400 [00:30<00:04, 46.98it/s, loss=0.74233]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  48%|████▊     | 191/400 [00:31<00:04, 47.17it/s, loss=0.34321]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  50%|█████     | 200/400 [00:31<00:04, 46.75it/s, loss=0.27895]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  52%|█████▏    | 206/400 [00:31<00:04, 45.96it/s, loss=0.31511]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  54%|█████▍    | 215/400 [00:31<00:05, 34.92it/s, loss=0.46061]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  56%|█████▋    | 225/400 [00:31<00:04, 40.45it/s, loss=0.30826]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  58%|█████▊    | 234/400 [00:32<00:03, 43.39it/s, loss=0.50893]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  61%|██████    | 243/400 [00:32<00:03, 43.99it/s, loss=0.28437]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  63%|██████▎   | 252/400 [00:32<00:03, 45.40it/s, loss=0.40750]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  65%|██████▌   | 261/400 [00:32<00:02, 46.67it/s, loss=0.46179]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  68%|██████▊   | 270/400 [00:32<00:02, 45.86it/s, loss=0.53908]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  70%|██████▉   | 279/400 [00:33<00:02, 45.76it/s, loss=0.37560]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  72%|███████▏  | 288/400 [00:33<00:02, 45.80it/s, loss=0.78611]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  74%|███████▍  | 297/400 [00:33<00:02, 46.83it/s, loss=0.31218]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  77%|███████▋  | 307/400 [00:33<00:02, 46.30it/s, loss=0.41175]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  79%|███████▉  | 317/400 [00:33<00:01, 46.25it/s, loss=0.51348]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  82%|████████▏ | 326/400 [00:34<00:01, 45.83it/s, loss=0.70273]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  83%|████████▎ | 333/400 [00:34<00:01, 45.10it/s, loss=0.69915]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  85%|████████▌ | 341/400 [00:34<00:01, 41.04it/s, loss=0.53976]

torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])
torch.Size([1000, 9])


Epoch 1/2:  87%|████████▋ | 349/400 [00:34<00:01, 42.52it/s, loss=0.50309]