In [13]:
import re
import torch
from nltk.tokenize import word_tokenize

from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
)

In [2]:
device_1 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_2 = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [6]:
MODEL_NAME = 'sberbank-ai/ruRoberta-large'

In [7]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True)

In [4]:
model_cat = torch.load("model_aspect_cat.pt")
model_sent = torch.load("model_aspect_sent.pt")

In [15]:
_ = model_cat.to(device_1)
_ = model_sent.to(device_2)

In [9]:
def join_punctuation(seq, right_punct = '!$%\'),-.:;?', left_punct = '('):
    right_punct = set(right_punct)
    left_punct = set(left_punct)
    seq = iter(seq)
    current = next(seq)

    for nxt in seq:
        if nxt in right_punct:
            current += nxt
        elif current in left_punct:
            current += nxt
        else:
            yield current
            current = nxt

    yield current

In [23]:
int_to_cat = [
    "O", 
    "Food", 
    "Food", 
    "Interior",
    "Interior", 
    "Price",
    "Price", 
    "Whole", 
    "Whole", 
    "Service", 
    "Service",
]
int_to_sent = ["positive", "negative", "neutral", "both"]

def inference(review, model_cat, model_sent, tokenizer, idx):
    tokens = word_tokenize(review)
    text = review
    
    model_cat.eval()
    model_sent.eval()
    tokenized = tokenizer(tokens, is_split_into_words=True, return_tensors="pt")
    word_ids = tokenized.word_ids()
    
    res = []
    pred = model_cat(tokenized["input_ids"].to(device_1), attention_mask=tokenized["attention_mask"].to(device_1)).logits.argmax(dim=2)[0]
    prev = None
    for k, j in enumerate(word_ids):
        if j != None and prev != j:
            res.append(int_to_cat[pred[k].item()])
        prev = j
        
    res_2 = []
    pred = model_sent(tokenized["input_ids"].to(device_2), attention_mask=tokenized["attention_mask"].to(device_2)).logits.argmax(dim=2)[0]
    prev = None
    for k, j in enumerate(word_ids):
        if j != None and prev != j:
            res_2.append(int_to_sent[pred[k].item()])
        prev = j
        
    result = []
    prev = None
    for token, tag, sent in zip(tokens, res, res_2):
        if tag != "O":
            if prev and prev == tag:
                result[-1][0].append(token) 
            else:
                result.append([[token], tag, sent])
        prev = tag
        
    output = []
    length = 0
    for span in result:
        word = ' '.join(join_punctuation(span[0]))
        m = re.search(word, text)
        if m:
            s, e = m.span()
            output.append([idx, span[1], word, str(s + length), str(e + length), span[2]])
            text = text[e:]
            length += e
    return output

In [27]:
reviews_filename = "data/dev_reviews.txt"
aspect_filename = "data/dev_pred_aspects.txt"

with open(aspect_filename, "w", encoding="utf-8") as file_write:
    with open(reviews_filename, encoding="utf-8") as file_read:
        for line in file_read:
            line = line[:-1]
            idx, review = line.split("\t")
            res = inference(review, model_cat, model_sent, tokenizer, idx)
            for r in res:
                file_write.write("\t".join(r) + "\n")