사용한 모델: KR-BERT_character_sub-character  
모델 repo: https://github.com/snunlp/KR-BERT

In [1]:
import json
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.nn.utils import clip_grad_norm_
from sklearn.model_selection import train_test_split
from torch.optim import AdamW
import pickle
import numpy as np
import pandas as pd
import re
from pathlib import Path
from typing import Union
from transformers import BertModel, BertForPreTraining, BertTokenizer, BertPreTrainedModel, BertConfig

In [2]:
import sklearn
import transformers
print("사용한 라이브러리 버전")
print("-"*22)
print("pytorch:", torch.__version__)
print("transformers:", transformers.__version__)
print("scikit-learn:", sklearn.__version__)
print("numpy:", np.__version__)
print("pandas:", pd.__version__)
print("-"*22)

사용한 라이브러리 버전
----------------------
pytorch: 1.10.2
transformers: 4.16.2
scikit-learn: 1.0.2
numpy: 1.21.5
pandas: 1.4.1
----------------------


In [3]:
class Config:
    """Config class"""

    def __init__(self, json_path_or_dict: Union[str, dict]) -> None:
        """Instantiating Config class
        Args:
            json_path_or_dict (Union[str, dict]): filepath of config or dictionary which has attributes
        """
        if isinstance(json_path_or_dict, dict):
            self.__dict__.update(json_path_or_dict)
        else:
            with open(json_path_or_dict, mode="r") as io:
                params = json.loads(io.read())
            self.__dict__.update(params)

    def save(self, json_path: Union[str, Path]) -> None:
        """Saving config to json_path
        Args:
            json_path (Union[str, Path]): filepath of config
        """
        with open(json_path, mode="w") as io:
            json.dump(self.__dict__, io, indent=4)

    def update(self, json_path_or_dict) -> None:
        """Updating Config instance
        Args:
            json_path_or_dict (Union[str, dict]): filepath of config or dictionary which has attributes
        """
        if isinstance(json_path_or_dict, dict):
            self.__dict__.update(json_path_or_dict)
        else:
            with open(json_path_or_dict, mode="r") as io:
                params = json.loads(io.read())
            self.__dict__.update(params)

    @property
    def dict(self) -> dict:
        return self.__dict__
    
def data_qc(paragrahp:str):
    paragrahp = re.sub(r"(\(.*?\))", "", paragrahp)
    paragrahp = re.sub("((http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.([a-zA-Z]){2,6}([a-zA-Z0-9\.\&\/\?\:@\-_=#])*", "", paragrahp)
    paragrahp = re.sub("'^[a-zA-Z0-9+-_.]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$'", "", paragrahp)
    return paragrahp

### read data

In [4]:
with open("./klue-sts-data/klue-sts-v1.1_train.json", "rt", encoding='utf8') as f:
    data = json.load(f)

In [5]:
data[0]

{'guid': 'klue-sts-v1_train_00000',
 'source': 'airbnb-rtt',
 'sentence1': '숙소 위치는 찾기 쉽고 일반적인 한국의 반지하 숙소입니다.',
 'sentence2': '숙박시설의 위치는 쉽게 찾을 수 있고 한국의 대표적인 반지하 숙박시설입니다.',
 'labels': {'label': 3.7, 'real-label': 3.714285714285714, 'binary-label': 1},
 'annotations': {'agreement': '0:0:0:2:5:0',
  'annotators': ['07', '13', '15', '10', '12', '02', '19'],
  'annotations': [3, 4, 4, 4, 3, 4, 4]}}

In [6]:
for idx, i in enumerate(data):
    if "\xa0" in i['sentence1']:
        print(idx, i)

150 {'guid': 'klue-sts-v1_train_00150', 'source': 'policy-rtt', 'sentence1': '먼저 데이터·인공지능(AI)분야에서는\xa0데이터 활용을 높이기 위해 건강 등 민감한 정보에 대한 가이드라인을 8월까지 마련한다.', 'sentence2': '우선 데이터 및 인공지능(AI) 분야에서는 데이터 활용도를 높이기 위해 8월까지 건강 등 중요 정보에 대한 가이드라인을 마련하기로 했습니다.', 'labels': {'label': 3.7, 'real-label': 3.666666666666667, 'binary-label': 1}, 'annotations': {'agreement': '0:0:1:1:3:1', 'annotators': ['17', '03', '12', '15', '05', '09'], 'annotations': [4, 2, 5, 3, 4, 4]}}
159 {'guid': 'klue-sts-v1_train_00159', 'source': 'policy-sampled', 'sentence1': '65세 이상 독거 노인 죽음, 전체\xa0고독사의 43% 달해', 'sentence2': '여성(51.6%)과 노후준비가 빈약한 고령층(65세 이상, 32.9%)의 고용률 역시 역대 최고치를 기록했습니다.', 'labels': {'label': 0.3, 'real-label': 0.3333333333333333, 'binary-label': 0}, 'annotations': {'agreement': '4:2:0:0:0:0', 'annotators': ['14', '10', '02', '06', '19', '13'], 'annotations': [0, 0, 0, 0, 1, 1]}}
276 {'guid': 'klue-sts-v1_train_00276', 'source': 'policy-sampled', 'sentence1': '이러한 정책 여건을 반영, 방통위는\xa0‘활력있는 방송통신, 신뢰받는 미디어’를 비전으로

6968 {'guid': 'klue-sts-v1_train_06968', 'source': 'policy-rtt', 'sentence1': '청약신청은 한국토지주택공사, 서울주택도시공사, 경기주택도시공사 등 공공주택사업자별 입주자모집 공고를 확인한 후\xa0누리집 또는\xa0현장접수 등으로 가능하다.', 'sentence2': '청약접수는 한국토지주택공사, 서울주택도시공사, 경기주택도시공사 등 공공주택사업자를 확인한 후 홈페이지 또는 현장 신청을 통해 가능합니다.', 'labels': {'label': 3.9, 'real-label': 3.857142857142857, 'binary-label': 1}, 'annotations': {'agreement': '0:0:0:2:4:1', 'annotators': ['17', '03', '07', '16', '19', '15', '04'], 'annotations': [5, 4, 4, 3, 4, 3, 4]}}
7093 {'guid': 'klue-sts-v1_train_07093', 'source': 'policy-sampled', 'sentence1': '또\xa0하루 300~400여명에 달하는\xa0외부 방문객 출입을 전면 금지하는 한편, 직원과의 대면 업무도\xa0금지하는 등 사회적 거리두기를 적극 실천중이다.', 'sentence2': '외교부는 주한 외교단과 함께 코로나19의 확산 억제를 위해 실시하는 범정부적인 ‘사회적 거리두기’를 적극 실천한다.', 'labels': {'label': 1.3, 'real-label': 1.333333333333333, 'binary-label': 0}, 'annotations': {'agreement': '1:3:1:1:0:0', 'annotators': ['12', '02', '06', '19', '04', '10'], 'annotations': [1, 1, 1, 3, 2, 0]}}
7184 {'guid': 'klue-sts-v1_train_07184', 'source':

In [7]:
re.sub(r'\[a-z0-9]+', '', data[7666]['sentence1'])

'올해 35개 기업연구소가\xa0공모 결과 우수 기업연구소로 선정됐다.'

In [8]:
re.sub(r'\xa0', '', data[7666]['sentence1'])

'올해 35개 기업연구소가공모 결과 우수 기업연구소로 선정됐다.'

In [9]:
import unicodedata

In [10]:
unicodedata.normalize("NFKD", data[7666]['sentence1'])

'올해 35개 기업연구소가 공모 결과 우수 기업연구소로 선정됐다.'

In [11]:
shape = np.full([len(data), 3], np.nan)
df = pd.DataFrame(shape, columns=['sentence1', 'sentence2', 'label'])

In [12]:
for idx, el in enumerate(data):
    df.loc[idx] = [el['sentence1'], el['sentence2'], el['labels']['real-label']]

In [13]:
df.loc[7:7]

Unnamed: 0,sentence1,sentence2,label
7,사례집은 국립환경과학원 누리집(ecolibrary.me.go.kr)에서 12일부터 ...,주말을 제외한 평일 오후 12시 30분부터 문예회관 공식 페이스북과 유튜브에서는 지...,0.0


In [14]:
df[['sentence1', 'sentence2']] = df[['sentence1', 'sentence2']].applymap(data_qc)

In [15]:
df.loc[7:7]

Unnamed: 0,sentence1,sentence2,label
7,사례집은 국립환경과학원 누리집에서 12일부터 볼 수 있다.,주말을 제외한 평일 오후 12시 30분부터 문예회관 공식 페이스북과 유튜브에서는 지...,0.0


In [16]:
train, valid = train_test_split(df, test_size=0.2, random_state=42)

In [17]:
del data

### read config files

In [18]:
conf_info = Config(f"./config_files/config_subchar12367_bert.json")

In [19]:
with open("./config_files/config_subchar12367_bert.json", "rt", encoding="utf8") as f:
    conf_subchar = json.load(f)

In [20]:
conf_subchar

{'config': 'config_files/bert_config_subchar12367.json',
 'bert': 'torch_model/pytorch_model_subchar12367_bert.bin',
 'tokenizer': 'config_files/vocab_snu_subchar12367.txt',
 'vocab': 'config_files/vocab_snu_subchar12367.pkl'}

In [21]:
tokenizer_krbert_sub = BertTokenizer.from_pretrained(f"{conf_info.tokenizer}")



In [22]:
tokenizer_krbert_sub.tokenize(df['sentence1'].loc[0])

['숙소',
 '위치',
 '##는',
 '찾기',
 '쉬',
 '##ᆸ',
 '##고',
 '일반적인',
 '한국의',
 '반',
 '##지',
 '##하',
 '숙소',
 '##입니다',
 '.']

In [23]:
def custom_collate_fn(batch):
    input_list, target_list = [], []
    
    for _input, _target in batch:
        input_list.append(_input)
        target_list.append(_target)
    
    tensorized_input = tokenizer_krbert_sub(
        input_list,
        add_special_tokens=True,
        padding="longest",
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    
    tensorized_label = torch.tensor(target_list, dtype = torch.float)
    
    return tensorized_input, tensorized_label


class CustomDataset(Dataset):
    """
    - input_data: list of string
    - target_data: list of int
    """

    def __init__(self, input_data:list, target_data:list) -> None:
        self.X = input_data
        self.Y = target_data

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        X = self.X[index]
        Y = self.Y[index]
        return X, Y

In [24]:
class FCLayer(torch.nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0.0, use_activation=None):
        super(FCLayer, self).__init__()
        self.use_activation = use_activation
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.dropout(x)
        if self.use_activation:
            x = self.use_activation(x)
        return self.linear(x)


class BertSts(BertPreTrainedModel):
    def __init__(self, config) -> None:
        super(BertSts, self).__init__(config)
        self.bert = BertModel(config)
        self.Dense = FCLayer(config.hidden_size, config.hidden_size, config.hidden_dropout_prob)
        self.output_layer = FCLayer(config.hidden_size, 1)


    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        dense_outputs = self.Dense(bert_outputs['pooler_output'])
        sim_score = self.output_layer(dense_outputs)
        return sim_score

In [25]:
conf_info.bert

'torch_model/pytorch_model_subchar12367_bert.bin'

In [26]:
config = BertConfig(conf_info.bert)

In [27]:
config.vocab_size = 12367

In [28]:
model = BertSts(config = config)

In [29]:
model

BertSts(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(12367, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    

In [30]:
weights = torch.load(conf_info.bert)

In [31]:
df.loc[7]['sentence1'], df.loc[7]['sentence2']

('사례집은 국립환경과학원 누리집에서 12일부터 볼 수 있다.',
 '주말을 제외한 평일 오후 12시 30분부터 문예회관 공식 페이스북과 유튜브에서는 지역 예술인들이 중심이 된 서양음악, 국악, 댄스 등의 공연을 실시간으로 감상할 수 있다.')

In [32]:
tokenizer_krbert_sub(df.loc[7]['sentence1'], df.loc[7]['sentence2'])

{'input_ids': [2, 2238, 518, 17, 2591, 3641, 3552, 80, 4891, 518, 23, 181, 828, 594, 33, 30, 5, 3, 3141, 6, 3634, 967, 22, 244, 181, 50, 281, 171, 230, 268, 615, 129, 187, 991, 2561, 24, 5641, 282, 463, 2329, 4809, 1818, 12, 698, 7052, 5401, 8, 644, 1329, 8, 5822, 646, 1078, 6, 4998, 28, 1509, 160, 49, 33, 30, 5, 3], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [33]:
param_names = []

for name, param in model.named_parameters():
    param_names.append(name)

In [34]:
from copy import deepcopy

In [35]:
weight_dict = deepcopy(model.state_dict())

In [36]:
for name, weight in weights.items():
    if name in param_names:
        weight_dict[name] = weight

In [37]:
model.load_state_dict(weight_dict)

<All keys matched successfully>

In [38]:
batch_size = 32

In [39]:
train_dataset = CustomDataset(train.iloc[:, :2].values.tolist(), train['label'].tolist())
valid_dataset = CustomDataset(valid.iloc[:, :2].values.tolist(), valid['label'].tolist())

In [40]:
train_dataloader = DataLoader(train_dataset,
                              batch_size = batch_size,
                              sampler = RandomSampler(train_dataset),
                              collate_fn = custom_collate_fn)

valid_dataloader = DataLoader(valid_dataset,
                              batch_size = batch_size,
                              sampler = RandomSampler(valid_dataset),
                              collate_fn = custom_collate_fn)

In [41]:
loss_fct = torch.nn.MSELoss()

def train(model, train_dataloader, valid_dataloader=None, epochs=2):
        global loss_fct, scheduler
        
        for epoch in range(epochs):
            print(f"*****Epoch {epoch} Train Start*****")
            
            total_loss, batch_loss, batch_count = 0,0,0
        
            model.train()
            model.to(device)
            
            for step, batch in enumerate(train_dataloader):
                batch_count+=1
                
                batch = tuple(item.to(device) for item in batch)
            
                batch_input, batch_label = batch
                
                model.zero_grad()
            
                logits = model(**batch_input)

                loss = loss_fct(logits.view(-1), batch_label.view(-1))
                
                batch_loss += loss.item()
                total_loss += loss.item()
            
                loss.backward()
                
                clip_grad_norm_(model.parameters(), 1.0)
                
                optimizer.step()
                scheduler.step()
                
                if (step % 10 == 0 and step != 0):
                    learning_rate = optimizer.param_groups[0]['lr']
                    print(f"Epoch: {epoch}, Step : {step}, LR : {learning_rate}, Avg Loss : {batch_loss / batch_count:.4f}")

                    batch_loss, batch_count = 0,0

            print(f"Epoch {epoch} Total Mean Loss : {total_loss/(step+1):.4f}")
            print(f"*****Epoch {epoch} Train Finish*****\n")
            
            if valid_dataloader:
                print(f"*****Epoch {epoch} Valid Start*****")
                valid_loss = validate(model, valid_dataloader)
                print(f"Epoch {epoch} Valid Loss : {valid_loss:.4f}")
                print(f"*****Epoch {epoch} Valid Finish*****\n")
            
#             save_checkpoint(".", model, optimizer, scheduler, epoch, total_loss/(step+1))
                
        print("Train Completed. End Program.")

In [42]:
def validate(model, valid_dataloader):
   
    # 모델을 evaluate 모드로 설정 & device 할당
    model.eval()
    model.to(device)
    
    total_loss = 0
        
    for step, batch in enumerate(valid_dataloader):
        
        # tensor 연산 전, 각 tensor에 device 할당
        batch = tuple(item.to(device) for item in batch)
            
        batch_input, batch_label = batch
            
        # gradient 계산하지 않음
        with torch.no_grad():
            logits = model(**batch_input)
            
        # loss
        loss = loss_fct(logits.view(-1), batch_label.view(-1))
        total_loss += loss.item()
        
    total_loss = total_loss/(step+1)

    return total_loss

In [43]:
def initializer(train_dataloader, epochs=2):
    """
    모델, 옵티마이저, 스케쥴러 초기화
    """

    optimizer = AdamW(
        model.parameters(),
        lr=2e-5,
        eps=1e-8
    )
    
    total_steps = len(train_dataloader) * epochs
    print(f"Total train steps with {epochs} epochs: {total_steps}")

    scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer, 
        num_warmup_steps = 0,
        num_training_steps = total_steps
    )

    return optimizer, scheduler

In [44]:
# device type
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"# available GPUs : {torch.cuda.device_count()}")
    print(f"GPU name : {torch.cuda.get_device_name()}")
else:
    device = torch.device("cpu")
print(device)

# available GPUs : 1
GPU name : GeForce RTX 3070
cuda


In [45]:
from transformers import get_linear_schedule_with_warmup, get_constant_schedule, get_cosine_with_hard_restarts_schedule_with_warmup

In [46]:
epochs = 2
optimizer, scheduler = initializer(train_dataloader, epochs=epochs)

train(model, train_dataloader, valid_dataloader, epochs = epochs)

Total train steps with 2 epochs: 584
*****Epoch 0 Train Start*****
Epoch: 0, Step : 10, LR : 1.9982497394784285e-05, Avg Loss : 6.6261
Epoch: 0, Step : 20, LR : 1.9936258727371667e-05, Avg Loss : 1.3261
Epoch: 0, Step : 30, LR : 1.9861273081163116e-05, Avg Loss : 0.8416
Epoch: 0, Step : 40, LR : 1.9757757400064857e-05, Avg Loss : 0.5864
Epoch: 0, Step : 50, LR : 1.9626011169343043e-05, Avg Loss : 0.4091
Epoch: 0, Step : 60, LR : 1.9466415549171235e-05, Avg Loss : 0.4376
Epoch: 0, Step : 70, LR : 1.9279432271880972e-05, Avg Loss : 0.3990
Epoch: 0, Step : 80, LR : 1.9065602306105914e-05, Avg Loss : 0.4192
Epoch: 0, Step : 90, LR : 1.8825544291684332e-05, Avg Loss : 0.3561
Epoch: 0, Step : 100, LR : 1.855995274984798e-05, Avg Loss : 0.3343
Epoch: 0, Step : 110, LR : 1.8269596073875555e-05, Avg Loss : 0.3021
Epoch: 0, Step : 120, LR : 1.7955314306024055e-05, Avg Loss : 0.3402
Epoch: 0, Step : 130, LR : 1.7618016707169658e-05, Avg Loss : 0.3046
Epoch: 0, Step : 140, LR : 1.7258679126189515e

In [47]:
def predict(model, test_dataloader):
    """
    test_dataloader의 label별 확률값과 실제 label 값을 반환
    """

    model.eval()
    model.to(device)

    all_logits = []
    all_labels = []

    for step, batch in enumerate(test_dataloader):
        print(f"{step+1}/{len(test_dataloader)}\r", end = "")
        
        batch_input, batch_label = batch
        
        batch_input = batch_input.to(device)
        
        with torch.no_grad():
            logits = model(**batch_input)
            all_logits.append(logits)
        all_labels.extend(batch_label)

    all_logits = torch.cat(all_logits, dim=0)
    probs = torch.tensor(all_logits).cpu().numpy()
    all_labels = np.array(all_labels)

    return probs, all_labels

In [48]:
probs, labels = predict(model, valid_dataloader)

73/73

  probs = torch.tensor(all_logits).cpu().numpy()


In [49]:
for i in range(len(probs)):
    print(f"probs: {probs[i][0]:5.2f},  label: {labels[i]:5.2f},  diff: {(probs[i] - labels[i])[0]:5.2f}")

probs: -0.03,  label:  0.00,  diff: -0.03
probs:  4.31,  label:  4.20,  diff:  0.11
probs: -0.06,  label:  0.00,  diff: -0.06
probs:  0.04,  label:  0.71,  diff: -0.67
probs:  0.03,  label:  0.00,  diff:  0.03
probs:  3.34,  label:  2.00,  diff:  1.34
probs:  0.12,  label:  0.00,  diff:  0.12
probs:  3.53,  label:  2.43,  diff:  1.10
probs:  0.01,  label:  0.14,  diff: -0.14
probs:  0.23,  label:  0.33,  diff: -0.10
probs:  3.95,  label:  2.67,  diff:  1.29
probs:  3.74,  label:  3.71,  diff:  0.03
probs: -0.00,  label:  0.00,  diff: -0.00
probs:  0.67,  label:  0.86,  diff: -0.18
probs:  3.76,  label:  2.00,  diff:  1.76
probs:  3.69,  label:  3.29,  diff:  0.41
probs:  3.14,  label:  2.20,  diff:  0.94
probs:  0.57,  label:  0.57,  diff: -0.01
probs:  4.55,  label:  4.71,  diff: -0.16
probs:  4.02,  label:  4.00,  diff:  0.02
probs:  3.61,  label:  3.43,  diff:  0.18
probs:  3.78,  label:  4.00,  diff: -0.22
probs:  0.23,  label:  0.00,  diff:  0.23
probs:  0.33,  label:  0.00,  diff

In [50]:
from sklearn.metrics import mean_squared_error

In [51]:
mean_squared_error(probs.flatten(), labels)

0.23966944

In [52]:
s1 = tokenizer_krbert_sub([['코딩 못하는 개발자도 취직 가능할까요?ㅠㅠ', '포도빛 향기에 취해만 가는데']],
                            truncation = True,
                            padding = "longest",
                            max_length=128,
                            return_tensors = "pt")

s2 = tokenizer_krbert_sub([['가스비가 너무 많이 나왔어요.', '가스비가 왜 이렇게 많이 나왔어!']],
                            truncation = True,
                            padding = "longest",
                            max_length=128,
                            return_tensors = "pt")

In [53]:
s1.to(device)
s2.to(device)

{'input_ids': tensor([[   2, 5453, 3213, 1000,  549, 1623, 1844,    5,    3, 5453, 3213, 1464,
         1559,  549, 1623, 2875, 1218,    3]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}

In [54]:
with torch.no_grad():
    print(model(**s1))

tensor([[0.1058]], device='cuda:0')


In [55]:
with torch.no_grad():
    print(model(**s2))

tensor([[4.0434]], device='cuda:0')
