In [None]:
import torch
import torch.nn as nn

#DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = 'cpu'

import torchtext.transforms as T
from torch.hub import load_state_dict_from_url

padding_idx = 1
bos_idx = 0
eos_idx = 2
max_seq_len = 256
vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
#spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
spm_model_path = r"https://huggingface.co/IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece/resolve/main/spm.model"

text_transform = T.Sequential(
    T.SentencePieceTokenizer(spm_model_path),
    T.VocabTransform(load_state_dict_from_url(vocab_path)),
    T.Truncate(max_seq_len - 2),
    T.AddToken(token=bos_idx, begin=True),
    T.AddToken(token=eos_idx, begin=False),
)
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader

batch_size = 128

def batch_transform(x):
    return text_transform(x[0]), int(x[1])-1

In [None]:
from torchdata.datapipes.iter import IterableWrapper
train_datapipe = IterableWrapper(["./data/train.csv"])
train_datapipe = train_datapipe.open_files(encoding="utf-8").parse_csv()
train_datapipe = train_datapipe.shuffle().sharding_filter()
list(train_datapipe)

In [2]:
test_datapipe = IterableWrapper(["./data/test.csv"])
test_datapipe = test_datapipe.open_files(encoding="utf-8").parse_csv()
test_datapipe = test_datapipe.shuffle().sharding_filter()
list(test_datapipe)

[['能简单介绍下site', '7'],
 ['请问site在哪', '1'],
 ['讲讲city和locale的关系', '2'],
 ['site感觉很有名，能介绍一下吗', '7'],
 ['帮我查下site的官网', '6'],
 ['site这块有啥商店吗', '9'],
 ['你好，请问locale是哪个州的', '1'],
 ['city下面有哪几个区', '10'],
 ['我该怎么去site', '3'],
 ['site还有site之间有啥关系', '2'],
 ['site有官方网站没', '6'],
 ['region和locale间有啥关系', '2'],
 ['饿死我了，site附近有什么吃的吗', '8'],
 ['请问site有联系电话没有', '5'],
 ['locale的地址是啥', '1'],
 ['state与country间有关系', '2'],
 ['site附近还有什么好地方', '11'],
 ['请问locale在哪里呀', '1'],
 ['site这块有没有什么购物的地方', '9'],
 ['请讲讲state还有country的联系', '2'],
 ['我有点饿了，locale这块有啥吃的吗', '8'],
 ['locale里有没有值得逛逛的地方', '10'],
 ['讲讲country和site间有啥关系', '2'],
 ['site有什么介绍没有', '7'],
 ['这个city是state的什么', '2'],
 ['city是哪个国的', '1'],
 ['饿了，推荐点locale这里的饭店', '8'],
 ['你好，请问怎么到locale', '3'],
 ['site有啥联系方式没', '4'],
 ['site附近有什么好地方值得一去', '11']]

In [None]:
train_datapipe = train_datapipe.map(batch_transform)
train_datapipe = train_datapipe.batch(batch_size)
train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"])

train_dataloader = DataLoader(train_datapipe, batch_size=None)
list(train_datapipe)


In [3]:
test_datapipe = test_datapipe.map(batch_transform)
test_datapipe = test_datapipe.batch(batch_size)
test_datapipe = test_datapipe.rows2columnar(["token_ids", "target"])
test_dataloader = DataLoader(test_datapipe, batch_size=None)
list(test_datapipe)

[defaultdict(list,
             {'token_ids': [[0, 91, 1486, 9679, 11090, 3, 88812, 11233, 2],
               [0, 7001, 1935, 2008, 87154, 24417, 3, 16076, 2],
               [0, 3, 55043, 1935, 164691, 41978, 2],
               [0, 3, 4, 200757, 55043, 1935, 3, 3, 2],
               [0, 91, 1486, 465, 220213, 7064, 2],
               [0, 3, 60089, 354, 61340, 43, 6711, 2],
               [0, 3, 11090, 465, 3, 3029, 2],
               [0, 11341, 1189, 238, 3, 1294, 11158, 3, 2],
               [0, 3, 3, 61340, 9679, 238, 3, 1294, 3, 2],
               [0, 3, 37085, 3, 11090, 2],
               [0, 6, 3, 4, 49732, 2391, 55043, 1935, 151521, 3, 2],
               [0, 6, 3, 60089, 264, 55043, 1935, 86282, 2],
               [0, 3, 11090, 3, 2],
               [0, 26349, 3, 3, 2],
               [0, 13129, 35679, 3, 4, 55043, 1935, 3, 3, 90146, 9131, 2],
               [0, 6, 3, 238, 3, 1294, 264, 11090, 11158, 3, 11233, 2],
               [0, 91, 1486, 3, 3, 58934, 9131, 2],
             

In [4]:
model = torch.load("D:\\projects\\python\\test\\src\\intent\\model.pt")
model.to(DEVICE)

RobertaModel(
  (encoder): RobertaEncoder(
    (transformer): TransformerEncoder(
      (token_embedding): Embedding(250002, 768, padding_idx=1)
      (layers): TransformerEncoder(
        (layers): ModuleList(
          (0-11): 12 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (linear1): Linear(in_features=768, out_features=3072, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear2): Linear(in_features=3072, out_features=768, bias=True)
            (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (positional_embedding): PositionalEmbedding(
        (embedding): Embedding(5

In [5]:
import torchtext.functional as F

def eval_step(input, target):
    output = model(input)
    return (output.argmax(1) == target).type(torch.float).mean().item()

def evaluate():
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    counter = 0
    with torch.no_grad():
        for batch in test_dataloader:
            input = F.to_tensor(batch["token_ids"], padding_value=padding_idx).to(DEVICE)
            print(input)
            target = torch.tensor(batch["target"]).to(DEVICE)
            predictions = eval_step(input, target)
            correct_predictions += predictions

            total_predictions += 1
            counter += 1

    return correct_predictions / total_predictions


In [8]:
accuracy = evaluate()
print("accuracy = [{}]".format(accuracy))

tensor([[     0,     91,   1486,      3,      3,  58934,   9131,      2,      1,
              1,      1,      1],
        [     0,     91,   1486,   9679,  11090,      3,  88812,  11233,      2,
              1,      1,      1],
        [     0,     91,   1486,  34643,      3,  10491,      3,   1677,      2,
              1,      1,      1],
        [     0,      3,  55043,   1935, 164691,  41978,      2,      1,      1,
              1,      1,      1],
        [     0,      3,  11090,    465,      3,   3029,      2,      1,      1,
              1,      1,      1],
        [     0,      3,      4, 200757,  12082,    789,  55043,   1935,      2,
              1,      1,      1],
        [     0,      6,      3,      3,      4,  11090,  34643,  47097,  90146,
           9131,      2,      1],
        [     0,  26349,  25735,      3,  28138,   3624,      2,      1,      1,
              1,      1,      1],
        [     0,     91,   1486,      3,      3,   7064,      2,      1,      1,