### Inference

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

In [None]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoTokenizer

from en_grammar_checker.config import Config
from en_grammar_checker.trainer import MyLightningClassifierModel

In [None]:
tqdm.pandas()

In [None]:
cnfg = Config()

In [None]:
# model = torch.load(cnfg.trained_model_path)

In [None]:
# model.keys()

In [None]:
pl_module = MyLightningClassifierModel.load_from_checkpoint(cnfg.trained_model_path)



In [None]:
# get tokenizer
# my_tokenizer = AutoTokenizer.from_pretrained(cnfg.base_model_name)
# or
my_tokenizer = AutoTokenizer.from_pretrained(
    pl_module.model.base_model.config._name_or_path
)



In [None]:
cnfg.base_model_name

'microsoft/deberta-v3-large'

In [None]:
pl_module.model.base_model.config._name_or_path

'microsoft/deberta-v3-large'

In [None]:
device = pl_module.device

In [None]:
pl_module = pl_module.eval()

In [None]:
def get_encoded_tensor(context_length, my_tokenizer, sentence, device):
    tokens_dict = my_tokenizer.encode_plus(
        sentence,  # Sentence to encode.
        add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
        max_length=context_length,  # Pad & truncate all sentences.
        truncation=True,
        padding="max_length",
        return_attention_mask=True,  # Construct attn. masks.
        return_tensors="pt",
    ).to(device)
    input_ids = tokens_dict["input_ids"]
    attention_mask = tokens_dict["attention_mask"]
    return input_ids, attention_mask, tokens_dict

In [None]:
def infer_for_sentence(sentence, module, context_length, tokenizer, device):
    input_ids, attention_mask, tokens_dict = get_encoded_tensor(
        context_length, tokenizer, sentence, device
    )
    with torch.no_grad():
        output = module.model(input_ids=input_ids, attention_mask=attention_mask)

    probs = nn.Softmax(dim=1)(output).detach().cpu()
    label = probs.argmax().item()
    return output, probs, label

In [None]:
sentence = "She knew French for Tom."
output, probs, label = infer_for_sentence(
    sentence, pl_module, cnfg.context_length, my_tokenizer, device
)

In [None]:
output

tensor([[-1.3638,  0.9380]], device='cuda:0')

In [None]:
probs

tensor([[0.0910, 0.9090]])

In [None]:
label

1

## Test set predictions

In [None]:
test_set_1 = pd.read_csv(
    f"{cnfg.base_dataset_path}in_domain_dev.tsv",
    delimiter="\t",
    header=None,
    names=["sentence_source", "label", "label_notes", "sentence"],
)

test_set_2 = pd.read_csv(
    f"{cnfg.base_dataset_path}out_of_domain_dev.tsv",
    delimiter="\t",
    header=None,
    names=["sentence_source", "label", "label_notes", "sentence"],
)

In [None]:
test_set_1.head()

Unnamed: 0,sentence_source,label,label_notes,sentence
0,gj04,1,,The sailors rode the breeze clear of the rocks.
1,gj04,1,,The weights made the rope stretch over the pul...
2,gj04,1,,The mechanical doll wriggled itself loose.
3,cj99,1,,"If you had eaten more, you would want less."
4,cj99,0,*,"As you eat the most, you want the least."


In [None]:
test_set_2.head()

Unnamed: 0,sentence_source,label,label_notes,sentence
0,clc95,1,,Somebody just left - guess who.
1,clc95,1,,"They claimed they had settled on something, bu..."
2,clc95,1,,"If Sam was going, Sally would know where."
3,clc95,1,,"They're going to serve the guests something, b..."
4,clc95,1,,She's reading. I can't imagine what.


In [None]:
test_set_1.shape, test_set_2.shape

((527, 4), (516, 4))

In [None]:
test_set_1.label.value_counts()

label
1    365
0    162
Name: count, dtype: int64

In [None]:
test_set_2.label.value_counts()

label
1    354
0    162
Name: count, dtype: int64

In [None]:
test_set_1["pred"] = test_set_1.sentence.progress_apply(
    lambda x: infer_for_sentence(
        x, pl_module, cnfg.context_length, my_tokenizer, device
    )[2]
)

100%|█████████████████████████████████████████| 527/527 [00:25<00:00, 20.52it/s]


In [None]:
test_set_2["pred"] = test_set_2.sentence.progress_apply(
    lambda x: infer_for_sentence(
        x, pl_module, cnfg.context_length, my_tokenizer, device
    )[2]
)

100%|█████████████████████████████████████████| 516/516 [00:25<00:00, 20.26it/s]


In [None]:
test_set_1.pred.value_counts()

pred
1    379
0    148
Name: count, dtype: int64

In [None]:
test_set_2.pred.value_counts()

pred
1    370
0    146
Name: count, dtype: int64

In [None]:
test_set_2[test_set_2.label == 0]

Unnamed: 0,sentence_source,label,label_notes,sentence,pred
6,clc95,0,*,John ate dinner but I don't know who.,1
7,clc95,0,*,"She mailed John a letter, but I don't know to ...",1
10,clc95,0,*,"She was bathing, but I couldn't make out who.",0
11,clc95,0,*,She knew French for Tom.,1
12,clc95,0,*,John is tall on several occasions.,1
...,...,...,...,...,...
493,w_80,0,*,It is to give up to leave.,0
495,w_80,0,*,It was believed to be illegal by them to do that.,1
501,w_80,0,*,I gave Pete the book to impress.,1
504,w_80,0,*,I presented Bill with it to read.,1


In [None]:
# test_set_3 = pd.read_excel("../data/ged_data/test_data.xlsx")
test_set_3 = pd.read_csv("../data/ged_data/val_data.csv")

In [None]:
test_set_3.labels.value_counts()

labels
0    5000
1    5000
Name: count, dtype: int64

In [None]:
test_set_3.shape

(10000, 2)

In [None]:
test_set_3["pred"] = test_set_3.input.progress_apply(
    lambda x: infer_for_sentence(
        x, pl_module, cnfg.context_length, my_tokenizer, device
    )[2]
)

100%|█████████████████████████████████████| 10000/10000 [08:12<00:00, 20.29it/s]


### Classification Reports

In [None]:
class_name = ["Incorrect", "Correct"]

In [None]:
print(
    classification_report(
        test_set_1.label.values, test_set_1.pred.values, target_names=class_name
    )
)

              precision    recall  f1-score   support

   Incorrect       0.76      0.70      0.73       162
     Correct       0.87      0.90      0.89       365

    accuracy                           0.84       527
   macro avg       0.82      0.80      0.81       527
weighted avg       0.84      0.84      0.84       527



In [None]:
print(
    classification_report(
        test_set_2.label.values, test_set_2.pred.values, target_names=class_name
    )
)

              precision    recall  f1-score   support

   Incorrect       0.77      0.70      0.73       162
     Correct       0.87      0.91      0.89       354

    accuracy                           0.84       516
   macro avg       0.82      0.80      0.81       516
weighted avg       0.84      0.84      0.84       516



In [None]:
print(
    classification_report(
        test_set_3.labels.values, test_set_3.pred.values, target_names=class_name
    )
)

              precision    recall  f1-score   support

   Incorrect       0.62      0.45      0.52      5000
     Correct       0.57      0.73      0.64      5000

    accuracy                           0.59     10000
   macro avg       0.60      0.59      0.58     10000
weighted avg       0.60      0.59      0.58     10000

