In [1]:
from src.models import FtModel
from src.config import parse_args
from src.dataset import FinetuneDataset
from src.utils import *

from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, Subset, WeightedRandomSampler
import os
from tqdm import tqdm,trange
import pandas as pd
import torch

def inference():
    args = parse_args()
    args.gpu_ids='1'
    args.bert_seq_length=128
    args.ckpt_file="data/checkpoint/model_epoch_3_f1_0.9058_1000.bin"
    args.ema_start=0
    args.ema_decay=0.99
    args.result_file="result.tsv"
    args.bert_dir='digitalepidemiologylab/covid-twitter-bert'

    
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
    setup_device(args)
    setup_seed(args)
    test_dataset = FinetuneDataset(args, args.test_path, True)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset,
                                    batch_size=args.val_batch_size,
                                    sampler=test_sampler,
                                    drop_last=False,
                                    pin_memory=True)
    
    print('The test data length: ',len(test_dataloader))
    model = FtModel(args)
    model = model.to(args.device)
    if args.distributed_train:
        model = torch.nn.parallel.DataParallel(model)
    ckpoint = torch.load(args.ckpt_file)
    model.load_state_dict(ckpoint['model_state_dict'])
    print("The epoch {} and the best mean f1 {:.4f} of the validation set.".format(ckpoint['epoch'],ckpoint['mean_f1']))
    
    if args.ema_start >= 0:
        ema = EMA(model, args.ema_decay)
        ema.resume(ckpoint['shadow'][0], ckpoint['backup'][0])
        # ema.shadow = 
        ema.apply_shadow()
    
    model.eval()
    predictions = []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(test_dataloader,desc="Evaluating")):
            for k in batch:
                batch[k] = batch[k].cuda()

            probability = model(batch,True)
            pred_label_id = torch.argmax(probability, dim=1)
            predictions.extend(pred_label_id.cpu().numpy())
    with open(f"data/{args.result_file}","w+") as f:
        print('task')
        f.write(f"tweet_id\tlabel\n")
        for i in trange(len(predictions)):
            i_d = test_dataset.data['tweet_id'].iloc[i]
#             text = test_dataset.data['text'].iloc[i]
            label = int(predictions[i])
            
            f.write(f"{i_d}\t{label}\n")

In [2]:
inference()

The test data length:  21


Some weights of the model checkpoint at digitalepidemiologylab/covid-twitter-bert were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


The epoch 3 and the best mean f1 0.9058 of the validation set.


  probability = nn.functional.softmax(logits)
Evaluating: 100%|██████████| 21/21 [00:09<00:00,  2.21it/s]


task


100%|██████████| 1291/1291 [00:00<00:00, 80778.82it/s]
