Using kernel `conda_pytorch_latest_p36`

In [1]:
# !pip install fastai
# !pip install cloudpathlib

# Import

In [2]:
import sys
sys.path.append('../../../')

In [3]:
from pathlib import Path
import os
import random
import json
from datetime import datetime
from collections import defaultdict

In [4]:
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report
import torch
import sagemaker
from sagemaker import get_execution_role
import boto3
import torch.nn as nn
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
from fastai.text.all import *
from sklearn import metrics
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, multilabel_confusion_matrix

In [5]:
from deep.constants import *
from deep.utils import *

In [6]:
%load_ext autoreload
%autoreload 2

# Data

In [9]:
def preprocessing(df):
    df = df.copy()
    df['sectors'] = df['sectors'].apply(eval)    
    df['pillars'] = df['pillars'].apply(eval)
    df['pillars'] = df['pillars'].apply(lambda x: list(set(x)))
    df['subpillars'] = df['subpillars'].apply(eval)
    return df

In [14]:
column = 'subpillars'
classes=SUBPILLARS
text_column = 'excerpt'
merge_column = 'merge'

In [11]:
train = preprocessing(pd.read_csv(LATEST_DATA_PATH / 'data_v0.4.3_train.csv'))
val = preprocessing(pd.read_csv(LATEST_DATA_PATH / 'data_v0.4.3_val.csv'))
test = preprocessing(pd.read_csv(LATEST_DATA_PATH / 'data_v0.4.3_test.csv'))

In [15]:
def process_multiclass(df, train, column='pillars', classes=PILLARS):
    relevant_train = df.copy()
    relevant_train = relevant_train[relevant_train[column].apply(len) > 0]
    relevant_train[merge_column] = relevant_train.pillars + relevant_train.subpillars
    relevant_train[merge_column] = relevant_train[merge_column].apply(lambda x: ';'.join(x)) 
    
    
    relevant_train['is_valid'] = False if train else True
    
    return relevant_train
    
    

# Use fastai

In [16]:
train_df = process_multiclass(train, True, column=column, classes=classes)
test_df = process_multiclass(test, False, column=column, classes=classes)
df = pd.concat([train_df, test_df])

In [17]:
train_df

Unnamed: 0,entry_id,lead_id,project_id,project_title,analysis_framework_id,excerpt,dropped_excerpt,created_by_id,modified_by_id,verified,verification_last_changed_by_id,sectors,pillars,subpillars,merge,is_valid
0,163664,35315,2028,IMMAP/DFS Syria,1306,Market monitoring by the World Food Programme recorded a 48 per cent increase in the average price of a standard reference food basket between May and June. Food prices are 240 per cent higher than in June last year.,,2232,2232,False,,[Food Security],[Impact],[Impact->Impact On Systems And Services],Impact;Impact->Impact On Systems And Services,False
1,162812,37820,2098,IMMAP/DFS Bangladesh,1306,Quarantine Facilities: ninety-three shelters in Camp 20 Extension are currently operating as a Quarantine Facility for contacts of confirmed cases to support early containment of the outbreak. IOM is working with World Concern/ Medair who are providing dedicated Community Health Workers (CHWs) to carry out contact follow-up and health check-up services.,,657,2233,False,,[Health],[Capacities & Response],[Capacities & Response->International Response],Capacities & Response;Capacities & Response->International Response,False
2,164560,39796,2098,IMMAP/DFS Bangladesh,1306,"Within dimensions, markets are broadly operating at high functionality across supply chain indicators (availability of stock and resilience) but at low to moderate capacity in terms of assortment and price volatility.",,1152,1152,False,,[Cross],[Impact],[Impact->Impact On Systems And Services],Impact;Impact->Impact On Systems And Services,False
3,157496,38706,2098,IMMAP/DFS Bangladesh,1306,"Frontline aid workers face a heightened risk of COVID-19 infection.Since September 2019, the Government of Bangladesh has suspended 3G and 4G mobile networks and internet access in the Rohingya settlements. These restrictions have hindered the rapid dissemination of important public health messages related to COVID-19 targeting both Rohingya refugees and Bangladeshis, as well as their ability to stay connected with family and loved ones",,2233,2233,False,,"[Health, Logistics]","[People At Risk, Impact]","[Impact->Driver/Aggravating Factors, People At Risk->Risk And Vulnerabilities, Impact->Impact On Systems And Services, Impact->Impact On People]",People At Risk;Impact;Impact->Driver/Aggravating Factors;People At Risk->Risk And Vulnerabilities;Impact->Impact On Systems And Services;Impact->Impact On People,False
5,162971,37820,2098,IMMAP/DFS Bangladesh,1306,"IOM MHPSS teams supported in coordinating and providing trainings on Protection from Sexual Exploitations and Abuse (PSEA), basic psychological skills from ITC Interpreters, mental health and psycho-social support during COVID-19 pandemic. Stress management, and the Mental Health Gap Action Programme (mhGAP) during the month seeking to build the capacity of MHPSS volunteers, teachers, doctors and community leaders. A total of 214 participants benefitted from the trainings.",,657,2233,False,,[Health],[Capacities & Response],"[Capacities & Response->Number Of People Reached, Capacities & Response->International Response]",Capacities & Response;Capacities & Response->Number Of People Reached;Capacities & Response->International Response,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90643,283333,51241,2170,IMMAP/DFS Nigeria,1306,"[16th Mar 2021,North east Nigeria]Levels of food insecurity and malnutrition across the BAY states remain of public-health significance, and are predicted to worsen, with conflict causing mass population displacement, sub-optimal access to primary health care services, poor hygiene and sanitation conditions (including poor child feeding and care environments) and high food prices being the main drivers. The food security situation has deteriorated significantly in 2020.",,2230,1152,True,1152.0,"[Food Security, Nutrition]","[Humanitarian Conditions, People At Risk, Impact]","[Impact->Driver/Aggravating Factors, People At Risk->Risk And Vulnerabilities, Humanitarian Conditions->Physical And Mental Well Being]",Humanitarian Conditions;People At Risk;Impact;Impact->Driver/Aggravating Factors;People At Risk->Risk And Vulnerabilities;Humanitarian Conditions->Physical And Mental Well Being,False
90647,282475,51241,2170,IMMAP/DFS Nigeria,1306,"[16th Mar 2021,North east Nigeria] Lower availability of labour due to the conflict and the pandemic, in addition to higher prices and security-based restrictions on transportation of nitrate-based fertilizer, have reduced production. The government of Borno State, the epicentre of the conflict, estimates that the insurgency has caused $6 billion worth of destruction in the State .",,2230,1152,True,1152.0,"[Agriculture, Livelihoods]","[Humanitarian Conditions, Impact]","[Impact->Driver/Aggravating Factors, Humanitarian Conditions->Living Standards, Impact->Impact On Systems And Services]",Humanitarian Conditions;Impact;Impact->Driver/Aggravating Factors;Humanitarian Conditions->Living Standards;Impact->Impact On Systems And Services,False
90648,282949,51241,2170,IMMAP/DFS Nigeria,1306,"[16th Mar 2021,North east Nigeria]The governments of BAY states imposed a three-week lockdown from mid-April 2020 to slow the spread of COVID-19. The measures, which comprised restrictions on movement and closures of markets and inter-state borders, severely impaired delivery and access to critical services by government and humanitarian partners. These restrictions seem to have contained the spread of COVID-19 into the highly congested IDP camps and communities, thus likely preventing significant loss of lives",,2230,26,True,26.0,[Cross],[Impact],[Impact->Impact On Systems And Services],Impact;Impact->Impact On Systems And Services,False
90649,283375,51241,2170,IMMAP/DFS Nigeria,1306,"[16th Mar 2021,North east Nigeria] Impact on systems and services: Insecurity due to NSAG activities continues to limit the presence of civilian authorities outside state capitals, especially in northern Borno State, leaving people reliant on humanitarian aid. Basic social infrastructure and services including the police and judiciary, access to health care, education and livelihoods, and civilian personnel such as teachers, nurses, doctors and civil administrators are lacking in most locations. However, there is more frequent presence of local government officials in some LGAs (mostly in ...",,2230,1152,True,1152.0,"[Health, Education, Protection, Livelihoods]","[Humanitarian Conditions, Impact]","[Impact->Driver/Aggravating Factors, Humanitarian Conditions->Living Standards]",Humanitarian Conditions;Impact;Impact->Driver/Aggravating Factors;Humanitarian Conditions->Living Standards,False


In [20]:
dls = TextDataLoaders.from_df(
    df,
    text_col=text_column,
    label_col=merge_column,
    label_delim=';',
    valid_col='is_valid',
    is_lm = False,    # Mention explicitly that this dataloader is meant for language model
    seq_len = 72,     # Pick a sequence length i.e. how many words to feed through the RNN at once
    bs = 64,     # Specify the batch size for the dataloader
    y_block=MultiCategoryBlock,
)
learn = text_classifier_learner(
    dls, 
    AWD_LSTM, 
    drop_mult=0.5, 
    metrics=[
        accuracy_multi, 
        RecallMulti(thresh=0.35), 
        PrecisionMulti(thresh=0.35), 
        F1ScoreMulti(thresh=0.35), 
        RocAucMulti()
    ]
)
classes = learn.dls.vocab[1]

  return array(a, dtype, copy=False, order=order)


In [None]:
learn.fine_tune(3, 0.02)

epoch,train_loss,valid_loss,accuracy_multi,recall_score,precision_score,f1_score,roc_auc_score,time
0,0.241613,0.23542,0.907648,0.244468,0.357179,0.246176,0.769242,02:29


  _warn_prf(average, modifier, msg_start, len(result))


epoch,train_loss,valid_loss,accuracy_multi,recall_score,precision_score,f1_score,roc_auc_score,time
0,0.213281,0.210538,0.915581,0.321233,0.48535,0.355466,0.834945,06:24


  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
def get_threshold_metrics(preds, targets, num_thresholds):
    thresholds = [x/num_thresholds for x in range(num_thresholds)]

    recalls = []
    precisions = []
    f1_scores = []

    for threshold in thresholds:
        t_preds = (preds.view(-1).numpy() > threshold).astype(int)
        t_targets = targets.view(-1).numpy()
        precisions.append(precision_score(t_targets, t_preds))
        recalls.append(recall_score(t_targets, t_preds))    
        f1_scores.append(f1_score(t_targets, t_preds))        
        
    all_metrics = pd.DataFrame(
        {
            'threshold': thresholds,
            'recall': recalls,
            'precision': precisions,
            'f1_score': f1_scores
        }
    ).set_index('threshold', drop=True)

    return all_metrics

In [None]:
def get_best_threshold(learner, num_thresholds=20):
    train_preds, train_targets = learn.get_preds(0)
    train_metrics = get_threshold_metrics(train_preds, train_targets, num_thresholds)
    best_threshold = train_metrics.f1_score.idxmax()
    
    return best_threshold, train_metrics

In [None]:
def get_metrics(preds, targets):
    indexes = []
    recalls = []
    precisions = []
    f1_scores = []
    counts = []
    
    for i, class_ in enumerate(classes):
        class_preds = preds[:, i]
        class_targets = targets[:, i]
        
        indexes.append(class_)
        precisions.append(precision_score(class_targets, class_preds))
        recalls.append(recall_score(class_targets, class_preds))    
        f1_scores.append(f1_score(class_targets, class_preds))  
        counts.append(int(class_targets.sum()))
        
    indexes.append('all')
    precisions.append(precision_score(targets.view(-1), preds.view(-1)))
    recalls.append(recall_score(targets.view(-1), preds.view(-1)))    
    
    f1_scores.append(f1_score(targets.view(-1), preds.view(-1)))  
    counts.append(int(targets.view(-1).sum()))



    
    all_metrics = pd.DataFrame(
        {
            'class': indexes,
            'recall': recalls,
            'precision': precisions,
            'f1_score': f1_scores,
            'counts': counts
        }
    ).set_index('class', drop=True)
    
    return all_metrics

In [None]:
def evaluate(learner, threshold):
    test_preds, test_targets = learn.get_preds(1)
    

In [None]:
best_threshold, train_metrics = get_best_threshold(learn)
print(best_threshold)

In [None]:
train_metrics

In [None]:
test_preds, test_targets = learn.get_preds(1)
test_discrete_preds = (test_preds > best_threshold).int()
test_discrete_targets = test_targets.int()

In [None]:
test_preds

In [None]:
test_targets

In [None]:
count = defaultdict(int)
a = test_df.pillars.apply(lambda x: x.split(';'))
for x in a:
    for y in x:
        count[y] += 1

In [None]:
multilabel_confusion_matrix(test_discrete_preds, test_discrete_targets, samplewise=False)

In [None]:
multi_label_metrics = get_metrics(test_discrete_preds,test_discrete_targets)

In [None]:
multi_label_metrics

In [None]:
multi_label_metrics.plot(figsize=(20, 10), xticks=range(len(classes)+1), yticks=[x/10 for x in range(11)], ylim=(0, 1), grid=True)

In [None]:
base = Path('/home/ec2-user/SageMaker/experiments-dfs/models/fastai/results')

In [None]:
with open(base / 'multi_label_metrics.pickle', 'wb') as f:
    pickle.dump(multi_label_metrics, f)

# Analysis

In [None]:
tp = test_discrete_preds
tt = test_discrete_targets

In [None]:
test_df

In [None]:
start = 0

In [None]:
for sentence, sector,p, t, in zip(
    test_df.sentence_text.iloc[start:], 
    test_df.sector_ids.iloc[start:], 
    tp[start:], 
    tt[start:]
):
    if 1 or list(p) != list(t):
        print(sentence)
        t = [classes[i] for i, x in enumerate(t) if x]
        print('Expected:', ', '.join(t))
        p = [classes[i] for i, x in enumerate(p) if x]
        print('Predicted:', ', '.join(p))
    else:
        continue
    
    a = input()
    if a == 's':
        break
    