In [1]:
import os
import argparse
import time
import json
import torch
import random
import torch.nn as nn
import numpy as np
from dataset_BERT import EurDataset, collate_data
import sys 
sys.path.append("..") 
from models.BERT2FC import DeepSC_BERT2FC
from utils import SNR_to_noise, initNetParams, train_step_bart2fc, val_step_bart2fc, NoamOpt, EarlyStopping
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', default='../data/BERT/train_data.pkl', type=str)
parser.add_argument('--checkpoint-path', default='../checkpoints/BERT2FC/lr=1e-5', type=str)
parser.add_argument('--channel', default='TEST', type=str, help='Please choose AWGN, Rayleigh, and Rician')
parser.add_argument('--MAX-LENGTH', default=70, type=int)
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=400, type=int)
parser.add_argument('--resume', default=True, type=bool)
parser.add_argument('--Test_epochs', default=1, type=int)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def validate(epoch, args, net):
    test_eur = EurDataset('test')
    test_iterator = DataLoader(test_eur, batch_size=args.batch_size, num_workers=0,
                               pin_memory=True, collate_fn=collate_data)
    net.eval()
    pbar = tqdm(test_iterator)
    total = 0

    with torch.no_grad():
        for sents in pbar:

            sents = sents.to(device)
            loss = val_step_bart2fc(net, sents, sents, 0.1, pad_idx, criterion, args.channel)
            total += loss
            pbar.set_description('Epoch: {}; Type: VAL; Loss: {:.5f}'.format(epoch, loss))
    
    early_stopping(total / len(test_iterator), net)
    if early_stopping.early_stop:
        sys.exit("Early stopping")   

    #val_loss.add_scalar("VAL loss", loss, epoch)
    return total / len(test_iterator)


def train(epoch, args, net):
    train_eur = EurDataset('train')
    train_iterator = DataLoader(train_eur, batch_size=args.batch_size, num_workers=0, shuffle=True,
                                pin_memory=True, collate_fn=collate_data)
    pbar = tqdm(train_iterator)

    noise_std = np.random.uniform(SNR_to_noise(5), SNR_to_noise(10), size=(1))

    for sents in pbar:
        sents = sents.to(device)
        loss = train_step_bart2fc(net, sents, sents, noise_std[0], pad_idx, opt, criterion, args.channel)
        pbar.set_description('Epoch: {};  Type: Train; Loss: {:.5f}'.format(epoch, loss))
            
    #tra_loss.add_scalar("Train loss", loss, epoch)


if __name__ == '__main__':

    #val_loss = SummaryWriter("logs/BART2FC/origin")
    #tra_loss = SummaryWriter("logs/BART2FC/origin")
    setup_seed(42)
    args = parser.parse_args(args=[])

    start_idx = 101
    pad_idx = 0
    end_idx = 102

    """ define optimizer and loss function """

    vocab_size = 30522
    deepsc_bert2fc = DeepSC_BERT2FC(vocab_size).to(device)
    early_stopping = EarlyStopping(args.checkpoint_path + '/best')

    """ load existed model"""
    if args.resume:
        model_paths = []
        for fn in os.listdir(args.checkpoint_path):
            if not fn.endswith('.pth'): continue
            idx = int(os.path.splitext(fn)[0].split('_')[-1])  # read the idx of image
            model_paths.append((os.path.join(args.checkpoint_path, fn), idx))

        model_paths.sort(key=lambda x: x[1])  # sort the image by the idx
        model_path, _ = model_paths[-1]
        print(model_path)
        checkpoint = torch.load(model_path, map_location='cpu')
        deepsc_bert2fc.load_state_dict(checkpoint,strict=False)
        print('model load!')
    else:
        print('no existed checkpoint')
        for p in deepsc_bert2fc.quantization.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
        for p in deepsc_bert2fc.dequantization.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
        for p in deepsc_bert2fc.dense.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 
 
    criterion = nn.CrossEntropyLoss(reduction='none')
    opt = torch.optim.Adam(deepsc_bert2fc.parameters(), lr=1e-7, betas=(0.9, 0.98), eps=1e-8, weight_decay=5e-4)
    # opt = NoamOpt(768, 1, 20000, optimizer)

    record_acc = 10
    for epoch in range(28 , 28 + args.epochs):

        start = time.time()
        train(epoch, args, deepsc_bert2fc)
        avg_acc = validate(epoch, args, deepsc_bert2fc)

        if record_acc >= avg_acc:
            record_acc = avg_acc
            if not os.path.exists(args.checkpoint_path):
                os.makedirs(args.checkpoint_path)
            with open(args.checkpoint_path + '/checkpoint_{}.pth'.format(str(epoch).zfill(2)), 'wb') as f:
                torch.save(deepsc_bert2fc.state_dict(), f)
                # bleu_score1, bleu_score2, bleu_score3, bleu_score4 = \
                #     performance(args, [0], deepsc_vqvae, token_to_idx, pad_idx, start_idx, end_idx)
                # print(bleu_score1, bleu_score2, bleu_score3, bleu_score4)

    #val_loss.close()
    #tra_loss.close()


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


../checkpoints/BERT2FC/lr=1e-5\checkpoint_27.pth
model load!


Epoch: 28;  Type: Train; Loss: 0.26864: 100%|██████████| 3348/3348 [08:52<00:00,  6.28it/s]
Epoch: 28; Type: VAL; Loss: 0.19291: 100%|██████████| 372/372 [00:18<00:00, 19.69it/s]


Validation loss decreased (inf --> 0.322283).  Saving model ...


Epoch: 29;  Type: Train; Loss: 0.16667: 100%|██████████| 3348/3348 [09:12<00:00,  6.06it/s]
Epoch: 29; Type: VAL; Loss: 0.19282: 100%|██████████| 372/372 [00:18<00:00, 19.90it/s]


Validation loss decreased (0.322283 --> 0.322146).  Saving model ...


Epoch: 30;  Type: Train; Loss: 0.38662: 100%|██████████| 3348/3348 [09:12<00:00,  6.06it/s]
Epoch: 30; Type: VAL; Loss: 0.19272: 100%|██████████| 372/372 [00:19<00:00, 19.50it/s]


Validation loss decreased (0.322146 --> 0.322065).  Saving model ...


Epoch: 31;  Type: Train; Loss: 0.40906: 100%|██████████| 3348/3348 [07:52<00:00,  7.09it/s]
Epoch: 31; Type: VAL; Loss: 0.19264: 100%|██████████| 372/372 [00:12<00:00, 28.65it/s]


Validation loss decreased (0.322065 --> 0.321935).  Saving model ...


Epoch: 32;  Type: Train; Loss: 0.28624: 100%|██████████| 3348/3348 [07:31<00:00,  7.41it/s]
Epoch: 32; Type: VAL; Loss: 0.19256: 100%|██████████| 372/372 [00:13<00:00, 28.59it/s]


Validation loss decreased (0.321935 --> 0.321892).  Saving model ...


Epoch: 33;  Type: Train; Loss: 0.37825: 100%|██████████| 3348/3348 [07:31<00:00,  7.41it/s]
Epoch: 33; Type: VAL; Loss: 0.19254: 100%|██████████| 372/372 [00:12<00:00, 28.69it/s]


Validation loss decreased (0.321892 --> 0.321816).  Saving model ...


Epoch: 34;  Type: Train; Loss: 0.38490: 100%|██████████| 3348/3348 [07:31<00:00,  7.41it/s]
Epoch: 34; Type: VAL; Loss: 0.19252: 100%|██████████| 372/372 [00:12<00:00, 28.64it/s]


Validation loss decreased (0.321816 --> 0.321750).  Saving model ...


Epoch: 35;  Type: Train; Loss: 0.42097: 100%|██████████| 3348/3348 [07:31<00:00,  7.42it/s]
Epoch: 35; Type: VAL; Loss: 0.19247: 100%|██████████| 372/372 [00:12<00:00, 28.68it/s]


Validation loss decreased (0.321750 --> 0.321666).  Saving model ...


Epoch: 36;  Type: Train; Loss: 0.37162: 100%|██████████| 3348/3348 [07:30<00:00,  7.43it/s]
Epoch: 36; Type: VAL; Loss: 0.19243: 100%|██████████| 372/372 [00:12<00:00, 28.90it/s]


Validation loss decreased (0.321666 --> 0.321629).  Saving model ...


Epoch: 37;  Type: Train; Loss: 0.41026: 100%|██████████| 3348/3348 [07:29<00:00,  7.44it/s]
Epoch: 37; Type: VAL; Loss: 0.19227: 100%|██████████| 372/372 [00:12<00:00, 29.01it/s]


Validation loss decreased (0.321629 --> 0.321570).  Saving model ...


Epoch: 38;  Type: Train; Loss: 0.43428: 100%|██████████| 3348/3348 [07:29<00:00,  7.45it/s]
Epoch: 38; Type: VAL; Loss: 0.19238: 100%|██████████| 372/372 [00:12<00:00, 29.01it/s]


Validation loss decreased (0.321570 --> 0.321517).  Saving model ...


Epoch: 39;  Type: Train; Loss: 0.18440: 100%|██████████| 3348/3348 [07:29<00:00,  7.45it/s]
Epoch: 39; Type: VAL; Loss: 0.19227: 100%|██████████| 372/372 [00:12<00:00, 29.07it/s]


Validation loss decreased (0.321517 --> 0.321433).  Saving model ...


Epoch: 40;  Type: Train; Loss: 0.38915: 100%|██████████| 3348/3348 [07:29<00:00,  7.45it/s]
Epoch: 40; Type: VAL; Loss: 0.19224: 100%|██████████| 372/372 [00:12<00:00, 29.06it/s]


Validation loss decreased (0.321433 --> 0.321385).  Saving model ...


Epoch: 41;  Type: Train; Loss: 0.35838: 100%|██████████| 3348/3348 [07:29<00:00,  7.45it/s]
Epoch: 41; Type: VAL; Loss: 0.19223: 100%|██████████| 372/372 [00:12<00:00, 29.10it/s]


Validation loss decreased (0.321385 --> 0.321337).  Saving model ...


Epoch: 42;  Type: Train; Loss: 0.36172: 100%|██████████| 3348/3348 [07:28<00:00,  7.46it/s]
Epoch: 42; Type: VAL; Loss: 0.19231: 100%|██████████| 372/372 [00:12<00:00, 29.07it/s]


Validation loss decreased (0.321337 --> 0.321291).  Saving model ...


Epoch: 43;  Type: Train; Loss: 0.27320: 100%|██████████| 3348/3348 [07:29<00:00,  7.45it/s]
Epoch: 43; Type: VAL; Loss: 0.19216: 100%|██████████| 372/372 [00:12<00:00, 29.05it/s]


Validation loss decreased (0.321291 --> 0.321205).  Saving model ...


Epoch: 44;  Type: Train; Loss: 0.32662: 100%|██████████| 3348/3348 [07:28<00:00,  7.46it/s]
Epoch: 44; Type: VAL; Loss: 0.19215: 100%|██████████| 372/372 [00:12<00:00, 29.11it/s]


Validation loss decreased (0.321205 --> 0.321176).  Saving model ...


Epoch: 45;  Type: Train; Loss: 0.22183: 100%|██████████| 3348/3348 [07:28<00:00,  7.46it/s]
Epoch: 45; Type: VAL; Loss: 0.19202: 100%|██████████| 372/372 [00:12<00:00, 29.12it/s]


Validation loss decreased (0.321176 --> 0.321105).  Saving model ...


Epoch: 46;  Type: Train; Loss: 0.23425: 100%|██████████| 3348/3348 [07:28<00:00,  7.46it/s]
Epoch: 46; Type: VAL; Loss: 0.19208: 100%|██████████| 372/372 [00:12<00:00, 29.08it/s]


Validation loss decreased (0.321105 --> 0.321053).  Saving model ...


Epoch: 47;  Type: Train; Loss: 0.32856: 100%|██████████| 3348/3348 [07:28<00:00,  7.46it/s]
Epoch: 47; Type: VAL; Loss: 0.19209: 100%|██████████| 372/372 [00:12<00:00, 29.08it/s]


Validation loss decreased (0.321053 --> 0.320992).  Saving model ...


Epoch: 48;  Type: Train; Loss: 0.48199: 100%|██████████| 3348/3348 [07:28<00:00,  7.47it/s]
Epoch: 48; Type: VAL; Loss: 0.19200: 100%|██████████| 372/372 [00:12<00:00, 29.06it/s]


Validation loss decreased (0.320992 --> 0.320915).  Saving model ...


Epoch: 49;  Type: Train; Loss: 0.44525: 100%|██████████| 3348/3348 [07:28<00:00,  7.46it/s]
Epoch: 49; Type: VAL; Loss: 0.19188: 100%|██████████| 372/372 [00:12<00:00, 29.11it/s]


Validation loss decreased (0.320915 --> 0.320893).  Saving model ...


Epoch: 50;  Type: Train; Loss: 0.24014: 100%|██████████| 3348/3348 [07:28<00:00,  7.47it/s]
Epoch: 50; Type: VAL; Loss: 0.19196: 100%|██████████| 372/372 [00:12<00:00, 29.09it/s]


Validation loss decreased (0.320893 --> 0.320855).  Saving model ...


Epoch: 51;  Type: Train; Loss: 0.45339: 100%|██████████| 3348/3348 [07:28<00:00,  7.47it/s]
Epoch: 51; Type: VAL; Loss: 0.19182: 100%|██████████| 372/372 [00:12<00:00, 29.10it/s]


Validation loss decreased (0.320855 --> 0.320790).  Saving model ...


Epoch: 52;  Type: Train; Loss: 0.34182: 100%|██████████| 3348/3348 [07:28<00:00,  7.47it/s]
Epoch: 52; Type: VAL; Loss: 0.19179: 100%|██████████| 372/372 [00:12<00:00, 29.11it/s]


Validation loss decreased (0.320790 --> 0.320746).  Saving model ...


Epoch: 53;  Type: Train; Loss: 0.26443: 100%|██████████| 3348/3348 [07:28<00:00,  7.47it/s]
Epoch: 53; Type: VAL; Loss: 0.19176: 100%|██████████| 372/372 [00:12<00:00, 29.17it/s]


Validation loss decreased (0.320746 --> 0.320693).  Saving model ...


Epoch: 54;  Type: Train; Loss: 0.58224: 100%|██████████| 3348/3348 [07:28<00:00,  7.47it/s]
Epoch: 54; Type: VAL; Loss: 0.19165: 100%|██████████| 372/372 [00:12<00:00, 29.14it/s]


Validation loss decreased (0.320693 --> 0.320631).  Saving model ...


Epoch: 55;  Type: Train; Loss: 0.35457: 100%|██████████| 3348/3348 [07:28<00:00,  7.47it/s]
Epoch: 55; Type: VAL; Loss: 0.19157: 100%|██████████| 372/372 [00:12<00:00, 29.06it/s]


Validation loss decreased (0.320631 --> 0.320571).  Saving model ...


Epoch: 56;  Type: Train; Loss: 0.39236: 100%|██████████| 3348/3348 [07:28<00:00,  7.47it/s]
Epoch: 56; Type: VAL; Loss: 0.19166: 100%|██████████| 372/372 [00:12<00:00, 29.06it/s]


Validation loss decreased (0.320571 --> 0.320525).  Saving model ...


Epoch: 57;  Type: Train; Loss: 0.40807: 100%|██████████| 3348/3348 [07:28<00:00,  7.46it/s]
Epoch: 57; Type: VAL; Loss: 0.19162: 100%|██████████| 372/372 [00:12<00:00, 29.05it/s]


Validation loss decreased (0.320525 --> 0.320476).  Saving model ...


Epoch: 58;  Type: Train; Loss: 0.19867: 100%|██████████| 3348/3348 [07:28<00:00,  7.46it/s]
Epoch: 58; Type: VAL; Loss: 0.19162: 100%|██████████| 372/372 [00:12<00:00, 29.05it/s]


Validation loss decreased (0.320476 --> 0.320425).  Saving model ...


Epoch: 59;  Type: Train; Loss: 0.35364:  15%|█▌        | 517/3348 [01:46<55:06,  1.17s/it]  