In [2]:
# -*- coding: utf-8 -*-
"""
Created on Tue May 26 16:59:14 2020

@author: HQ Xie
"""
import os
import argparse
import time
import json
import torch
import random
import torch.nn as nn
import numpy as np
from dataset_BART import EurDataset, collate_data
import sys 
sys.path.append("..") 
from models.BART2FC import DeepSC_BART2FC
from utils import SNR_to_noise, initNetParams, train_step_bart2fc, val_step_bart2fc, NoamOpt
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', default='../data/BART/train_data.pkl', type=str)
parser.add_argument('--checkpoint-path', default='../checkpoints/BART2FC/lr=1e-5', type=str)
parser.add_argument('--channel', default='TEST', type=str, help='Please choose AWGN, Rayleigh, and Rician')
parser.add_argument('--MAX-LENGTH', default=70, type=int)
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=400, type=int)
parser.add_argument('--resume', default=False, type=bool)
parser.add_argument('--Test_epochs', default=1, type=int)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def validate(epoch, args, net):
    test_eur = EurDataset('test')
    test_iterator = DataLoader(test_eur, batch_size=args.batch_size, num_workers=0,
                               pin_memory=True, collate_fn=collate_data)
    net.eval()
    pbar = tqdm(test_iterator)
    total = 0

    with torch.no_grad():
        for sents in pbar:

            sents = sents.to(device)
            loss = val_step_bart2fc(net, sents, sents, 0.1, pad_idx, criterion, args.channel)
            total += loss
            pbar.set_description('Epoch: {}; Type: VAL; Loss: {:.5f}'.format(epoch, loss))

    #val_loss.add_scalar("VAL loss", loss, epoch)
    return total / len(test_iterator)


def train(epoch, args, net):
    train_eur = EurDataset('train')
    train_iterator = DataLoader(train_eur, batch_size=args.batch_size, num_workers=0, shuffle=True,
                                pin_memory=True, collate_fn=collate_data)
    pbar = tqdm(train_iterator)

    noise_std = np.random.uniform(SNR_to_noise(5), SNR_to_noise(10), size=(1))

    for sents in pbar:
        sents = sents.to(device)
        loss = train_step_bart2fc(net, sents, sents, noise_std[0], pad_idx, opt, criterion, args.channel)
        pbar.set_description('Epoch: {};  Type: Train; Loss: {:.5f}'.format(epoch, loss))
            
    #tra_loss.add_scalar("Train loss", loss, epoch)


if __name__ == '__main__':

    #val_loss = SummaryWriter("logs/BART2FC/origin")
    #tra_loss = SummaryWriter("logs/BART2FC/origin")
    setup_seed(42)
    args = parser.parse_args(args=[])

    start_idx = 0
    pad_idx = 1
    end_idx = 2

    """ define optimizer and loss function """

    vocab_size = 50265
    deepsc_bart2fc = DeepSC_BART2FC(vocab_size).to(device)

    """ load existed model"""
    if args.resume:
        model_paths = []
        for fn in os.listdir(args.checkpoint_path):
            if not fn.endswith('.pth'): continue
            idx = int(os.path.splitext(fn)[0].split('_')[-1])  # read the idx of image
            model_paths.append((os.path.join(args.checkpoint_path, fn), idx))

        model_paths.sort(key=lambda x: x[1])  # sort the image by the idx
        model_path, _ = model_paths[-1]
        print(model_path)
        checkpoint = torch.load(model_path, map_location='cpu')
        deepsc_bart2fc.load_state_dict(checkpoint,strict=False)
        print('model load!')
    else:
        print('no existed checkpoint')
        for p in deepsc_bart2fc.quantization.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
        for p in deepsc_bart2fc.dequantization.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
        for p in deepsc_bart2fc.dense.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
 
    criterion = nn.CrossEntropyLoss(reduction='none')
    opt = torch.optim.Adam(deepsc_bart2fc.parameters(), lr=1e-5, betas=(0.9, 0.98), eps=1e-8, weight_decay=5e-4)
    # opt = NoamOpt(768, 1, 20000, optimizer)

    record_acc = 10
    for epoch in range(1, args.epochs):

        start = time.time()
        train(epoch, args, deepsc_bart2fc)
        avg_acc = validate(epoch, args, deepsc_bart2fc)

        if (epoch + 1) % 5 == 0 or record_acc >= avg_acc:
            record_acc = avg_acc
            if not os.path.exists(args.checkpoint_path):
                os.makedirs(args.checkpoint_path)
            with open(args.checkpoint_path + '/checkpoint_{}.pth'.format(str(epoch).zfill(2)), 'wb') as f:
                torch.save(deepsc_bart2fc.state_dict(), f)
                # bleu_score1, bleu_score2, bleu_score3, bleu_score4 = \
                #     performance(args, [0], deepsc_vqvae, token_to_idx, pad_idx, start_idx, end_idx)
                # print(bleu_score1, bleu_score2, bleu_score3, bleu_score4)

    #val_loss.close()
    #tra_loss.close()


no existed checkpoint


Epoch: 1;  Type: Train; Loss: 5.19452:   1%|          | 30/3348 [00:10<19:18,  2.86it/s] 


KeyboardInterrupt: 