# Setting

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.chdir('/content/drive/MyDrive/Stage2/code') 

In [3]:
os.getcwd()

'/content/drive/MyDrive/Stage2/code'

라이브러리 다운로드

In [4]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3
!pip install torch
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting mxnet
[?25l  Downloading https://files.pythonhosted.org/packages/30/07/66174e78c12a3048db9039aaa09553e35035ef3a008ba3e0ed8d2aa3c47b/mxnet-1.8.0.post0-py2.py3-none-manylinux2014_x86_64.whl (46.9MB)
[K     |████████████████████████████████| 46.9MB 95kB/s 
Collecting graphviz<0.9.0,>=0.8.1
  Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl
Installing collected packages: graphviz, mxnet
  Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-1.8.0.post0
Collecting gluonnlp
[?25l  Downloading https://files.pythonhosted.org/packages/9c/81/a238e47ccba0d7a61dcef4e0b4a7fd4473cb86bed3d84dd4fe28d45a0905/gluonnlp-0.10.0.tar.gz (344kB)
[K     |████████████████████████████████| 348kB 3.9MB/s 
Building wheels for collected packages: gluonnlp
  Building wheel for gluon

라이브러리 불러오기

In [5]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import pandas as pd
import numpy as np
import re
import tarfile
import pickle as pickle
from tqdm import tqdm
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split,StratifiedKFold

from transformers import *
from tqdm import tqdm

GPU 설정

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")

kobert 불러오기

In [7]:
device

device(type='cuda', index=0)

# Preprocessing

In [8]:
data_path = "/content/drive/MyDrive/Stage2/input/data/"

In [9]:
def load_data(dataset_dir):
    with open('/content/drive/MyDrive/Stage2/input/data/label_type.pkl', 'rb') as f:
        label_type = pickle.load(f)
    dataset = pd.read_csv(dataset_dir, delimiter='\t', header=None)
    dataset = preprocessing_dataset(dataset, label_type)
    return dataset

def preprocessing_dataset(dataset, label_type):
    label = []
    for i in dataset[8]:
        if i == 'blind':
            label.append(100)
        else:
            label.append(label_type[i])
    out_dataset = pd.DataFrame({'sentence':dataset[1],'entity_01':dataset[2],'entity_02':dataset[5],'label':label,})
    return out_dataset

In [10]:
origin_data_path = data_path+"train/train.tsv"
eng_data_path = data_path+'train/new_data_NaN_en.tsv'
jap_data_path = data_path+'train/new_data_NaN_ja.tsv'
chi_data_path = data_path+'train/new_data_NaN_zh.tsv'

origin_dataset = load_data(origin_data_path)
eng_dataset = load_data(eng_data_path)
jap_dataset = load_data(jap_data_path)
chi_dataset = load_data(chi_data_path)


In [11]:
origin_dataset

Unnamed: 0,sentence,entity_01,entity_02,label
0,영국에서 사용되는 스포츠 유틸리티 자동차의 브랜드로는 랜드로버(Land Rover)...,랜드로버,자동차,17
1,"선거에서 민주당은 해산 전 의석인 230석에 한참 못 미치는 57석(지역구 27석,...",민주당,27석,0
2,유럽 축구 연맹(UEFA) 집행위원회는 2014년 1월 24일에 열린 회의를 통해 ...,유럽 축구 연맹,UEFA,6
3,"용병 공격수 챠디의 부진과 시즌 초 활약한 강수일의 침체, 시즌 중반에 영입한 세르...",강수일,공격수,2
4,람캄행 왕은 1237년에서 1247년 사이 수코타이의 왕 퍼쿤 씨 인트라팃과 쓰엉 ...,람캄행,퍼쿤 씨 인트라팃,8
...,...,...,...,...
8995,2002년 FIFA 월드컵 사우디아라비아와의 1차전에서 독일은 8-0으로 승리하였는...,사우디아라비아,2002년,0
8996,일본의 2대 메이커인 토요타와 닛산은 시장 점유율을 높이기 위한 신차 개발을 계속하...,토요타,일본,9
8997,방호의의 손자 방덕룡(方德龍)은 1588년(선조 21년) 무과에 급제하고 낙안군수로...,방덕룡,선무원종공신(宣武原從功臣),2
8998,LG전자는 올해 초 국내시장에 출시한 2020년형 ‘LG 그램’ 시리즈를 이달부터 ...,LG전자,북미,0


## back translation 취합

In [12]:
final_dataset = pd.DataFrame({
    'origin_sentence':origin_dataset['sentence'],
    'eng_sentence':eng_dataset['sentence'],
    'jap_sentence':jap_dataset['sentence'],
    'chi_sentence':chi_dataset['sentence'],
    'entity_01':origin_dataset['entity_01'],
    'entity_02':origin_dataset['entity_02'],
    'label':origin_dataset['label']
})

In [13]:
final_dataset.head()

Unnamed: 0,origin_sentence,eng_sentence,jap_sentence,chi_sentence,entity_01,entity_02,label
0,영국에서 사용되는 스포츠 유틸리티 자동차의 브랜드로는 랜드로버(Land Rover)...,"영국에서 사용되는 스포츠카 브랜드에는 랜드로버와 지프가 포함되어 있으며, 이들 브랜...","영국에서 사용되는 스포츠 유틸리티 자동차 브랜드에서는 랜드로버와 지프가 있고, 이 ...",영국 스포츠타운 자동차의 브랜드는 랜드로버와 제프입니다,랜드로버,자동차,17
1,"선거에서 민주당은 해산 전 의석인 230석에 한참 못 미치는 57석(지역구 27석,...","선거에서는 민주당이 해산 전 230석을 훨씬 밑도는 57석(구 27석, 비례대표 3...","선거에서 민주당은 해산 전 의석의 230석에 아직 못 미치는 57석(지역구 27석,...",,민주당,27석,0
2,유럽 축구 연맹(UEFA) 집행위원회는 2014년 1월 24일에 열린 회의를 통해 ...,,,,유럽 축구 연맹,UEFA,6
3,"용병 공격수 챠디의 부진과 시즌 초 활약한 강수일의 침체, 시즌 중반에 영입한 세르...","부진한 공격수 차디와 시즌 초반 활약을 펼친 강수일의 부진, 중반 영입한 세르비아 ...",,,강수일,공격수,2
4,람캄행 왕은 1237년에서 1247년 사이 수코타이의 왕 퍼쿤 씨 인트라팃과 쓰엉 ...,,,,람캄행,퍼쿤 씨 인트라팃,8


In [14]:
#train, vali = train_test_split(dataset, test_size=0.2, random_state=42)
#train[['sentence','label']].to_csv(data_path+"train/train_train.txt", sep='\t', index=False)
#vali[['sentence','label']].to_csv(data_path+"train/train_vali.txt", sep='\t', index=False)

In [15]:
#dataset_train = nlp.data.TSVDataset(data_path+"train/train_train.txt", field_indices=[0,1], num_discard_samples=1)
#dataset_vali = nlp.data.TSVDataset(data_path+"train/train_vali.txt", field_indices=[0,1], num_discard_samples=1)

In [16]:
from transformers import *
model = XLMRobertaForSequenceClassification.from_pretrained("xlm-roberta-large")
input_size = model.classifier.out_proj.in_features
model.classifier.out_proj = nn.Linear(in_features=input_size, out_features=42, bias=True)
model.classifier

tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-large")
config = XLMRobertaConfig.from_pretrained("xlm-roberta-large")


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=513.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2244861551.0, style=ProgressStyle(descr…




Some weights of the model checkpoint at xlm-roberta-large were not used when initializing XLMRobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-large and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.we

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5069051.0, style=ProgressStyle(descript…




In [17]:
config

XLMRobertaConfig {
  "architectures": [
    "XLMRobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "xlm-roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "output_past": true,
  "pad_token_id": 1,
  "type_vocab_size": 1,
  "vocab_size": 250002
}

In [18]:
class BERTDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length,training):
        self.origin_sentence = dataset['origin_sentence']
        self.eng_sentence = dataset['eng_sentence']
        self.jap_sentence = dataset['jap_sentence']
        self.chi_sentence = dataset['chi_sentence']
        self.entity_01 = dataset['entity_01']
        self.entity_02 = dataset['entity_02']
        self.labels = torch.tensor(dataset['label'])
        self.tokenizer = tokenizer
        self.training = training

        
    def __getitem__(self, idx):
        if self.training:
            sentences = [self.origin_sentence[idx],]
            if self._is_sentence(self.eng_sentence[idx]):
                sentences.append(self.eng_sentence[idx])
            if self._is_sentence(self.jap_sentence[idx]):
                sentences.append(self.jap_sentence[idx])
            if self._is_sentence(self.chi_sentence[idx]):
                sentences.append(self.chi_sentence[idx])
            sentence = sentences[np.random.randint(len(sentences))]
        else:
            sentence = self.origin_sentence[idx]
            

        e1 = self.entity_01[idx]
        e2 = self.entity_02[idx]
        
        item = self.tokenizer(e1+'RELATION'+e2+sentence , max_length=max_len, pad_to_max_length=True, truncation=True,return_tensors='pt')
        item['labels'] = self.labels[idx]

        #return (self.sentences[i] + (self.labels[i], ))
        return item
    
    def _is_sentence(self, sentence) :
        return False if sentence is np.NaN else True

    def __len__(self):
        return (len(self.labels))

In [19]:
max_len = 128
batch_size = 8
warmup_ratio = 0.01
num_epochs = 100
max_grad_norm = 1
log_interval = 50
learning_rate =1e-5

In [20]:
#(dataset, idx, tokenizer, max_len,training):
data_train = BERTDataset(final_dataset,tokenizer=tokenizer, max_length=max_len,training= True)
data_vali = BERTDataset( final_dataset, tokenizer=tokenizer, max_length=max_len, training= False)

In [21]:
skf = StratifiedKFold(n_splits=5, random_state=None, shuffle=False)
folds = []
for train_index, test_index in skf.split(final_dataset['origin_sentence'], final_dataset['label']):
    vali = final_dataset.loc[test_index]
    train = final_dataset.loc[train_index]
    folds.append({'train_idx':train_index,'valid_idx':test_index})



In [22]:
from torch.utils.data import Subset
train_subset = Subset(data_train,train_index)
valid_subset = Subset(data_vali,test_index)

In [23]:
train_dataloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, num_workers=2)
vali_dataloader = torch.utils.data.DataLoader(valid_subset, batch_size=batch_size, num_workers=2)

=# Classification

In [24]:
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [25]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes=42, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [26]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = LabelSmoothingLoss(classes=42, smoothing=0.5)

In [27]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [28]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [29]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [30]:
model = model.to(device)

In [31]:
cnt = 0
best_acc = 0.0

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, item in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()

        token_ids = item['input_ids'].squeeze().long().to(device)
        attention_mask = item['attention_mask'].squeeze().long().to(device)
        label = item['labels'].long().to(device)
        
        out = model(token_ids, attention_mask)[0]
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()
        train_acc += calc_accuracy(out, label)
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, item in enumerate(tqdm(vali_dataloader)):
        token_ids = item['input_ids'].squeeze().long().to(device)
        attention_mask = item['attention_mask'].squeeze().long().to(device)
        label = item['labels'].long().to(device)
        out = model(token_ids, attention_mask)[0]
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    
    test_acc = test_acc / (batch_id+1)
    
    if test_acc > best_acc:
        cnt = 0
        best_acc = test_acc
        torch.save(model.state_dict(), "/content/drive/MyDrive/Stage2/model/xlm-roberta-large.pt")
    else:
        cnt+=1
        if cnt == 10:
            print('EarlyStop: '+str(e)+' Epochs')
            break
print('Best Score: ', best_acc)

100%|██████████| 900/900 [06:46<00:00,  2.22it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 1 train acc 0.46375


100%|██████████| 225/225 [00:26<00:00,  8.64it/s]


epoch 1 test acc 0.5361111111111111


100%|██████████| 900/900 [06:45<00:00,  2.22it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 2 train acc 0.6568055555555555


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]


epoch 2 test acc 0.6805555555555556


100%|██████████| 900/900 [06:45<00:00,  2.22it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 3 train acc 0.7515277777777778


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]


epoch 3 test acc 0.7277777777777777


100%|██████████| 900/900 [06:45<00:00,  2.22it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 4 train acc 0.8143055555555555


100%|██████████| 225/225 [00:26<00:00,  8.64it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 4 test acc 0.7222222222222222


100%|██████████| 900/900 [06:45<00:00,  2.22it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 5 train acc 0.8554166666666667


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]


epoch 5 test acc 0.7372222222222222


100%|██████████| 900/900 [06:45<00:00,  2.22it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 6 train acc 0.8880555555555556


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]


epoch 6 test acc 0.7505555555555555


100%|██████████| 900/900 [06:44<00:00,  2.22it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 7 train acc 0.9148611111111111


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 7 test acc 0.7394444444444445


100%|██████████| 900/900 [06:43<00:00,  2.23it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 8 train acc 0.9227777777777778


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 8 test acc 0.7383333333333333


100%|██████████| 900/900 [06:42<00:00,  2.23it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 9 train acc 0.9368055555555556


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 9 test acc 0.7411111111111112


100%|██████████| 900/900 [06:42<00:00,  2.24it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 10 train acc 0.9445833333333333


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]


epoch 10 test acc 0.7638888888888888


100%|██████████| 900/900 [06:41<00:00,  2.24it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 11 train acc 0.9543055555555555


100%|██████████| 225/225 [00:25<00:00,  8.66it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 11 test acc 0.7594444444444445


100%|██████████| 900/900 [06:41<00:00,  2.24it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 12 train acc 0.9598611111111112


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 12 test acc 0.7588888888888888


100%|██████████| 900/900 [06:40<00:00,  2.25it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 13 train acc 0.9618055555555556


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]


epoch 13 test acc 0.775


100%|██████████| 900/900 [06:40<00:00,  2.25it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 14 train acc 0.9702777777777778


100%|██████████| 225/225 [00:25<00:00,  8.65it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 14 test acc 0.7744444444444445


100%|██████████| 900/900 [06:39<00:00,  2.25it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 15 train acc 0.9694444444444444


100%|██████████| 225/225 [00:25<00:00,  8.66it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 15 test acc 0.7727777777777778


100%|██████████| 900/900 [06:39<00:00,  2.25it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 16 train acc 0.9766666666666667


100%|██████████| 225/225 [00:25<00:00,  8.66it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 16 test acc 0.7505555555555555


100%|██████████| 900/900 [06:39<00:00,  2.25it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 17 train acc 0.9794444444444445


100%|██████████| 225/225 [00:26<00:00,  8.64it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 17 test acc 0.7483333333333333


100%|██████████| 900/900 [06:38<00:00,  2.26it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 18 train acc 0.9808333333333333


100%|██████████| 225/225 [00:26<00:00,  8.65it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 18 test acc 0.7711111111111111


100%|██████████| 900/900 [06:38<00:00,  2.26it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 19 train acc 0.9811111111111112


100%|██████████| 225/225 [00:25<00:00,  8.67it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 19 test acc 0.7594444444444445


100%|██████████| 900/900 [06:38<00:00,  2.26it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 20 train acc 0.9791666666666666


100%|██████████| 225/225 [00:25<00:00,  8.66it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 20 test acc 0.7622222222222222


100%|██████████| 900/900 [06:38<00:00,  2.26it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 21 train acc 0.9805555555555555


100%|██████████| 225/225 [00:25<00:00,  8.67it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 21 test acc 0.7583333333333333


100%|██████████| 900/900 [06:37<00:00,  2.26it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 22 train acc 0.9848611111111111


100%|██████████| 225/225 [00:25<00:00,  8.68it/s]
  0%|          | 0/900 [00:00<?, ?it/s]

epoch 22 test acc 0.76


100%|██████████| 900/900 [06:37<00:00,  2.26it/s]
  0%|          | 0/225 [00:00<?, ?it/s]

epoch 23 train acc 0.9865277777777778


100%|██████████| 225/225 [00:25<00:00,  8.68it/s]

epoch 23 test acc 0.7583333333333333
EarlyStop: 22 Epochs
Best Score:  0.775





# Predict

In [32]:
dataset_path = r"/content/drive/MyDrive/Stage2/input/data/test/test.tsv"

dataset = load_data(dataset_path)

dataset['sentence'] = dataset['entity_01'] + ' [SEP] ' + dataset['entity_02'] + ' [SEP] ' + dataset['sentence']

dataset[['sentence','label']].to_csv(data_path+"test/test.txt", sep='\t', index=False)

In [33]:
dataset_test = nlp.data.TSVDataset(data_path+"test/test.txt", field_indices=[0,1], num_discard_samples=1)

data_test = BERTDataset(dataset_test, 0, 1, tokenizer, max_len, True, False)

test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

TypeError: ignored

In [None]:
model.load_state_dict(torch.load("/content/drive/MyDrive/Stage2/model/xlm-roberta-large.pt"))

model.eval()

Predict = []

for batch_id, (token_ids, attention_mask, label) in enumerate(test_dataloader):
    token_ids = token_ids.long().to(device)
    attention_mask = attention_mask.long().to(device)
    label = label.long().to(device)
    out = model(token_ids, attention_mask)[0]
    _, predict = torch.max(out,1)
    Predict.extend(predict.tolist())

In [None]:
output = pd.DataFrame(Predict, columns=['pred'])
output.to_csv('/content/drive/MyDrive/Stage2/result/xlm_roberta_large_stratified.csv', index=False)