In [None]:
!pip install transformers sentence-transformers datasets

In [None]:
import os
import csv
import sys
import math
import json
import gzip
import random
from urllib.request import urlopen
from datetime import datetime
import pandas as pd

from sentence_transformers import models, losses, datasets
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

In [None]:
model_name = "studio-ousia/luke-japanese-base-lite"
train_batch_size = 128
max_seq_length = 128
num_epochs = 1

model_save_path = "output/sbert-jsnli-luke-japanese-base-lite"

# Model

In [None]:
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean')
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Train Dataset
* JSNLI train data

In [None]:
from datasets import load_dataset
jsnli  = load_dataset("shunk031/jsnli", "with-filtering")

In [None]:
label2name = {
    0:"entailment",
    1:"neutral",
    2:"contradiction",
}

train_data = {}
def add_to_samples(sent1, sent2, label):
    if sent1 not in train_data:
        train_data[sent1] = {'contradiction': set(), 'entailment': set(), 'neutral': set()}
    train_data[sent1][label].add(sent2)

for i, row in enumerate(jsnli["train"]):
    sent1 = row['premise'].strip()
    sent2 = row['hypothesis'].strip()

    add_to_samples(sent1, sent2, label2name[row['label']])
    add_to_samples(sent2, sent1, label2name[row['label']])  #Also add the opposite

train_samples = []
for sent1, others in train_data.items():
    if len(others['entailment']) > 0 and len(others['contradiction']) > 0:
        train_samples.append(InputExample(texts=[sent1, random.choice(list(others['entailment'])), random.choice(list(others['contradiction']))]))
        train_samples.append(InputExample(texts=[random.choice(list(others['entailment'])), sent1, random.choice(list(others['contradiction']))]))
"Train samples: {}".format(len(train_samples))

# Validation Dataset
* JSTS train data

In [None]:
#Read jsts dataset and use it as development set
dev_samples = []

jsts_url = "https://raw.githubusercontent.com/yahoojapan/JGLUE/main/datasets/jsts-v1.1/train-v1.1.json"
jsts = pd.DataFrame([json.loads(line) for line in urlopen(jsts_url).readlines()])

for line in urlopen(jsts_url).readlines():
    row = json.loads(line)
    score = float(row['label']) / 5.0 #Normalize score to range 0 ... 1
    dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))

dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, batch_size=train_batch_size, name='sts-train')

# Train

In [None]:
# Special data loader that avoid duplicates within a batch
train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size)

# Our training loss
train_loss = losses.MultipleNegativesRankingLoss(model)

# Configure the training
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
"Warmup-steps: {}".format(warmup_steps)

In [None]:
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=dev_evaluator,
          epochs=num_epochs,
          evaluation_steps=int(len(train_dataloader)*0.1),
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          use_amp=False
          )