In [4]:
%load_ext autoreload
%autoreload 2

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

In [5]:
from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import classification_report

In [6]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

In [25]:
OUT_DIM=6

In [12]:
class GetActs(Transform):
    def encodes(self, x):        
        img_file = text_file = None
        
        if x["has_text"]:
            text_file = Path(x["activation_path"] + ".npy")
            if x["has_image"]:
                img_file = Path(text_file.as_posix().replace("text", "img").replace("npy", "pt"))
        else:
            img_file = Path(x["activation_path"] + ".pt")
        
        if img_file is None:
            img_act = torch.zeros((4096))
        else:
            img_act = torch.load(img_file)
                            
        if text_file is None:
            text_act = torch.zeros((3840))
            text_none = True
        else:
            text_act = tensor(np.load(text_file))
        
        img_none = img_file == None
        text_none = text_file == None
                            
        return (img_act, text_act, img_none, text_none)

In [22]:
class ImgTextFusion(Module):
    def __init__(self, head, embs_for_none=True, img_emb_dim=4096, text_emb_dim=3840):
        self.head = head.cuda()
        self.embs_for_none = embs_for_none
        if embs_for_none:
            self.img_none_emb = torch.nn.Embedding(num_embeddings=1, embedding_dim=img_emb_dim).cuda()
            self.text_none_emb = torch.nn.Embedding(num_embeddings=1, embedding_dim=text_emb_dim).cuda()
            self.index= tensor(0).cuda()
    
    def forward(self, x):
        img_act, text_act, img_none, text_none = x
        if self.embs_for_none:
            img_act[img_none] = self.img_none_emb(self.index)
            text_act[text_none] = self.text_none_emb(self.index)
        return self.head(torch.cat([img_act, text_act], axis=-1))

In [23]:
def create_head(nf, n_out, lin_ftrs=None, ps=0.5, bn_final=False, lin_first=False):
    "Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes."
    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]
    ps = L(ps)
    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    if lin_first: layers.append(nn.Dropout(ps.pop(0)))
    for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):
        layers += LinBnDrop(ni, no, bn=True, p=p, act=actn, lin_first=lin_first)
    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
    return nn.Sequential(*layers)

In [19]:
dls = torch.load("./data/fusion_dl_v2.pth")

In [31]:
dls.train.shuffle = False; dls.train.get_idxs()[:5]

[0, 1, 2, 3, 4]

In [34]:
dls.train.drop_last = False; dls.train.drop_last

False

In [21]:
test_dl = torch.load("./data/test_dl_fusion_text.pth")

In [26]:
head = create_head(4096 + 3840, OUT_DIM, lin_ftrs=[128])

In [27]:
model = ImgTextFusion(head)

In [28]:
learn = Learner(dls, model)

In [29]:
learn.load("best_fusion_128_moreEpochs")

<fastai.learner.Learner at 0x7f070c3ee370>

In [36]:
train_preds, _ = learn.get_preds(0)

In [84]:
valid_preds, _ = learn.get_preds()

In [38]:
test_preds, _ = learn.get_preds(dl=test_dl)

In [48]:
train_idx = tensor(dls.train.items[dls.train.items["has_text"]].index.values)

In [67]:
train_preds_filtered = torch.index_select(train_preds, 0, train_idx)

In [85]:
valid_idx = tensor(dls.valid.items[dls.valid.items["has_text"]].index.values)

In [86]:
valid_preds_filtered = torch.index_select(valid_preds, 0, valid_idx)

In [97]:
train = pd.read_csv("./data/train_small.csv")
val = pd.read_csv("./data/validation_small.csv")
test_data = pd.read_csv("./data/test_small.csv")

In [98]:
train["document_type"] = train.apply(lambda x: "B-" + x["document_type"] if x["pages"] == 1 else "I-" + x["document_type"],
                                     axis=1)
val["document_type"] = val.apply(lambda x: "B-" + x["document_type"] if x["pages"] == 1 else "I-" + x["document_type"],
                                     axis=1)
test_data["document_type"] = test_data.apply(lambda x: "B-" + x["document_type"] if x["pages"] == 1 else "I-" + x["document_type"],
                                   axis=1)

In [99]:
def data_to_process(data, vectors):
    xs = []
    ys = []
    data["data"] = vectors.tolist()
    for k, v in data.groupby("process_id").groups.items():
        xs.append(data.iloc[v]["data"].tolist())
        ys.append(data.iloc[v]["document_type"].tolist())
    return xs, ys

In [100]:
X_train, y_train = data_to_process(train, train_preds_filtered)
X_valid, y_valid = data_to_process(val, valid_preds_filtered)
X_test, y_test = data_to_process(test_data, test_preds)

In [101]:
 len(X_train), len(y_train)

(2743, 2743)

In [102]:
def data2feat(data):
    feat_data = []
    for i, sentence in enumerate(data):
        feat_data.append([])
        for j, token in enumerate(sentence):
            feat_data[i].append({ str(i) : d for i, d in enumerate(token)})
    return feat_data

In [103]:
X_train = data2feat(X_train)
X_valid = data2feat(X_valid)
X_test = data2feat(X_test)

In [104]:
crf = sklearn_crfsuite.CRF(
    verbose=True,
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True,
    all_possible_states=True
)
crf.fit(X_train, y_train)

loading training data to CRFsuite: 100%|██████████| 2743/2743 [00:00<00:00, 4273.66it/s]



Feature generation
type: CRF1d
feature.minfreq: 0.000000
feature.possible_states: 1
feature.possible_transitions: 1
0....1....2....3....4....5....6....7....8....9....10
Number of features: 216
Seconds required: 0.290

L-BFGS optimization
c1: 0.100000
c2: 0.100000
num_memories: 6
max_iterations: 100
epsilon: 0.000010
stop: 10
delta: 0.000010
linesearch: MoreThuente
linesearch.max_iterations: 20

Iter 1   time=0.21  loss=248613.21 active=216   feature_norm=1.00
Iter 2   time=0.32  loss=160534.77 active=216   feature_norm=3.14
Iter 3   time=0.10  loss=122377.75 active=216   feature_norm=3.02
Iter 4   time=0.21  loss=116020.00 active=214   feature_norm=3.49
Iter 5   time=0.11  loss=103650.28 active=216   feature_norm=3.47
Iter 6   time=0.11  loss=99763.12 active=215   feature_norm=3.68
Iter 7   time=0.11  loss=96428.39 active=216   feature_norm=3.95
Iter 8   time=0.11  loss=92152.49 active=215   feature_norm=4.47
Iter 9   time=0.11  loss=85035.56 active=216   feature_norm=4.92
Iter 10  ti

CRF(algorithm='lbfgs', all_possible_states=True, all_possible_transitions=True,
    c1=0.1, c2=0.1, keep_tempfiles=None, max_iterations=100, verbose=True)

In [105]:
labels = crf.classes_; labels

['B-outros',
 'I-outros',
 'B-sentenca',
 'I-sentenca',
 'B-peticao_do_RE',
 'I-peticao_do_RE',
 'B-despacho_de_admissibilidade',
 'B-acordao_de_2_instancia',
 'B-agravo_em_recurso_extraordinario',
 'I-agravo_em_recurso_extraordinario',
 'I-acordao_de_2_instancia',
 'I-despacho_de_admissibilidade']

In [107]:
# group B and I results
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
y_pred = crf.predict(X_test)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=4
))

                                    precision    recall  f1-score   support

          B-acordao_de_2_instancia     0.9050    0.9095    0.9073       199
          I-acordao_de_2_instancia     0.9403    0.8514    0.8936        74
B-agravo_em_recurso_extraordinario     0.6442    0.3146    0.4227       213
I-agravo_em_recurso_extraordinario     0.7667    0.4803    0.5906      1628
     B-despacho_de_admissibilidade     0.7757    0.5646    0.6535       147
     I-despacho_de_admissibilidade     0.2857    0.1569    0.2025        51
                          B-outros     0.8917    0.2274    0.3623     25744
                          I-outros     0.7161    0.9795    0.8273     59664
                   B-peticao_do_RE     0.8571    0.5000    0.6316       312
                   I-peticao_do_RE     0.9073    0.6810    0.7780      6019
                        B-sentenca     0.8472    0.6906    0.7609       265
                        I-sentenca     0.9658    0.7240    0.8276      1210

          

In [112]:
valid_fold = np.array([-1]*len(X_train) + [0]*len(X_valid));  valid_fold.shape

(4636,)

In [117]:
X_train_val = np.concatenate([X_train, X_valid]); X_train_val.shape



(4636,)

In [118]:
y_train_val = np.concatenate([y_train, y_valid]); y_train_val.shape

(4636,)

In [126]:
import scipy
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score, RandomizedSearchCV, PredefinedSplit


# define fixed parameters and parameters to search
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    max_iterations=100,
    all_possible_transitions=True,
    all_possible_states=True
)
params_space = {
    'c1': scipy.stats.expon(scale=0.5),
    'c2': scipy.stats.expon(scale=0.05),
}
ps = PredefinedSplit(valid_fold)

# use the same metric for evaluation
f1_scorer = make_scorer(metrics.flat_f1_score,
                        average='macro', labels=labels)

# search
rs = RandomizedSearchCV(crf, params_space,
                        cv=ps,
                        verbose=1,
                        n_jobs=-1,
                        n_iter=150,
                        scoring=f1_scorer,
                        random_state=42)
rs.fit(X_train_val, y_train_val)

Fitting 1 folds for each of 150 candidates, totalling 150 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  42 tasks      | elapsed:  3.1min
[Parallel(n_jobs=-1)]: Done 150 out of 150 | elapsed: 10.6min finished


RandomizedSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ...,  0,  0])),
                   estimator=CRF(algorithm='lbfgs', all_possible_states=True,
                                 all_possible_transitions=True,
                                 keep_tempfiles=None, max_iterations=100),
                   n_iter=150, n_jobs=-1,
                   param_distributions={'c1': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f0731573490>,
                                        'c2': <scipy.stats._distn_infrastruct...
                   random_state=42,
                   scoring=make_scorer(flat_f1_score, average=macro, labels=['B-outros', 'I-outros', 'B-sentenca', 'I-sentenca', 'B-peticao_do_RE', 'I-peticao_do_RE', 'B-despacho_de_admissibilidade', 'B-acordao_de_2_instancia', 'B-agravo_em_recurso_extraordinario', 'I-agravo_em_recurso_extraordinario', 'I-acordao_de_2_instancia', 'I-despacho_de_admissibilidade']),
                   verbose=1)

In [127]:
print('best params:', rs.best_params_)
print('best CV score:', rs.best_score_)
print('model size: {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000)) 

best params: {'c1': 0.5575775051486467, 'c2': 0.07169438131197314}
best CV score: 0.6822817430568014
model size: 0.01M


In [128]:
crf = sklearn_crfsuite.CRF(
    verbose=True,
    algorithm='lbfgs',
    c1=rs.best_params_['c1'],
    c2=rs.best_params_['c2'],
    max_iterations=1000,
    all_possible_transitions=True,
    all_possible_states=True
)
crf.fit(X_train, y_train)

loading training data to CRFsuite: 100%|██████████| 2743/2743 [00:00<00:00, 4367.83it/s]



Feature generation
type: CRF1d
feature.minfreq: 0.000000
feature.possible_states: 1
feature.possible_transitions: 1
0....1....2....3....4....5....6....7....8....9....10
Number of features: 216
Seconds required: 0.284

L-BFGS optimization
c1: 0.557578
c2: 0.071694
num_memories: 6
max_iterations: 1000
epsilon: 0.000010
stop: 10
delta: 0.000010
linesearch: MoreThuente
linesearch.max_iterations: 20

Iter 1   time=0.20  loss=248615.06 active=216   feature_norm=1.00
Iter 2   time=0.30  loss=160541.65 active=216   feature_norm=3.14
Iter 3   time=0.10  loss=122386.80 active=216   feature_norm=3.02
Iter 4   time=0.20  loss=116023.85 active=214   feature_norm=3.49
Iter 5   time=0.10  loss=103660.65 active=216   feature_norm=3.47
Iter 6   time=0.10  loss=99774.66 active=215   feature_norm=3.68
Iter 7   time=0.10  loss=96440.96 active=216   feature_norm=3.95
Iter 8   time=0.10  loss=92167.45 active=215   feature_norm=4.47
Iter 9   time=0.10  loss=85053.85 active=216   feature_norm=4.92
Iter 10  t

CRF(algorithm='lbfgs', all_possible_states=True, all_possible_transitions=True,
    c1=0.5575775051486467, c2=0.07169438131197314, keep_tempfiles=None,
    max_iterations=1000, verbose=True)

In [129]:
# group B and I results
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
y_pred = crf.predict(X_test)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=4
))



                                    precision    recall  f1-score   support

          B-acordao_de_2_instancia     0.9188    0.9095    0.9141       199
          I-acordao_de_2_instancia     0.9841    0.8378    0.9051        74
B-agravo_em_recurso_extraordinario     0.5723    0.4272    0.4892       213
I-agravo_em_recurso_extraordinario     0.7429    0.5147    0.6081      1628
     B-despacho_de_admissibilidade     0.7636    0.5714    0.6537       147
     I-despacho_de_admissibilidade     0.3333    0.1765    0.2308        51
                          B-outros     0.8857    0.2295    0.3645     25744
                          I-outros     0.7185    0.9771    0.8281     59664
                   B-peticao_do_RE     0.8308    0.5353    0.6511       312
                   I-peticao_do_RE     0.8978    0.6993    0.7862      6019
                        B-sentenca     0.8060    0.7057    0.7525       265
                        I-sentenca     0.9693    0.7298    0.8326      1210

          

In [130]:
y_pred_class = []
y_test_class = []

for i, sequence in enumerate(y_pred):
    y_pred_class.append([])
    for j, pred in enumerate(sequence):
        y_pred_class[i].append(pred[2:])
        
for i, sequence in enumerate(y_test):
    y_test_class.append([])
    for j, pred in enumerate(sequence):
        y_test_class[i].append(pred[2:])

In [131]:
sorted_labels=['acordao_de_2_instancia', 'agravo_em_recurso_extraordinario',
                'despacho_de_admissibilidade', 'outros', 'peticao_do_RE', 'sentenca']

print(metrics.flat_classification_report(
    y_test_class, y_pred_class, labels=sorted_labels, digits=4
))



                                  precision    recall  f1-score   support

          acordao_de_2_instancia     0.9385    0.8938    0.9156       273
agravo_em_recurso_extraordinario     0.7382    0.5160    0.6074      1841
     despacho_de_admissibilidade     0.7664    0.5303    0.6269       198
                          outros     0.9634    0.9905    0.9767     85408
                   peticao_do_RE     0.9000    0.6950    0.7843      6331
                        sentenca     0.9554    0.7403    0.8342      1475

                        accuracy                         0.9567     95526
                       macro avg     0.8770    0.7277    0.7909     95526
                    weighted avg     0.9542    0.9567    0.9538     95526



In [132]:
with open("models/crf.pkl", "wb") as file:
    pickle.dump(crf, file)

In [134]:
 from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:]) 

Top likely transitions:
I-peticao_do_RE -> I-peticao_do_RE 6.465655
I-agravo_em_recurso_extraordinario -> I-agravo_em_recurso_extraordinario 6.333905
B-peticao_do_RE -> I-peticao_do_RE 6.311590
B-agravo_em_recurso_extraordinario -> I-agravo_em_recurso_extraordinario 6.019876
I-sentenca -> I-sentenca 5.269288
I-outros -> I-outros 5.229679
B-acordao_de_2_instancia -> I-acordao_de_2_instancia 4.839585
B-despacho_de_admissibilidade -> I-despacho_de_admissibilidade 4.668769
B-sentenca -> I-sentenca 4.578805
I-acordao_de_2_instancia -> I-acordao_de_2_instancia 4.463847
I-despacho_de_admissibilidade -> I-despacho_de_admissibilidade 3.983685
I-outros -> B-sentenca 2.099486
I-outros -> B-peticao_do_RE 1.381329
I-outros -> B-acordao_de_2_instancia 1.225137
I-outros -> B-agravo_em_recurso_extraordinario 1.192601
B-outros -> I-outros 1.073978
B-despacho_de_admissibilidade -> B-despacho_de_admissibilidade 1.011868
I-despacho_de_admissibilidade -> B-despacho_de_admissibilidade 0.528252
I-outros -> B

In [135]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, sorted_labels[int(attr)]))

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])

Top positive:
7.199355 B-outros outros
6.367438 B-despacho_de_admissibilidade despacho_de_admissibilidade
6.260229 B-acordao_de_2_instancia acordao_de_2_instancia
4.636800 I-despacho_de_admissibilidade despacho_de_admissibilidade
4.585840 B-sentenca sentenca
4.401971 B-outros despacho_de_admissibilidade
4.224357 B-outros acordao_de_2_instancia
4.119760 B-agravo_em_recurso_extraordinario agravo_em_recurso_extraordinario
3.692636 B-peticao_do_RE peticao_do_RE
3.677365 B-outros sentenca
3.666268 I-acordao_de_2_instancia acordao_de_2_instancia
2.911459 I-sentenca sentenca
2.752903 B-outros agravo_em_recurso_extraordinario
2.046825 B-outros peticao_do_RE
1.511294 I-agravo_em_recurso_extraordinario agravo_em_recurso_extraordinario
0.912853 I-outros peticao_do_RE
0.617054 I-peticao_do_RE peticao_do_RE
0.392794 I-outros outros
-0.001515 I-peticao_do_RE acordao_de_2_instancia
-0.040824 I-agravo_em_recurso_extraordinario sentenca
-0.053651 B-sentenca despacho_de_admissibilidade
-0.070852 B-petic