In [24]:
import torch, json, pandas as pd
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from transformers.data.data_collator import DataCollatorForLanguageModeling

In [25]:
tokenizer = AutoTokenizer.from_pretrained("/home/qhn/Codes/Models/Bert-base-chinese")

loading configuration file /home/qhn/Codes/Models/Bert-base-chinese/config.json
Model config BertConfig {
  "_name_or_path": "/home/qhn/Codes/Models/Bert-base-chinese",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.41.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}

loading file vocab.txt
loading file tokenizer.json
loading file added_tokens.jso

In [26]:
def load_dataloader(dataset_name, tokenizer, args):
    collator = DataCollatorForLanguageModeling(tokenizer)
    data_paths = {
        "train": f"/home/qhn/Codes/Projects/KnowledgeDist/Data/{dataset_name}/train.tsv",
        "test": f"/home/qhn/Codes/Projects/KnowledgeDist/Data/{dataset_name}/test.tsv"
    }
    label2index = json.load(open(f"/home/qhn/Codes/Projects/KnowledgeDist/Data/{dataset_name}/label2index.json", "r"))
    train_data, test_data = {"text": [],"labels": []}, {"text": [],"labels": []}
    train = pd.read_csv(data_paths['train'], sep="\t")
    test  = pd.read_csv(data_paths['test'], sep="\t")
    for idx, row in train.iterrows():
        text, label = row['abstract'], label2index[row['discipline']]
        train_data['text'].append(text)
        train_data['labels'].append(label)
    for idx, row in test.iterrows():
        text, label = row['abstract'], label2index[row['discipline']]
        test_data['text'].append(text)
        test_data['labels'].append(label)
    
    train_dataset = Dataset.from_dict(train_data)
    test_dataset = Dataset.from_dict(test_data)
    
    train_dataset = train_dataset.map(
    lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=args["max_seq_length"]), batched=True)
    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=False,
                             num_workers=0)
    test_dataset = test_dataset.map(
    lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=args["max_seq_length"]), batched=True)
    test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True, drop_last=False, collate_fn=collator, 
                             num_workers=0)
    return train_loader, test_loader

In [27]:
args = {"max_seq_length": 256}
train, test = load_dataloader("agriculture", tokenizer, args)

Map: 100%|██████████| 2793/2793 [00:00<00:00, 5239.98 examples/s]
Map: 100%|██████████| 1197/1197 [00:00<00:00, 6703.71 examples/s]


In [34]:
for item in test:
    print(item['labels'][0])
    break

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  4905,  -100,  -100,  -100,  2825,  3318,  -100,  -100,
         -100,  -100,  -100,  5299,  -100,  -100,  1765,  -100,  -100,  -100,
         -100,  -100,   119,   121,  -100,  -100,  -100,  -100,  1772,  -100,
         -100,  -100,  -100,  -100,   129,  8595,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,   119,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  3946,  -100, 12929,  -100,  -100,
         -100,  -100,  4288,  -100,  -100,  -100,  2108,  -100,  -100,  -100,
         -100,  -100,  -100,  -100, 11289,  -100,  -100,  8320,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  6574,   711,  -100,  5112,  -100,
         -100,   117,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          117,  -100,  -100,  3300,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 