In [1]:
# -*- coding: utf-8 -*-
"""
Created on Tue May 26 16:59:14 2020

@author: HQ Xie
"""
import os
import argparse
import time
import json
import torch
import random
import torch.nn as nn
import numpy as np
import logging
from dataset_BART import EurDataset, collate_data
from performance_BARTEN2BARTDE import performance
import sys 
sys.path.append("..") 
from models.BARTEN2BARTDE import DeepSC_BARTEN2BARTDE
from utils import SNR_to_noise, initNetParams, train_step_barten2bartde, val_step_barten2bartde, NoamOpt, EarlyStopping
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', default='../data/BART/train_data.pkl', type=str)
parser.add_argument('--checkpoint-path', default='../checkpoints/BARTEN2BARTDE/lr=1e-5', type=str)
parser.add_argument('--channel', default='TEST', type=str, help='Please choose AWGN, Rayleigh, and Rician')
parser.add_argument('--MAX-LENGTH', default=65, type=int)
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--resume', default=False, type=bool)
parser.add_argument('--Test_epochs', default=1, type=int)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

log_format = '%(asctime)s %(message)s'
logging.basicConfig(filename='myLog.log', 
                    filemode='w',
                    level=logging.INFO,
                    format=log_format,
                    datefmt='%m/%d %I:%M:%S %p')
logger = logging.getLogger()

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def validate(epoch, args, net, noise_std):
    test_eur = EurDataset('test')
    test_iterator = DataLoader(test_eur, batch_size=args.batch_size, num_workers=0,
                               pin_memory=True, collate_fn=collate_data)
    pbar = tqdm(test_iterator)
    total = 0

    with torch.no_grad():
        for sents in pbar:
            sents = sents.to(device)
            loss = val_step_barten2bartde(net, sents, sents, noise_std, pad_idx, criterion, args.channel)
            total += loss
            pbar.set_description('Epoch: {}; Type: VAL; Loss: {:.5f}'.format(epoch, loss))

    early_stopping(total / len(test_iterator), net, logger)
    if early_stopping.early_stop:
        sys.exit("Early stopping")
        
    return total / len(test_iterator)


def train(epoch, args, net, noise_std):
    train_eur = EurDataset('train')
    train_iterator = DataLoader(train_eur, batch_size=args.batch_size, num_workers=0, shuffle=True,
                                pin_memory=True, collate_fn=collate_data)
    pbar = tqdm(train_iterator)
    total = 0

    for sents in pbar:
        sents = sents.to(device)
        loss = train_step_barten2bartde(net, sents, sents, noise_std, pad_idx, opt, criterion, args.channel)
        total += loss
        pbar.set_description('Epoch: {};  Type: Train; Loss: {:.5f}'.format(epoch, loss))

    return total / len(train_iterator)



if __name__ == '__main__':

    """ parameter setting"""
    setup_seed(42)
    args = parser.parse_args(args=[])
    early_stopping = EarlyStopping(args.checkpoint_path + '/best')
    loss_curve = SummaryWriter("../logs/BARTEN2BARTDE/lr=1e-5", flush_secs=1)
    logger.info('The args: {}'.format(args))


    """ special token idx """
    vocab_size = 50265
    start_idx = 0
    pad_idx = 1
    end_idx = 2


    """ define model """
    deepsc_barten2bartde = DeepSC_BARTEN2BARTDE(vocab_size).to(device)


    """ load existed model"""
    if args.resume:
        model_paths = []
        for fn in os.listdir(args.checkpoint_path):
            if not fn.endswith('.pth'): continue
            idx = int(os.path.splitext(fn)[0].split('_')[-1])  # read the idx of image
            model_paths.append((os.path.join(args.checkpoint_path, fn), idx))

        model_paths.sort(key=lambda x: x[1])  # sort the image by the idx
        model_path, _ = model_paths[-1]
        print(model_path)
        checkpoint = torch.load(model_path, map_location='cpu')
        deepsc_barten2bartde.load_state_dict(checkpoint,strict=False)
        print('model load!')
    else:
        print('no existed checkpoint')

    
    """ define optimizer and loss function """
    criterion = nn.CrossEntropyLoss(reduction='none')
    opt = torch.optim.Adam(deepsc_barten2bartde.parameters(), lr=1e-5, betas=(0.9, 0.98), eps=1e-8, weight_decay=5e-4)
    # opt = NoamOpt(768, 1, 20000, optimizer)

    record_acc = 10
    for epoch in range(1 , 1 + args.epochs):

        start = time.time()
        n_std = np.random.uniform(SNR_to_noise(5), SNR_to_noise(10))

        tra_acc = train(epoch, args, deepsc_barten2bartde, n_std)
        avg_acc = validate(epoch, args, deepsc_barten2bartde, n_std)

        if record_acc >= avg_acc:
            record_acc = avg_acc
            if not os.path.exists(args.checkpoint_path):
                os.makedirs(args.checkpoint_path)
            with open(args.checkpoint_path + '/checkpoint_{}.pth'.format(str(epoch).zfill(2)), 'wb') as f:
                torch.save(deepsc_barten2bartde.state_dict(), f)

        bleu1, bleu2, bleu3, bleu4 = performance(args, [0], deepsc_barten2bartde, pad_idx, start_idx, end_idx)
        
        # record the results
        logger.info('Epoch: {}; Type: Train; Loss: {:.5f}'.format(epoch, tra_acc))  
        logger.info('Epoch: {}; Type: Vaild; Loss: {:.5f}'.format(epoch, avg_acc))
        logger.info('Epoch: {}; Type: Vaild; BLEU score: {:.5f} {:.5f} {:.5f} {:.5f}'.format(epoch, bleu1[0], bleu2[0], bleu3[0], bleu4[0]))

        loss_curve.add_scalar("Train loss", tra_acc, epoch)
        loss_curve.add_scalar("Vaild loss", avg_acc, epoch)
        loss_curve.add_scalar("BLEU score", bleu1[0], epoch)

    loss_curve.close()


no existed checkpoint


Epoch: 1;  Type: Train; Loss: 3.49511: 100%|██████████| 3348/3348 [11:01<00:00,  5.06it/s]
Epoch: 1; Type: VAL; Loss: 1.81708: 100%|██████████| 372/372 [00:18<00:00, 20.11it/s]


Validation loss decreased (inf --> 3.062255).  Saving model ...


Epoch: 2;  Type: Train; Loss: 2.74690:  48%|████▊     | 1596/3348 [04:48<05:16,  5.54it/s]


KeyboardInterrupt: 