In [1]:
%load_ext autoreload
%autoreload 2

from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm

from utils import get_dls

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
import tensorflow as tf
from tensorflow.keras.models import model_from_json
from tensorflow.keras.layers import Input
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from os.path import join, split, splitext

In [3]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

# tensorflow RNG
tf.random.set_seed(seed)

In [5]:
SEQUENCE_LEN = 500 # Size of input arrays

In [6]:
models_path = Path("./models/")
weights_path = models_path/"stf_no_weights.keras"
json_path = models_path/"cnn_text.json"
tokenizer_path = models_path/"tokenizer.pickle"

In [7]:
data_path = Path("/mnt/nas/backups/08-07-2020/desktopg01/lisa/Data/CSV")

In [8]:
test_data = pd.read_csv(data_path/"test_small.csv")

In [9]:
with open(tokenizer_path, 'rb') as handle:
    tokenizer = pickle.load(handle, encoding="utf-8")

In [10]:
sequences_test = tokenizer.texts_to_sequences(test_data['body'])

In [11]:
X_test = sequence.pad_sequences(sequences_test, maxlen=SEQUENCE_LEN, padding='post')

In [12]:
encoder = LabelEncoder()

In [13]:
test_label = test_data['document_type'] 
test_label_toTest = encoder.fit_transform(test_label)
test_label = np.transpose(test_label_toTest)
test_label = to_categorical(test_label)

X_test = np.array(X_test)

In [14]:
json_file = open(json_path,'r')
loaded_model_json = json_file.read()
json_file.close()
with tf.device('/cpu:0'):
    model = model_from_json(loaded_model_json)
    model.load_weights(weights_path)
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    preds = model.predict(X_test, verbose=1)



In [15]:
preds_text = preds.argmax(axis=1)

In [16]:
preds_text.shape

(95526,)

In [17]:
path = Path("/mnt/nas/backups/08-07-2020/desktopg01/lisa/Data/small_flow")

In [18]:
test_items = get_image_files(path, folders="test")

In [19]:
text_files = set((test_data["file_name"].str.slice(stop=-4) + "_" + test_data["pages"].astype(str)).values)

In [20]:
test_items_filtered = [x for x in test_items if x.with_suffix("").name not in text_files]

In [21]:
len(test_items_filtered)

8037

In [22]:
dls = get_dls(path, 64, 224)

In [23]:
test_dl = dls.test_dl(test_items_filtered, with_labels=True)

In [24]:
learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat())

In [25]:
learn.load("best_image_weights_224")

<fastai.learner.Learner at 0x7ff9048cd490>

In [26]:
preds_img, labels_img = learn.get_preds(dl=test_dl)

In [27]:
preds_img = preds_img.argmax(dim=-1); preds_img.shape

torch.Size([8037])

In [28]:
preds = np.concatenate([preds_text, preds_img]); preds.shape

(103563,)

In [29]:
labels = np.concatenate([test_label_toTest, labels_img]); labels.shape

(103563,)

In [30]:
target_names = ['acordao_de_2_instancia','agravo_em_recurso_extraordinario', 'despacho_de_admissibilidade', 'outros', 'peticao_do_RE', 'sentenca']
print(classification_report(labels, preds, target_names=target_names, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.2692    0.8920    0.4136       287
agravo_em_recurso_extraordinario     0.4408    0.5522    0.4902      2655
     despacho_de_admissibilidade     0.3543    0.5377    0.4271       199
                          outros     0.9655    0.9507    0.9580     92533
                   peticao_do_RE     0.7144    0.7278    0.7211      6386
                        sentenca     0.7800    0.7053    0.7407      1503

                        accuracy                         0.9222    103563
                       macro avg     0.5874    0.7276    0.6251    103563
                    weighted avg     0.9307    0.9222    0.9258    103563



In [31]:
print(classification_report(labels_img, preds_img, target_names=target_names, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.0204    1.0000    0.0400        14
agravo_em_recurso_extraordinario     0.2910    0.7654    0.4217       814
     despacho_de_admissibilidade     0.0000    0.0000    0.0000         1
                          outros     0.9729    0.5846    0.7303      7125
                   peticao_do_RE     0.0143    0.1455    0.0260        55
                        sentenca     0.0144    0.1071    0.0253        28

                        accuracy                         0.5989      8037
                       macro avg     0.2188    0.4338    0.2072      8037
                    weighted avg     0.8922    0.5989    0.6905      8037



In [32]:
print(classification_report(test_label_toTest, preds_text, target_names=target_names, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9132    0.8864    0.8996       273
agravo_em_recurso_extraordinario     0.7114    0.4579    0.5572      1841
     despacho_de_admissibilidade     0.7535    0.5404    0.6294       198
                          outros     0.9651    0.9813    0.9731     85408
                   peticao_do_RE     0.7804    0.7329    0.7559      6331
                        sentenca     0.9191    0.7166    0.8053      1475

                        accuracy                         0.9494     95526
                       macro avg     0.8405    0.7193    0.7701     95526
                    weighted avg     0.9467    0.9494    0.9472     95526



In [33]:
learn.load("img_model_no_weights/best_image_no_weights_224")

<fastai.learner.Learner at 0x7ff9048cd490>

In [34]:
preds_img, labels_img = learn.get_preds(dl=test_dl)

In [35]:
preds_img = preds_img.argmax(dim=-1); preds_img.shape

torch.Size([8037])

In [36]:
preds = np.concatenate([preds_text, preds_img]); preds.shape

(103563,)

In [37]:
labels = np.concatenate([test_label_toTest, labels_img]); labels.shape

(103563,)

In [38]:
target_names = ['acordao_de_2_instancia','agravo_em_recurso_extraordinario', 'despacho_de_admissibilidade', 'outros', 'peticao_do_RE', 'sentenca']
print(classification_report(labels, preds, target_names=target_names, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9132    0.8432    0.8768       287
agravo_em_recurso_extraordinario     0.7114    0.3175    0.4391      2655
     despacho_de_admissibilidade     0.7279    0.5377    0.6185       199
                          outros     0.9585    0.9821    0.9702     92533
                   peticao_do_RE     0.7742    0.7274    0.7500      6386
                        sentenca     0.9191    0.7033    0.7968      1503

                        accuracy                         0.9441    103563
                       macro avg     0.8340    0.6852    0.7419    103563
                    weighted avg     0.9396    0.9441    0.9395    103563



In [39]:
print(classification_report(labels_img, preds_img, target_names=target_names, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.0000    0.0000    0.0000        14
agravo_em_recurso_extraordinario     0.0000    0.0000    0.0000       814
     despacho_de_admissibilidade     0.0000    0.0000    0.0000         1
                          outros     0.8863    0.9924    0.9364      7125
                   peticao_do_RE     0.0926    0.0909    0.0917        55
                        sentenca     0.0000    0.0000    0.0000        28

                        accuracy                         0.8804      8037
                       macro avg     0.1632    0.1806    0.1714      8037
                    weighted avg     0.7864    0.8804    0.8307      8037



  _warn_prf(average, modifier, msg_start, len(result))
