In [2]:
import os
import argparse
import time
import json
import torch
import random
import torch.nn as nn
import numpy as np
from dataset_BERT import EurDataset, collate_data
import sys 
sys.path.append("..") 
from models.BERT2FC import DeepSC_BERT2FC
from utils import SNR_to_noise, initNetParams, train_step_bart2fc, val_step_bart2fc, NoamOpt, EarlyStopping
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from performance import performance

parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', default='../data/BERT/train_data.pkl', type=str)
parser.add_argument('--checkpoint-path', default='../checkpoints/BERT2FC/lr=1e-5', type=str)
parser.add_argument('--channel', default='TEST', type=str, help='Please choose AWGN, Rayleigh, and Rician')
parser.add_argument('--MAX-LENGTH', default=70, type=int)
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=400, type=int)
parser.add_argument('--resume', default=True, type=bool)
parser.add_argument('--Test_epochs', default=1, type=int)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def validate(epoch, args, net):
    test_eur = EurDataset('test')
    test_iterator = DataLoader(test_eur, batch_size=args.batch_size, num_workers=0,
                               pin_memory=True, collate_fn=collate_data)
    net.eval()
    pbar = tqdm(test_iterator)
    total = 0

    with torch.no_grad():
        for sents in pbar:

            sents = sents.to(device)
            loss = val_step_bart2fc(net, sents, sents, 0.1, pad_idx, criterion, args.channel)
            total += loss
            pbar.set_description('Epoch: {}; Type: VAL; Loss: {:.5f}'.format(epoch, loss))
    
    early_stopping(total / len(test_iterator), net)
    if early_stopping.early_stop:
        sys.exit("Early stopping")   

    #val_loss.add_scalar("VAL loss", loss, epoch)
    return total / len(test_iterator)


def train(epoch, args, net):
    train_eur = EurDataset('train')
    train_iterator = DataLoader(train_eur, batch_size=args.batch_size, num_workers=0, shuffle=True,
                                pin_memory=True, collate_fn=collate_data)
    pbar = tqdm(train_iterator)

    noise_std = np.random.uniform(SNR_to_noise(5), SNR_to_noise(10), size=(1))

    for sents in pbar:
        sents = sents.to(device)
        loss = train_step_bart2fc(net, sents, sents, noise_std[0], pad_idx, opt, criterion, args.channel)
        pbar.set_description('Epoch: {};  Type: Train; Loss: {:.5f}'.format(epoch, loss))
            
    #tra_loss.add_scalar("Train loss", loss, epoch)


if __name__ == '__main__':

    #val_loss = SummaryWriter("logs/BART2FC/origin")
    #tra_loss = SummaryWriter("logs/BART2FC/origin")
    setup_seed(42)
    args = parser.parse_args(args=[])

    start_idx = 101
    pad_idx = 0
    end_idx = 102

    """ define optimizer and loss function """

    vocab_size = 30522
    deepsc_bert2fc = DeepSC_BERT2FC(vocab_size).to(device)
    early_stopping = EarlyStopping(args.checkpoint_path + '/best')

    """ load existed model"""
    if args.resume:
        model_paths = []
        for fn in os.listdir(args.checkpoint_path):
            if not fn.endswith('.pth'): continue
            idx = int(os.path.splitext(fn)[0].split('_')[-1])  # read the idx of image
            model_paths.append((os.path.join(args.checkpoint_path, fn), idx))

        model_paths.sort(key=lambda x: x[1])  # sort the image by the idx
        model_path, _ = model_paths[-1]
        print(model_path)
        checkpoint = torch.load(model_path, map_location='cpu')
        deepsc_bert2fc.load_state_dict(checkpoint,strict=False)
        print('model load!')
    else:
        print('no existed checkpoint')
        for p in deepsc_bert2fc.quantization.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
        for p in deepsc_bert2fc.dequantization.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
        for p in deepsc_bert2fc.dense.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
 
    criterion = nn.CrossEntropyLoss(reduction='none')
    opt = torch.optim.Adam(deepsc_bert2fc.parameters(), lr=1e-5, betas=(0.9, 0.98), eps=1e-8, weight_decay=5e-4)
    # opt = NoamOpt(768, 1, 20000, optimizer)

    record_acc = 10
    for epoch in range(13 , 13 + args.epochs):

        start = time.time()
        train(epoch, args, deepsc_bert2fc)
        avg_acc = validate(epoch, args, deepsc_bert2fc)

        if record_acc >= avg_acc:
            record_acc = avg_acc
            if not os.path.exists(args.checkpoint_path):
                os.makedirs(args.checkpoint_path)
            with open(args.checkpoint_path + '/checkpoint_{}.pth'.format(str(epoch).zfill(2)), 'wb') as f:
                torch.save(deepsc_bert2fc.state_dict(), f)
                # bleu_score1, bleu_score2, bleu_score3, bleu_score4 = \
                #     performance(args, [0], deepsc_vqvae, token_to_idx, pad_idx, start_idx, end_idx)
                # print(bleu_score1, bleu_score2, bleu_score3, bleu_score4)

    #val_loss.close()
    #tra_loss.close()


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


../checkpoints/BERT2FC/lr=1e-5\checkpoint_12.pth
model load!


Epoch: 13;  Type: Train; Loss: 0.33088: 100%|██████████| 3348/3348 [09:54<00:00,  5.64it/s]
Epoch: 13; Type: VAL; Loss: 0.21232: 100%|██████████| 372/372 [00:19<00:00, 18.74it/s]


Validation loss decreased (inf --> 0.375567).  Saving model ...


Epoch: 14;  Type: Train; Loss: 0.20143: 100%|██████████| 3348/3348 [09:43<00:00,  5.74it/s]
Epoch: 14; Type: VAL; Loss: 0.20796: 100%|██████████| 372/372 [00:19<00:00, 18.77it/s]


EarlyStopping counter: 1 out of 10


Epoch: 15;  Type: Train; Loss: 0.42909: 100%|██████████| 3348/3348 [09:43<00:00,  5.74it/s]
Epoch: 15; Type: VAL; Loss: 0.20442: 100%|██████████| 372/372 [00:17<00:00, 21.77it/s]


EarlyStopping counter: 2 out of 10


Epoch: 16;  Type: Train; Loss: 0.45253: 100%|██████████| 3348/3348 [10:41<00:00,  5.22it/s] 
Epoch: 16; Type: VAL; Loss: 0.20224: 100%|██████████| 372/372 [00:19<00:00, 18.88it/s]


EarlyStopping counter: 3 out of 10


Epoch: 17;  Type: Train; Loss: 0.30617: 100%|██████████| 3348/3348 [11:39<00:00,  4.79it/s]
Epoch: 17; Type: VAL; Loss: 0.19939: 100%|██████████| 372/372 [00:37<00:00, 10.01it/s]


EarlyStopping counter: 4 out of 10


Epoch: 18;  Type: Train; Loss: 0.39619: 100%|██████████| 3348/3348 [16:09<00:00,  3.45it/s]
Epoch: 18; Type: VAL; Loss: 0.19626: 100%|██████████| 372/372 [00:18<00:00, 19.91it/s]


EarlyStopping counter: 5 out of 10


Epoch: 19;  Type: Train; Loss: 0.39865: 100%|██████████| 3348/3348 [10:54<00:00,  5.12it/s]  
Epoch: 19; Type: VAL; Loss: 0.19696: 100%|██████████| 372/372 [00:29<00:00, 12.46it/s]


EarlyStopping counter: 6 out of 10


Epoch: 20;  Type: Train; Loss: 0.34822:   0%|          | 5/3348 [00:06<1:05:27,  1.17s/it]