In [42]:
%load_ext autoreload
%autoreload 2

from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import classification_report, accuracy_score, f1_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [43]:
import tensorflow as tf
from tensorflow.keras.models import model_from_json
from tensorflow.keras.layers import Input
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from os.path import join, split, splitext
from pathlib import Path

import pandas as pd
import pickle

import tqdm

In [44]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

In [45]:
OUT_DIM=6

In [46]:
torch.cuda.set_device(1)

In [47]:
SEQUENCE_LEN = 500 # Size of input arrays

In [48]:
models_path = Path("./models/")
weights_path = models_path/"stf_no_weights.keras"
json_path = models_path/"cnn_text.json"
tokenizer_path = models_path/"tokenizer.pickle"

In [49]:
json_file = open(json_path,'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)

In [50]:
model.load_weights(weights_path)

In [51]:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [52]:
data_path = Path("/mnt/nas/backups/08-07-2020/desktopg01/lisa/Data/CSV")

In [53]:
val = pd.read_csv(data_path/"validation_small.csv")
test_data = pd.read_csv(data_path/"test_small.csv")

In [54]:
# file with no images ARE_1072325_312635894_1324_05092017.pdf
test_data_no_img = test_data[(test_data["file_name"]=="ARE_1072325_312635894_1324_05092017.pdf")]
test_data = test_data[~(test_data["file_name"]=="ARE_1072325_312635894_1324_05092017.pdf")]

In [55]:
test_data = pd.concat([test_data, test_data_no_img])

In [56]:
with open(tokenizer_path, 'rb') as handle:
    tokenizer = pickle.load(handle, encoding="utf-8")

In [57]:
sequences_validation = tokenizer.texts_to_sequences(val['body'])
sequences_test = tokenizer.texts_to_sequences(test_data['body'])

In [58]:
X_val = sequence.pad_sequences(sequences_validation, maxlen=SEQUENCE_LEN, padding='post')
X_test = sequence.pad_sequences(sequences_test, maxlen=SEQUENCE_LEN, padding='post')

In [59]:
encoder = LabelEncoder()

In [60]:
valid_label = val['document_type'] 
valid_label_toTest = encoder.fit_transform(valid_label)
valid_label = np.transpose(valid_label_toTest)
valid_label = to_categorical(valid_label)

test_label = test_data['document_type'] 
test_label_toTest = encoder.transform(test_label)
test_label = np.transpose(test_label_toTest)
test_label = to_categorical(test_label)

X_val = np.array(X_val)
X_test = np.array(X_test)

In [61]:
text_probs_val = model.predict(X_val, verbose=1)



In [62]:
pred = text_probs_val.argmax(axis=1)

target_names = ['acordao_de_2_instancia','agravo_em_recurso_extraordinario', 'despacho_de_admissibilidade', 'outros', 'peticao_do_RE', 'sentenca']
print(classification_report(valid_label_toTest, pred, target_names=target_names, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9116    0.7592    0.8285       299
agravo_em_recurso_extraordinario     0.7504    0.4742    0.5811      2149
     despacho_de_admissibilidade     0.7727    0.6503    0.7062       183
                          outros     0.9629    0.9797    0.9712     84104
                   peticao_do_RE     0.7645    0.7456    0.7549      6364
                        sentenca     0.9285    0.6822    0.7865      1636

                        accuracy                         0.9460     94735
                       macro avg     0.8484    0.7152    0.7714     94735
                    weighted avg     0.9437    0.9460    0.9437     94735



In [63]:
text_probs_test = model.predict(X_test, verbose=1)



In [64]:
pred = text_probs_test.argmax(axis=1)
print(classification_report(test_label_toTest, pred, target_names=target_names, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9132    0.8864    0.8996       273
agravo_em_recurso_extraordinario     0.7114    0.4579    0.5572      1841
     despacho_de_admissibilidade     0.7535    0.5404    0.6294       198
                          outros     0.9651    0.9813    0.9731     85408
                   peticao_do_RE     0.7804    0.7329    0.7559      6331
                        sentenca     0.9191    0.7166    0.8053      1475

                        accuracy                         0.9494     95526
                       macro avg     0.8405    0.7193    0.7701     95526
                    weighted avg     0.9467    0.9494    0.9472     95526



In [65]:
path = Path("/mnt/nas/backups/08-07-2020/desktopg01/lisa/Data/small_flow")

In [66]:
val["split"]= "val"

In [67]:
val["path"] = "val/" + val["document_type"] + "/" + val["file_name"].str.rstrip(".pdf") + "_" + val["pages"].astype(str) + ".jpg"

In [68]:
test_data["split"] = "test"
test_data["path"] = "test/" + test_data["document_type"] + "/" + test_data["file_name"].str.rstrip(".pdf") + "_" + test_data["pages"].astype(str) + ".jpg"

In [69]:
df = pd.concat([val, test_data], axis=0)

In [70]:
 df.reset_index(drop=True, inplace=True)

In [71]:
assert df["document_type"].tolist() == val["document_type"].tolist() + test_data["document_type"].tolist()

In [72]:
def splitter(df):
    valid = df[df['split']=='val'].index.tolist()
    test = df[df['split']=='test'].index.tolist()
    return valid,valid, test

In [73]:
def get_x(r): return path/f'{r["path"]}'
def get_y(r): return r['document_type']

In [74]:
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_x=get_x,
                   get_y=get_y,
                   splitter=splitter,
                   item_tfms=Resize(460),
                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75,
                                               do_flip=False, max_rotate=0,
                                               max_warp=0
                                               ),
                               Normalize.from_stats(*imagenet_stats)])

In [75]:
dls = dblock.dataloaders(df.iloc[:-4], bs=64)

In [76]:
dls.vocab

['acordao_de_2_instancia', 'agravo_em_recurso_extraordinario', 'despacho_de_admissibilidade', 'outros', 'peticao_do_RE', 'sentenca']

In [77]:
learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat())

In [397]:
learn.load("./img_model_no_weights/best_image_no_weights_224")

<fastai.learner.Learner at 0x7f395c21bdf0>

In [398]:
img_probs_val, labels_val = learn.get_preds()

In [399]:
preds = np.argmax(img_probs_val, axis=1)
print(classification_report(labels_val, preds, target_names=dls.vocab, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.6667    0.0134    0.0262       299
agravo_em_recurso_extraordinario     0.0000    0.0000    0.0000      2149
     despacho_de_admissibilidade     0.0000    0.0000    0.0000       183
                          outros     0.9058    0.9955    0.9485     84104
                   peticao_do_RE     0.8092    0.2120    0.3359      6364
                        sentenca     0.9350    0.3429    0.5018      1636

                        accuracy                         0.9040     94735
                       macro avg     0.5528    0.2606    0.3021     94735
                    weighted avg     0.8768    0.9040    0.8734     94735



  _warn_prf(average, modifier, msg_start, len(result))


In [400]:
img_probs_test, labels_test = learn.get_preds(ds_idx=2)

In [401]:
dls[2].items.document_type.value_counts(), np.bincount(labels_test), np.bincount(np.argmax(test_label,axis=1))

(outros                              85404
 peticao_do_RE                        6331
 agravo_em_recurso_extraordinario     1841
 sentenca                             1475
 acordao_de_2_instancia                273
 despacho_de_admissibilidade           198
 Name: document_type, dtype: int64,
 array([  273,  1841,   198, 85404,  6331,  1475]),
 array([  273,  1841,   198, 85408,  6331,  1475]))

In [402]:
preds = np.argmax(img_probs_test, axis=1)
print(classification_report(labels_test, preds, target_names=dls.vocab, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     1.0000    0.0366    0.0707       273
agravo_em_recurso_extraordinario     0.0000    0.0000    0.0000      1841
     despacho_de_admissibilidade     0.0000    0.0000    0.0000       198
                          outros     0.9116    0.9961    0.9520     85404
                   peticao_do_RE     0.8408    0.2161    0.3438      6331
                        sentenca     0.9328    0.3295    0.4870      1475

                        accuracy                         0.9101     95522
                       macro avg     0.6142    0.2631    0.3089     95522
                    weighted avg     0.8880    0.9101    0.8817     95522



In [403]:
img_probs_val.shape, img_probs_test.shape, text_probs_val.shape, text_probs_test.shape

(torch.Size([94735, 6]), torch.Size([95522, 6]), (94735, 6), (95526, 6))

In [404]:
labels_val.shape, labels_test.shape, valid_label.shape, test_label.shape

(torch.Size([94735]), torch.Size([95522]), (94735, 6), (95526, 6))

In [405]:
(tensor(np.argmax(valid_label, axis=1)) == labels_val).all()

TensorCategory(True)

In [406]:
(tensor(np.argmax(test_label, axis=1))[:-4] == labels_test).all()

TensorCategory(True)

In [409]:
labels_test = torch.cat([labels_test, tensor(3,3,3,3)])

In [410]:
assert (tensor(np.argmax(test_label, axis=1)) == labels_test).all()

In [172]:
class_priors = tensor(val["document_type"].value_counts().sort_index().to_list())/len(val)

In [420]:
img_probs_test = torch.cat([img_probs_test,tensor([[1/6]*6]*4)])

In [78]:
def late_fusion(pred_image, pred_text, img_weight=1, text_weight=1):
    return pred_image**img_weight * pred_text**text_weight

In [79]:
def evaluate(preds_image, preds_text, targets, img_weight=1, text_weight=1):
    probs = late_fusion(preds_image, preds_text, img_weight, text_weight)
    preds = np.argmax(probs, axis=1)
    print(classification_report(targets, preds, target_names=dls.vocab, digits=4))

In [80]:
weights = [0, 1e-2, 3e-2, 5e-2, 7e-2, 1e-1, 3e-1, 5e-1, 7e-1, 1]

In [81]:
def cross_val(preds_image, preds_text, targets, weights):
    max_f1 = 0
    for img_weight in weights:
        for text_weight in weights:
            probs = late_fusion(preds_image, preds_text, img_weight, text_weight)
            preds = np.argmax(probs, axis=1)
            f1 = f1_score(preds,  targets, average="macro")
            if f1 > max_f1:
                max_f1 = f1
                best_img_weight = img_weight
                best_text_weight = text_weight
    return max_f1, best_img_weight, best_text_weight

In [426]:
max_f1, best_img_weight, best_text_weight = cross_val(img_probs_val, text_probs_val, labels_val, weights)

In [427]:
max_f1, best_img_weight, best_text_weight

(0.7714059568776966, 0, 0.01)

In [428]:
evaluate(img_probs_val, text_probs_val, labels_val, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9116    0.7592    0.8285       299
agravo_em_recurso_extraordinario     0.7504    0.4742    0.5811      2149
     despacho_de_admissibilidade     0.7727    0.6503    0.7062       183
                          outros     0.9629    0.9797    0.9712     84104
                   peticao_do_RE     0.7645    0.7456    0.7549      6364
                        sentenca     0.9285    0.6822    0.7865      1636

                        accuracy                         0.9460     94735
                       macro avg     0.8484    0.7152    0.7714     94735
                    weighted avg     0.9437    0.9460    0.9437     94735



In [429]:
evaluate(img_probs_test, text_probs_test, labels_test, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9132    0.8864    0.8996       273
agravo_em_recurso_extraordinario     0.7114    0.4579    0.5572      1841
     despacho_de_admissibilidade     0.7535    0.5404    0.6294       198
                          outros     0.9651    0.9813    0.9731     85408
                   peticao_do_RE     0.7804    0.7329    0.7559      6331
                        sentenca     0.9191    0.7166    0.8053      1475

                        accuracy                         0.9494     95526
                       macro avg     0.8405    0.7193    0.7701     95526
                    weighted avg     0.9467    0.9494    0.9472     95526



In [82]:
learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat())

In [83]:
learn.load("./best_image_weights_224")

<fastai.learner.Learner at 0x7fde36ab3400>

In [84]:
img_probs_val, labels_val = learn.get_preds()

In [85]:
preds = np.argmax(img_probs_val, axis=1)
print(classification_report(labels_val, preds, target_names=dls.vocab, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.1272    0.8094    0.2199       299
agravo_em_recurso_extraordinario     0.0530    0.7222    0.0987      2149
     despacho_de_admissibilidade     0.0411    0.6885    0.0776       183
                          outros     0.9864    0.4429    0.6113     84104
                   peticao_do_RE     0.2196    0.6596    0.3295      6364
                        sentenca     0.3368    0.7372    0.4623      1636

                        accuracy                         0.4705     94735
                       macro avg     0.2940    0.6766    0.2999     94735
                    weighted avg     0.8979    0.4705    0.5759     94735



In [86]:
img_probs_test, labels_test = learn.get_preds(ds_idx=2)

In [191]:
preds = np.argmax(img_probs_test, axis=1)
print(classification_report(labels_test, preds, target_names=dls.vocab, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.1276    0.9011    0.2235       273
agravo_em_recurso_extraordinario     0.0444    0.7279    0.0837      1841
     despacho_de_admissibilidade     0.0448    0.6717    0.0840       198
                          outros     0.9841    0.4389    0.6071     85408
                   peticao_do_RE     0.2167    0.6381    0.3236      6331
                        sentenca     0.3074    0.7756    0.4403      1475

                        accuracy                         0.4647     95526
                       macro avg     0.2875    0.6922    0.2937     95526
                    weighted avg     0.9003    0.4647    0.5735     95526



In [88]:
labels_test = torch.cat([labels_test, tensor(3,3,3,3)])

In [190]:
img_probs_test = torch.cat([img_probs_test, class_priors.expand(4, -1)]); img_probs_test.shape

torch.Size([95526, 6])

In [192]:
max_f1, best_img_weight, best_text_weight = cross_val(img_probs_val, text_probs_val, labels_val, weights)

In [193]:
max_f1, best_img_weight, best_text_weight

(0.7799706199972015, 0.07, 0.05)

In [194]:
evaluate(img_probs_val, text_probs_val, labels_val, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9160    0.7659    0.8342       299
agravo_em_recurso_extraordinario     0.6535    0.5784    0.6137      2149
     despacho_de_admissibilidade     0.6587    0.7486    0.7008       183
                          outros     0.9664    0.9746    0.9705     84104
                   peticao_do_RE     0.7592    0.7530    0.7561      6364
                        sentenca     0.9288    0.7097    0.8046      1636

                        accuracy                         0.9450     94735
                       macro avg     0.8138    0.7550    0.7800     94735
                    weighted avg     0.9440    0.9450    0.9442     94735



In [195]:
evaluate(img_probs_test, text_probs_test, labels_test, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9308    0.8864    0.9081       273
agravo_em_recurso_extraordinario     0.6155    0.5470    0.5792      1841
     despacho_de_admissibilidade     0.6667    0.6465    0.6564       198
                          outros     0.9679    0.9763    0.9721     85408
                   peticao_do_RE     0.7704    0.7421    0.7560      6331
                        sentenca     0.9173    0.7444    0.8219      1475

                        accuracy                         0.9480     95526
                       macro avg     0.8114    0.7571    0.7823     95526
                    weighted avg     0.9466    0.9480    0.9470     95526



In [100]:
just_imgs_df_trainVal = pd.read_csv("./data/just_img_train_val.csv",index_col=0)

In [102]:
just_imgs_df_test = pd.read_csv("./data/just_img_test.csv", index_col=0)

In [112]:
just_imgs_df_trainVal["activation_path"] = just_imgs_df_trainVal["activation_path"].str[16:] + ".jpg"

In [113]:
just_imgs_df_test["activation_path"] = just_imgs_df_test["activation_path"].str[16:] + ".jpg"

In [114]:
just_imgs_df_trainVal.head()

Unnamed: 0,themes,process_id,file_name,document_type,pages,body,activation_path,is_valid,has_text,has_image
0,,ARE_1053618,ARE_1053618_311977154_1420_07062017_299.pdf,agravo_em_recurso_extraordinario,299,,val/agravo_em_recurso_extraordinario/ARE_1053618_311977154_1420_07062017_299.jpg,True,False,True
1,,ARE_1053618,ARE_1053618_311977154_1420_07062017_288.pdf,agravo_em_recurso_extraordinario,288,,val/agravo_em_recurso_extraordinario/ARE_1053618_311977154_1420_07062017_288.jpg,True,False,True
2,,ARE_1053618,ARE_1053618_311977154_1420_07062017_248.pdf,agravo_em_recurso_extraordinario,248,,val/agravo_em_recurso_extraordinario/ARE_1053618_311977154_1420_07062017_248.jpg,True,False,True
3,,ARE_1064575,ARE_1064575_312350997_1420_03082017_54.pdf,agravo_em_recurso_extraordinario,54,,val/agravo_em_recurso_extraordinario/ARE_1064575_312350997_1420_03082017_54.jpg,True,False,True
4,,ARE_1073427,ARE_1073427_312673900_1420_05092017_41.pdf,agravo_em_recurso_extraordinario,41,,val/agravo_em_recurso_extraordinario/ARE_1073427_312673900_1420_05092017_41.jpg,True,False,True


In [117]:
just_imgs_df_val = just_imgs_df_trainVal[just_imgs_df_trainVal["is_valid"]]

In [120]:
just_imgs_df_val["split"]= "val"

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  just_imgs_df_val["split"]= "val"


In [127]:
just_imgs_df_test["split"]= "test"

In [129]:
img_df = pd.concat([just_imgs_df_val, just_imgs_df_test], axis=0)

In [132]:
img_df.reset_index(drop=True, inplace=True)

In [133]:
img_df

Unnamed: 0,themes,process_id,file_name,document_type,pages,body,activation_path,is_valid,has_text,has_image,split
0,,ARE_1053618,ARE_1053618_311977154_1420_07062017_299.pdf,agravo_em_recurso_extraordinario,299,,val/agravo_em_recurso_extraordinario/ARE_1053618_311977154_1420_07062017_299.jpg,1.0,False,True,val
1,,ARE_1053618,ARE_1053618_311977154_1420_07062017_288.pdf,agravo_em_recurso_extraordinario,288,,val/agravo_em_recurso_extraordinario/ARE_1053618_311977154_1420_07062017_288.jpg,1.0,False,True,val
2,,ARE_1053618,ARE_1053618_311977154_1420_07062017_248.pdf,agravo_em_recurso_extraordinario,248,,val/agravo_em_recurso_extraordinario/ARE_1053618_311977154_1420_07062017_248.jpg,1.0,False,True,val
3,,ARE_1064575,ARE_1064575_312350997_1420_03082017_54.pdf,agravo_em_recurso_extraordinario,54,,val/agravo_em_recurso_extraordinario/ARE_1064575_312350997_1420_03082017_54.jpg,1.0,False,True,val
4,,ARE_1073427,ARE_1073427_312673900_1420_05092017_41.pdf,agravo_em_recurso_extraordinario,41,,val/agravo_em_recurso_extraordinario/ARE_1073427_312673900_1420_05092017_41.jpg,1.0,False,True,val
...,...,...,...,...,...,...,...,...,...,...,...
20874,,ARE_1130738,ARE_1130738_314300437_12_08052018_79.pdf,peticao_do_RE,79,,test/peticao_do_RE/ARE_1130738_314300437_12_08052018_79.jpg,,False,True,test
20875,,ARE_721616,ARE_721616_2218683_12_26072013_15.pdf,peticao_do_RE,15,,test/peticao_do_RE/ARE_721616_2218683_12_26072013_15.jpg,,False,True,test
20876,,ARE_1062698,ARE_1062698_312268390_12_27072017_22.pdf,peticao_do_RE,22,,test/peticao_do_RE/ARE_1062698_312268390_12_27072017_22.jpg,,False,True,test
20877,,ARE_1062698,ARE_1062698_312268390_12_27072017_19.pdf,peticao_do_RE,19,,test/peticao_do_RE/ARE_1062698_312268390_12_27072017_19.jpg,,False,True,test


In [134]:
assert img_df["document_type"].tolist() == just_imgs_df_val["document_type"].tolist() + just_imgs_df_test["document_type"].tolist()

In [135]:
img_df.rename(columns={"activation_path":"path"}, inplace=True)

In [137]:
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_x=get_x,
                   get_y=get_y,
                   splitter=splitter,
                   item_tfms=Resize(460),
                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75,
                                               do_flip=False, max_rotate=0,
                                               max_warp=0
                                               ),
                               Normalize.from_stats(*imagenet_stats)])

In [138]:
dls_img = dblock.dataloaders(img_df, bs=64)

In [141]:
learn = cnn_learner(dls_img, resnet50, loss_func=CrossEntropyLossFlat())

In [142]:
learn.load("./best_image_weights_224")

<fastai.learner.Learner at 0x7fdf2afb1b20>

In [143]:
just_img_probs_val, just_img_labels_val = learn.get_preds()

In [144]:
preds = np.argmax(just_img_probs_val, axis=1)
print(classification_report(just_img_labels_val, preds, target_names=dls.vocab, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.0276    1.0000    0.0538        21
agravo_em_recurso_extraordinario     0.1643    0.7288    0.2681       649
     despacho_de_admissibilidade     0.0096    0.3333    0.0186         6
                          outros     0.9643    0.6242    0.7578     11498
                   peticao_do_RE     0.1245    0.2440    0.1649       623
                        sentenca     0.0091    0.0667    0.0160        45

                        accuracy                         0.6096     12842
                       macro avg     0.2166    0.4995    0.2132     12842
                    weighted avg     0.8778    0.6096    0.7002     12842



In [146]:
just_img_probs_test, just_img_labels_test = learn.get_preds(ds_idx=2)

In [147]:
preds = np.argmax(just_img_probs_test, axis=1)
print(classification_report(just_img_labels_test, preds, target_names=dls.vocab, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.0204    1.0000    0.0400        14
agravo_em_recurso_extraordinario     0.2910    0.7654    0.4217       814
     despacho_de_admissibilidade     0.0000    0.0000    0.0000         1
                          outros     0.9729    0.5846    0.7303      7125
                   peticao_do_RE     0.0143    0.1455    0.0260        55
                        sentenca     0.0144    0.1071    0.0253        28

                        accuracy                         0.5989      8037
                       macro avg     0.2188    0.4338    0.2072      8037
                    weighted avg     0.8922    0.5989    0.6905      8037



In [148]:
all_img_probs_val = torch.cat([img_probs_val, just_img_probs_val]); all_img_probs_val.shape

torch.Size([107577, 6])

In [151]:
all_text_probs_val = torch.cat([tensor(text_probs_val), tensor([[1/6]*6]*12842)]); all_text_probs_val.shape

torch.Size([107577, 6])

In [152]:
all_labels_val = torch.cat([labels_val, just_img_labels_val]); all_labels_val.shape

torch.Size([107577])

In [153]:
evaluate(all_img_probs_val, all_text_probs_val, all_labels_val, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.2475    0.7812    0.3759       320
agravo_em_recurso_extraordinario     0.3589    0.6133    0.4528      2798
     despacho_de_admissibilidade     0.3333    0.7354    0.4587       189
                          outros     0.9663    0.9324    0.9490     95602
                   peticao_do_RE     0.6563    0.7076    0.6810      6987
                        sentenca     0.7367    0.6924    0.7139      1681

                        accuracy                         0.9050    107577
                       macro avg     0.5498    0.7437    0.6052    107577
                    weighted avg     0.9235    0.9050    0.9125    107577



In [154]:
all_img_probs_test = torch.cat([img_probs_test, just_img_probs_test]); all_img_probs_test.shape

torch.Size([103563, 6])

In [156]:
all_text_probs_test = torch.cat([tensor(text_probs_test), tensor([[1/6]*6]*8037)]); all_text_probs_test.shape

torch.Size([103563, 6])

In [157]:
all_labels_test = torch.cat([labels_test, just_img_labels_test]); all_labels_test.shape

torch.Size([103563])

In [158]:
evaluate(all_img_probs_test, all_text_probs_test, all_labels_test, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.2706    0.8920    0.4152       287
agravo_em_recurso_extraordinario     0.4316    0.6139    0.5068      2655
     despacho_de_admissibilidade     0.3636    0.6432    0.4646       199
                          outros     0.9682    0.9461    0.9570     92533
                   peticao_do_RE     0.7068    0.7369    0.7216      6386
                        sentenca     0.7831    0.7325    0.7570      1503

                        accuracy                         0.9209    103563
                       macro avg     0.5873    0.7608    0.6370    103563
                    weighted avg     0.9325    0.9209    0.9256    103563



In [179]:
all_text_probs_val = torch.cat([tensor(text_probs_val), class_priors.expand(12842,-1)]); all_text_probs_val.shape

torch.Size([107577, 6])

In [181]:
evaluate(all_img_probs_val, all_text_probs_val, all_labels_val, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9160    0.7156    0.8035       320
agravo_em_recurso_extraordinario     0.6185    0.5250    0.5679      2798
     despacho_de_admissibilidade     0.6587    0.7249    0.6902       189
                          outros     0.9605    0.9733    0.9668     95602
                   peticao_do_RE     0.7387    0.6994    0.7185      6987
                        sentenca     0.9244    0.6907    0.7906      1681

                        accuracy                         0.9382    107577
                       macro avg     0.8028    0.7215    0.7563    107577
                    weighted avg     0.9360    0.9382    0.9366    107577



In [182]:
all_text_probs_test = torch.cat([tensor(text_probs_test), class_priors.expand(8037,-1)]); all_text_probs_test.shape

torch.Size([103563, 6])

In [183]:
evaluate(all_img_probs_test, all_text_probs_test, all_labels_test, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9308    0.8432    0.8848       287
agravo_em_recurso_extraordinario     0.6203    0.4934    0.5496      2655
     despacho_de_admissibilidade     0.6632    0.6432    0.6531       199
                          outros     0.9642    0.9750    0.9696     92533
                   peticao_do_RE     0.7552    0.7366    0.7458      6386
                        sentenca     0.9150    0.7305    0.8124      1503

                        accuracy                         0.9434    103563
                       macro avg     0.8081    0.7370    0.7692    103563
                    weighted avg     0.9411    0.9434    0.9419    103563



In [180]:
val_preds =  late_fusion(img_probs_val, text_probs_val, img_weight=best_img_weight, text_weight=best_text_weight)

In [451]:
test_preds =  late_fusion(img_probs_test, text_probs_test, img_weight=best_img_weight, text_weight=best_text_weight)

In [470]:
train = pd.read_csv(data_path/"train_small.csv")

In [471]:
sequences_train = tokenizer.texts_to_sequences(train['body'])

In [472]:
X_train = sequence.pad_sequences(sequences_train, maxlen=SEQUENCE_LEN, padding='post')

In [473]:
train_label = train['document_type'] 
train_label_toTest = encoder.transform(train_label)
train_label = np.transpose(train_label_toTest)
train_label = to_categorical(train_label)

X_train = np.array(X_train)

In [474]:
text_probs_train = model.predict(X_train, verbose=1)



In [475]:
pred = text_probs_train.argmax(axis=1)

print(classification_report(train_label_toTest, pred, target_names=target_names, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9563    0.9892    0.9724       553
agravo_em_recurso_extraordinario     0.9790    0.8413    0.9049      2546
     despacho_de_admissibilidade     0.8943    0.8555    0.8744       346
                          outros     0.9950    0.9938    0.9944    134134
                   peticao_do_RE     0.9231    0.9741    0.9479      9509
                        sentenca     0.9853    0.9779    0.9816      2129

                        accuracy                         0.9894    149217
                       macro avg     0.9555    0.9386    0.9460    149217
                    weighted avg     0.9896    0.9894    0.9893    149217



In [476]:
path = Path("/mnt/nas/backups/08-07-2020/desktopg01/lisa/Data/small_flow")

In [478]:
train["path"] = "train/" + train["document_type"] + "/" + train["file_name"].str.rstrip(".pdf") + "_" + train["pages"].astype(str) + ".jpg"

In [495]:
train["split"]="val"

In [496]:
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_x=get_x,
                   get_y=get_y,
                   splitter=splitter,
                   item_tfms=Resize(460),
                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75,
                                               do_flip=False, max_rotate=0,
                                               max_warp=0
                                               ),
                               Normalize.from_stats(*imagenet_stats)])

In [497]:
dls = dblock.dataloaders(train, bs=64)

In [499]:
dls.vocab

['acordao_de_2_instancia', 'agravo_em_recurso_extraordinario', 'despacho_de_admissibilidade', 'outros', 'peticao_do_RE', 'sentenca']

In [507]:
learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat())

In [508]:
learn.load("./best_image_weights_224")

<fastai.learner.Learner at 0x7f39709a8df0>

In [509]:
img_probs_train, labels_train = learn.get_preds()

In [510]:
img_probs_train.shape,labels_train.shape

(torch.Size([149217, 6]), torch.Size([149217]))

In [511]:
preds = np.argmax(img_probs_train, axis=1)
print(classification_report(labels_train, preds, target_names=dls.vocab, digits=4))

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.1600    0.7541    0.2639       553
agravo_em_recurso_extraordinario     0.0507    0.8425    0.0956      2546
     despacho_de_admissibilidade     0.0618    0.7688    0.1144       346
                          outros     0.9914    0.4850    0.6513    134134
                   peticao_do_RE     0.2195    0.6585    0.3293      9509
                        sentenca     0.2982    0.8192    0.4373      2129

                        accuracy                         0.5086    149217
                       macro avg     0.2969    0.7213    0.3153    149217
                    weighted avg     0.9110    0.5086    0.6156    149217



In [512]:
evaluate(img_probs_train, text_probs_train, labels_train, best_img_weight, best_text_weight)

                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9786    0.9910    0.9847       553
agravo_em_recurso_extraordinario     0.8812    0.9643    0.9209      2546
     despacho_de_admissibilidade     0.7388    0.9566    0.8338       346
                          outros     0.9979    0.9901    0.9940    134134
                   peticao_do_RE     0.9124    0.9787    0.9444      9509
                        sentenca     0.9823    0.9901    0.9862      2129

                        accuracy                         0.9888    149217
                       macro avg     0.9152    0.9785    0.9440    149217
                    weighted avg     0.9896    0.9888    0.9890    149217



In [513]:
train_preds =  late_fusion(img_probs_train, text_probs_train, img_weight=best_img_weight, text_weight=best_text_weight)

In [514]:
train["document_type"] = train.apply(lambda x: "B-" + x["document_type"] if x["pages"] == 1 else "I-" + x["document_type"],
                                     axis=1)
val["document_type"] = val.apply(lambda x: "B-" + x["document_type"] if x["pages"] == 1 else "I-" + x["document_type"],
                                     axis=1)
test_data["document_type"] = test_data.apply(lambda x: "B-" + x["document_type"] if x["pages"] == 1 else "I-" + x["document_type"],
                                   axis=1)

In [515]:
def data_to_process(data, vectors):
    xs = []
    ys = []
    data["data"] = vectors.tolist()
    for k, v in data.groupby("process_id").groups.items():
        xs.append(data.iloc[v]["data"].tolist())
        ys.append(data.iloc[v]["document_type"].tolist())
    return xs, ys

In [516]:
X_train, y_train = data_to_process(train, train_preds)
X_valid, y_valid = data_to_process(val, val_preds)
X_test, y_test = data_to_process(test_data, test_preds)

In [517]:
 len(X_train), len(y_train)

(2743, 2743)

In [518]:
def data2feat(data):
    feat_data = []
    for i, sentence in enumerate(data):
        feat_data.append([])
        for j, token in enumerate(sentence):
            feat_data[i].append({ str(i) : d for i, d in enumerate(token)})
    return feat_data

In [519]:
X_train = data2feat(X_train)
X_valid = data2feat(X_valid)
X_test = data2feat(X_test)

In [522]:
import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

In [523]:
crf = sklearn_crfsuite.CRF(
    verbose=True,
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True,
    all_possible_states=True
)
crf.fit(X_train, y_train)

loading training data to CRFsuite: 100%|██████████| 2743/2743 [00:00<00:00, 4619.85it/s]



Feature generation
type: CRF1d
feature.minfreq: 0.000000
feature.possible_states: 1
feature.possible_transitions: 1
0....1....2....3....4....5....6....7....8....9....10
Number of features: 216
Seconds required: 0.301

L-BFGS optimization
c1: 0.100000
c2: 0.100000
num_memories: 6
max_iterations: 100
epsilon: 0.000010
stop: 10
delta: 0.000010
linesearch: MoreThuente
linesearch.max_iterations: 20

Iter 1   time=0.24  loss=252834.33 active=216   feature_norm=1.00
Iter 2   time=0.39  loss=179108.18 active=215   feature_norm=3.27
Iter 3   time=0.12  loss=132033.25 active=210   feature_norm=2.97
Iter 4   time=0.24  loss=119826.12 active=212   feature_norm=3.23
Iter 5   time=0.12  loss=111947.48 active=214   feature_norm=3.43
Iter 6   time=0.12  loss=106541.50 active=213   feature_norm=3.90
Iter 7   time=0.12  loss=100475.00 active=207   feature_norm=4.63
Iter 8   time=0.12  loss=88958.23 active=215   feature_norm=5.01
Iter 9   time=0.12  loss=87917.94 active=215   feature_norm=5.27
Iter 10  



CRF(algorithm='lbfgs', all_possible_states=True, all_possible_transitions=True,
    c1=0.1, c2=0.1, keep_tempfiles=None, max_iterations=100, verbose=True)

In [524]:
labels = crf.classes_; labels

['B-outros',
 'I-outros',
 'B-sentenca',
 'I-sentenca',
 'B-peticao_do_RE',
 'I-peticao_do_RE',
 'B-despacho_de_admissibilidade',
 'B-acordao_de_2_instancia',
 'B-agravo_em_recurso_extraordinario',
 'I-agravo_em_recurso_extraordinario',
 'I-acordao_de_2_instancia',
 'I-despacho_de_admissibilidade']

In [525]:
# group B and I results
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
y_pred = crf.predict(X_test)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=4
))



                                    precision    recall  f1-score   support

          B-acordao_de_2_instancia     0.9514    0.8844    0.9167       199
          I-acordao_de_2_instancia     0.9825    0.7568    0.8550        74
B-agravo_em_recurso_extraordinario     0.5455    0.3662    0.4382       213
I-agravo_em_recurso_extraordinario     0.7402    0.5092    0.6033      1628
     B-despacho_de_admissibilidade     0.7455    0.5578    0.6381       147
     I-despacho_de_admissibilidade     0.3333    0.1569    0.2133        51
                          B-outros     0.7914    0.2128    0.3354     25744
                          I-outros     0.7105    0.9670    0.8191     59664
                   B-peticao_do_RE     0.8258    0.4712    0.6000       312
                   I-peticao_do_RE     0.9140    0.6782    0.7786      6019
                        B-sentenca     0.8990    0.6717    0.7689       265
                        I-sentenca     0.9663    0.7347    0.8347      1210

          

In [526]:
valid_fold = np.array([-1]*len(X_train) + [0]*len(X_valid));  valid_fold.shape

(4636,)

In [527]:
X_train_val = np.concatenate([X_train, X_valid]); X_train_val.shape



(4636,)

In [528]:
y_train_val = np.concatenate([y_train, y_valid]); y_train_val.shape

(4636,)

In [531]:
import scipy
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score, RandomizedSearchCV, PredefinedSplit


# define fixed parameters and parameters to search
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    max_iterations=100,
    all_possible_transitions=True,
    all_possible_states=True
)
params_space = {
    'c1': scipy.stats.expon(scale=0.5),
    'c2': scipy.stats.expon(scale=0.05),
}
ps = PredefinedSplit(valid_fold)

# use the same metric for evaluation
f1_scorer = make_scorer(metrics.flat_f1_score,
                        average='macro', labels=labels)

# search
rs = RandomizedSearchCV(crf, params_space,
                        cv=ps,
                        verbose=1,
                        n_jobs=20,
                        n_iter=150,
                        scoring=f1_scorer,
                        random_state=42)
rs.fit(X_train_val, y_train_val)

Fitting 1 folds for each of 150 candidates, totalling 150 fits


[Parallel(n_jobs=20)]: Using backend LokyBackend with 20 concurrent workers.
[Parallel(n_jobs=20)]: Done  10 tasks      | elapsed:   29.2s
[Parallel(n_jobs=20)]: Done 150 out of 150 | elapsed:  2.8min finished


RandomizedSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ...,  0,  0])),
                   estimator=CRF(algorithm='lbfgs', all_possible_states=True,
                                 all_possible_transitions=True,
                                 keep_tempfiles=None, max_iterations=100),
                   n_iter=150, n_jobs=20,
                   param_distributions={'c1': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f3970c75b50>,
                                        'c2': <scipy.stats._distn_infrastruct...
                   random_state=42,
                   scoring=make_scorer(flat_f1_score, average=macro, labels=['B-outros', 'I-outros', 'B-sentenca', 'I-sentenca', 'B-peticao_do_RE', 'I-peticao_do_RE', 'B-despacho_de_admissibilidade', 'B-acordao_de_2_instancia', 'B-agravo_em_recurso_extraordinario', 'I-agravo_em_recurso_extraordinario', 'I-acordao_de_2_instancia', 'I-despacho_de_admissibilidade']),
                   verbose=1)

In [532]:
print('best params:', rs.best_params_)
print('best CV score:', rs.best_score_)
print('model size: {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000)) 

best params: {'c1': 0.048863070237470364, 'c2': 0.11375616125972615}
best CV score: 0.6803931788418751
model size: 0.01M


In [533]:
crf = sklearn_crfsuite.CRF(
    verbose=True,
    algorithm='lbfgs',
    c1=rs.best_params_['c1'],
    c2=rs.best_params_['c2'],
    max_iterations=1000,
    all_possible_transitions=True,
    all_possible_states=True
)
crf.fit(X_train, y_train)

loading training data to CRFsuite: 100%|██████████| 2743/2743 [00:00<00:00, 4754.34it/s]



Feature generation
type: CRF1d
feature.minfreq: 0.000000
feature.possible_states: 1
feature.possible_transitions: 1
0....1....2....3....4....5....6....7....8....9....10
Number of features: 216
Seconds required: 0.301

L-BFGS optimization
c1: 0.048863
c2: 0.113756
num_memories: 6
max_iterations: 1000
epsilon: 0.000010
stop: 10
delta: 0.000010
linesearch: MoreThuente
linesearch.max_iterations: 20

Iter 1   time=0.24  loss=252834.12 active=216   feature_norm=1.00
Iter 2   time=0.36  loss=179107.32 active=215   feature_norm=3.27
Iter 3   time=0.12  loss=132032.40 active=210   feature_norm=2.97
Iter 4   time=0.24  loss=119825.32 active=212   feature_norm=3.23
Iter 5   time=0.12  loss=111946.56 active=214   feature_norm=3.43
Iter 6   time=0.12  loss=106540.56 active=213   feature_norm=3.90
Iter 7   time=0.12  loss=100474.01 active=207   feature_norm=4.63
Iter 8   time=0.12  loss=88957.06 active=215   feature_norm=5.01
Iter 9   time=0.12  loss=87916.54 active=215   feature_norm=5.27
Iter 10 

CRF(algorithm='lbfgs', all_possible_states=True, all_possible_transitions=True,
    c1=0.048863070237470364, c2=0.11375616125972615, keep_tempfiles=None,
    max_iterations=1000, verbose=True)

In [534]:
# group B and I results
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
y_pred = crf.predict(X_test)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=4
))



                                    precision    recall  f1-score   support

          B-acordao_de_2_instancia     0.9516    0.8894    0.9195       199
          I-acordao_de_2_instancia     0.9828    0.7703    0.8636        74
B-agravo_em_recurso_extraordinario     0.5399    0.4131    0.4681       213
I-agravo_em_recurso_extraordinario     0.7021    0.5240    0.6001      1628
     B-despacho_de_admissibilidade     0.7745    0.5374    0.6345       147
     I-despacho_de_admissibilidade     0.4242    0.2745    0.3333        51
                          B-outros     0.7590    0.2169    0.3373     25744
                          I-outros     0.7118    0.9592    0.8172     59664
                   B-peticao_do_RE     0.7799    0.5224    0.6257       312
                   I-peticao_do_RE     0.8896    0.6880    0.7759      6019
                        B-sentenca     0.8932    0.6943    0.7813       265
                        I-sentenca     0.9722    0.7521    0.8481      1210

          

In [535]:
y_pred_class = []
y_test_class = []

for i, sequence in enumerate(y_pred):
    y_pred_class.append([])
    for j, pred in enumerate(sequence):
        y_pred_class[i].append(pred[2:])
        
for i, sequence in enumerate(y_test):
    y_test_class.append([])
    for j, pred in enumerate(sequence):
        y_test_class[i].append(pred[2:])

In [536]:
sorted_labels=['acordao_de_2_instancia', 'agravo_em_recurso_extraordinario',
                'despacho_de_admissibilidade', 'outros', 'peticao_do_RE', 'sentenca']

print(metrics.flat_classification_report(
    y_test_class, y_pred_class, labels=sorted_labels, digits=4
))



                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9795    0.8755    0.9246       273
agravo_em_recurso_extraordinario     0.7003    0.5242    0.5996      1841
     despacho_de_admissibilidade     0.7630    0.5202    0.6186       198
                          outros     0.9630    0.9896    0.9761     85408
                   peticao_do_RE     0.8906    0.6843    0.7739      6331
                        sentenca     0.9685    0.7498    0.8452      1475

                        accuracy                         0.9554     95526
                       macro avg     0.8775    0.7239    0.7897     95526
                    weighted avg     0.9529    0.9554    0.9525     95526



In [537]:
with open("models/crf_late_fusion.pkl", "wb") as file:
    pickle.dump(crf, file)

In [538]:
 from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:]) 

Top likely transitions:
I-peticao_do_RE -> I-peticao_do_RE 6.854842
B-peticao_do_RE -> I-peticao_do_RE 6.427525
B-agravo_em_recurso_extraordinario -> I-agravo_em_recurso_extraordinario 5.681058
I-agravo_em_recurso_extraordinario -> I-agravo_em_recurso_extraordinario 5.528226
I-sentenca -> I-sentenca 5.163614
I-outros -> I-outros 4.898594
B-sentenca -> I-sentenca 4.392003
B-acordao_de_2_instancia -> I-acordao_de_2_instancia 4.309193
I-acordao_de_2_instancia -> I-acordao_de_2_instancia 3.504246
B-despacho_de_admissibilidade -> I-despacho_de_admissibilidade 2.907754
I-despacho_de_admissibilidade -> I-despacho_de_admissibilidade 2.652916
I-outros -> B-sentenca 2.606884
B-outros -> I-outros 1.830688
B-acordao_de_2_instancia -> B-outros 1.487517
I-acordao_de_2_instancia -> B-outros 1.469131
I-outros -> B-acordao_de_2_instancia 1.333140
B-despacho_de_admissibilidade -> B-acordao_de_2_instancia 1.083396
I-outros -> B-outros 1.034921
B-despacho_de_admissibilidade -> B-outros 0.954576
I-outros -

In [539]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, sorted_labels[int(attr)]))

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])

Top positive:
11.299393 B-acordao_de_2_instancia acordao_de_2_instancia
9.899779 B-sentenca sentenca
9.070885 B-despacho_de_admissibilidade despacho_de_admissibilidade
8.078372 I-sentenca sentenca
8.037510 B-agravo_em_recurso_extraordinario agravo_em_recurso_extraordinario
7.938671 I-acordao_de_2_instancia acordao_de_2_instancia
6.532653 B-outros outros
6.518911 I-despacho_de_admissibilidade despacho_de_admissibilidade
5.869947 I-agravo_em_recurso_extraordinario agravo_em_recurso_extraordinario
5.779349 B-peticao_do_RE peticao_do_RE
3.440000 I-peticao_do_RE peticao_do_RE
1.852259 I-despacho_de_admissibilidade acordao_de_2_instancia
1.796359 I-outros outros
1.512521 B-despacho_de_admissibilidade acordao_de_2_instancia
1.204668 B-despacho_de_admissibilidade agravo_em_recurso_extraordinario
0.534983 B-outros agravo_em_recurso_extraordinario
0.396080 B-acordao_de_2_instancia sentenca
0.295299 B-peticao_do_RE agravo_em_recurso_extraordinario
0.175319 I-acordao_de_2_instancia despacho_de_adm