In [None]:
!pip install sklearn_crfsuite

Collecting sklearn_crfsuite
  Downloading sklearn_crfsuite-0.5.0-py2.py3-none-any.whl (10 kB)
Installing collected packages: sklearn_crfsuite
Successfully installed sklearn_crfsuite-0.5.0




In [None]:
!pip install joblib

ERROR: Could not find a version that satisfies the requirement pickle (from versions: none)
ERROR: No matching distribution found for pickle


In [None]:
import os
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import string
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import numpy as np
import sklearn_crfsuite
from sklearn_crfsuite import metrics
import joblib

nltk.download('punkt')
nltk.download('stopwords')

def read_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as file:
        return file.read()

def parse_ann(ann_content):
    annotations = []
    for line in ann_content.strip().split('\n'):
        if line.startswith('T'):
            parts = line.split('\t')
            ann_id = parts[0]
            label_info = parts[1]
            text = parts[2]
            label_info_parts = label_info.split()
            label = label_info_parts[0]
            start = int(label_info_parts[1].split(';')[0])
            end = int(label_info_parts[2].split(';')[0])
            annotations.append({
                'id': ann_id,
                'label': label,
                'start': start,
                'end': end,
                'text': text
            })
    return annotations

def preprocess_text(text):
    stop_words = set(stopwords.words('english'))
    tokens = word_tokenize(text)
    filtered_tokens = [token for token in tokens if token.lower() not in stop_words and token not in string.punctuation]
    return filtered_tokens

def format_input(text, annotations):
    tokens = word_tokenize(text)
    token_annotations = ['O'] * len(tokens)

    for ann in annotations:
        ann_tokens = word_tokenize(ann['text'])
        ann_label = ann['label']
        start = ann['start']
        end = ann['end']

        for i in range(len(tokens)):
            if tokens[i:i+len(ann_tokens)] == ann_tokens:
                token_annotations[i] = f'B-{ann_label}'
                for j in range(1, len(ann_tokens)):
                    token_annotations[i + j] = f'I-{ann_label}'
                break

    return tokens, token_annotations

def process_files(txt_file, ann_file):
    text = read_file(txt_file)
    ann_content = read_file(ann_file)
    annotations = parse_ann(ann_content)
    tokens, labels = format_input(text, annotations)
    return tokens, labels

tokend_text = []
cor_labels = []

def process_all_files(directory):
    for filename in os.listdir(directory):
        if filename.endswith(".txt"):
            txt_file = os.path.join(directory, filename)
            ann_file = txt_file.replace(".txt", ".ann")
            if os.path.exists(ann_file):
                tokens, labels = process_files(txt_file, ann_file)
                tokend_text.append(tokens)
                cor_labels.append(labels)

directory = 'n2c2/n2c2/part2'
process_all_files(directory)

labels_list = ["O", "B-Drug", "I-Drug", "B-Strength", "I-Strength", "B-Form", "I-Form", "B-Dosage", "I-Dosage",
               "B-Duration", "I-Duration", "B-Frequency", "I-Frequency", "B-Route", "I-Route", "B-ADE", "I-ADE",
               "B-Reason", "I-Reason"]

label_map = {label: i for i, label in enumerate(labels_list)}
id2label = {i: label for i, label in enumerate(labels_list)}

def word2features(sent, i):
    word = sent[i]
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
    }
    if i > 0:
        word1 = sent[i - 1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
        })
    else:
        features['BOS'] = True

    if i < len(sent) - 1:
        word1 = sent[i + 1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
        })
    else:
        features['EOS'] = True

    return features

def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for label in sent]

def sent2tokens(sent):
    return sent

X = [sent2features(s) for s in tokend_text]
y = [sent2labels(l) for l in cor_labels]

train_texts, test_texts, train_labels, test_labels = train_test_split(tokend_text, cor_labels, test_size=0.1, random_state=42)

kf = KFold(n_splits=10, shuffle=True, random_state=42)
all_metrics = {
    'accuracy': [],
    'precision': [],
    'recall': [],
    'f1_micro': [],
    'f1_macro': [],
    'confusion_matrix': []
}

for fold, (train_index, val_index) in enumerate(kf.split(train_texts)):
    print(f"Fold {fold + 1}")

    X_train = [sent2features(train_texts[i]) for i in train_index]
    y_train = [sent2labels(train_labels[i]) for i in train_index]
    X_val = [sent2features(train_texts[i]) for i in val_index]
    y_val = [sent2labels(train_labels[i]) for i in val_index]

    crf = sklearn_crfsuite.CRF(
        algorithm='lbfgs',
        c1=0.1,
        c2=0.1,
        max_iterations=100,
        all_possible_transitions=False
    )
    crf.fit(X_train, y_train)

    y_pred = crf.predict(X_val)

    flat_true_labels = [item for sublist in y_val for item in sublist]
    flat_pred_labels = [item for sublist in y_pred for item in sublist]

    cm = confusion_matrix(flat_true_labels, flat_pred_labels, labels=labels_list)

    accuracy = accuracy_score(flat_true_labels, flat_pred_labels)
    precision, recall, f1_micro, _ = precision_recall_fscore_support(flat_true_labels, flat_pred_labels, average='micro')
    _, _, f1_macro, _ = precision_recall_fscore_support(flat_true_labels, flat_pred_labels, average='macro')

    all_metrics['accuracy'].append(accuracy)
    all_metrics['precision'].append(precision)
    all_metrics['recall'].append(recall)
    all_metrics['f1_micro'].append(f1_micro)
    all_metrics['f1_macro'].append(f1_macro)
    all_metrics['confusion_matrix'].append(cm)

    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Micro F1: {f1_micro}")
    print(f"Macro F1: {f1_macro}")

metrics_mean_std = {metric: (np.mean(all_metrics[metric]), np.std(all_metrics[metric])) for metric in all_metrics}

print("\nMetrics Mean and Standard Deviation:")
for metric, (mean, std) in metrics_mean_std.items():
    print(f"{metric.capitalize()} - Mean: {mean}, Std: {std}")

X_test = [sent2features(s) for s in test_texts]
y_test = [sent2labels(l) for l in test_labels]

y_pred_test = crf.predict(X_test)

flat_true_labels_test = [item for sublist in y_test for item in sublist]
flat_pred_labels_test = [item for sublist in y_pred_test for item in sublist]

cm_test = confusion_matrix(flat_true_labels_test, flat_pred_labels_test, labels=labels_list)

accuracy_test = accuracy_score(flat_true_labels_test, flat_pred_labels_test)
precision_test, recall_test, f1_micro_test, _ = precision_recall_fscore_support(flat_true_labels_test, flat_pred_labels_test, average='micro')
_, _, f1_macro_test, _ = precision_recall_fscore_support(flat_true_labels_test, flat_pred_labels_test, average='macro')

print("\nTest Set Metrics:")
print(f"Confusion Matrix:\n{cm_test}")
print(f"Accuracy: {accuracy_test}")
print(f"Precision: {precision_test}")
print(f"Recall: {recall_test}")
print(f"Micro F1: {f1_micro_test}")
print(f"Macro F1: {f1_macro_test}")

joblib.dump(crf, 'crf_model.pkl')

crf = joblib.load('crf_model.pkl')


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\smrh1\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\smrh1\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Fold 1
Accuracy: 0.9614500463369721
Precision: 0.9614500463369721
Recall: 0.9614500463369721
Micro F1: 0.9614500463369721
Macro F1: 0.5265026830256874
Fold 2
Accuracy: 0.9608558905632255
Precision: 0.9608558905632255
Recall: 0.9608558905632255
Micro F1: 0.9608558905632255
Macro F1: 0.5081435438696298
Fold 3
Accuracy: 0.9493786442075505
Precision: 0.9493786442075505
Recall: 0.9493786442075505
Micro F1: 0.9493786442075505
Macro F1: 0.512470680999912
Fold 4
Accuracy: 0.9591663122075712
Precision: 0.9591663122075712
Recall: 0.9591663122075712
Micro F1: 0.9591663122075712
Macro F1: 0.5087991167349745
Fold 5
Accuracy: 0.9586858555022533
Precision: 0.9586858555022533
Recall: 0.9586858555022533
Micro F1: 0.9586858555022533
Macro F1: 0.5153474259456265
Fold 6
Accuracy: 0.9581215101258438
Precision: 0.9581215101258438
Recall: 0.9581215101258438
Micro F1: 0.9581215101258438
Macro F1: 0.5123277108693453
Fold 7
Accuracy: 0.9531073190603359
Precision: 0.9531073190603359
Recall: 0.9531073190603359
Mi

In [None]:
from sklearn_crfsuite.metrics import flat_classification_report
y_pred_test = crf.predict(X_test)
report = flat_classification_report(y_pred=y_pred_test, y_true=y_test)
print(report)

              precision    recall  f1-score   support

       B-ADE       0.14      0.01      0.02        74
    B-Dosage       0.66      0.38      0.48       172
      B-Drug       0.70      0.66      0.68      1045
  B-Duration       0.67      0.36      0.46        45
      B-Form       0.69      0.55      0.61       201
 B-Frequency       0.68      0.48      0.56       331
    B-Reason       0.48      0.12      0.19       306
     B-Route       0.43      0.26      0.32       139
  B-Strength       0.68      0.73      0.70       427
       I-ADE       0.38      0.05      0.09        60
    I-Dosage       0.70      0.38      0.49       237
      I-Drug       0.73      0.68      0.70       257
  I-Duration       0.70      0.42      0.52        76
      I-Form       0.81      0.79      0.80       233
 I-Frequency       0.60      0.63      0.61       877
    I-Reason       0.31      0.09      0.14       280
     I-Route       0.75      0.21      0.33        14
  I-Strength       0.65    