In [1]:
import torch
import numpy as np
import pandas as pd
from mlp import MLP
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score

In [2]:
np.random.seed(0)

In [3]:
def load_file(file_path, include_target=True):
    data = []
    sentence_id = 0

    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if "-DOCSTART-" in line:
                continue  # skip the docstart lines
            if line == "":
                sentence_id += 1  # Increment sentence ID at the end of each sentence
                continue
            if include_target:
                word, target = line.split()
                data.append((word, target, sentence_id))
            else:
                data.append((line, sentence_id))

    # Create a DataFrame
    if include_target:
        return pd.DataFrame(data, columns=['Word', 'Target', 'Sentence'])
    return pd.DataFrame(data, columns=['Word', 'Sentence'])


In [4]:
def random_embeddings(df, embedding_dim, empty_embeddings):
    embeddings = {word: np.random.rand(embedding_dim) for word in df['Word'].unique()}
    embeddings[''] = empty_embeddings
    return pd.DataFrame(list(embeddings.items()), columns=['Word','Embeddings'])

In [5]:
def create_context(df, window_size=2, include_target=True):
    context_data = []

    # Iterate through each sentence
    for sentence_id in df['Sentence'].unique():
        sentence_df = df[df['Sentence'] == sentence_id]
        words = sentence_df['Word'].tolist()
        if include_target:
            ners = sentence_df['Target'].tolist()
        
        # Create context for each word
        for i in range(len(words)):
            context = []
            for j in range(-window_size, window_size + 1):
                if 0 <= i + j < len(words):
                    context.append(words[i + j])
                else:
                    context.append('')
            if include_target:
                context.append(ners[i])
            context_data.append(context)

    # Create a DataFrame
    columns = [f'word_{i}' for i in range(-window_size, window_size + 1)]
    if include_target:
        columns += ['Target']
    context_df = pd.DataFrame(context_data, columns=columns)
    return context_df


In [6]:
def add_embeddings(df, embeddings_df, default_embeddings, window_size=2):
    for i in range(-window_size, window_size+1):
        word_col = f'word_{i}'
        df = df.merge(embeddings_df, how='left', left_on=word_col, right_on='Word')
        df.drop(columns=['Word'], inplace=True)
        df.loc[df['Embeddings'].isna(), 'Embeddings'] = df[df['Embeddings'].isna()].apply(lambda _: default_embeddings, axis=1)
        df.rename(columns={'Embeddings': f'Embeddings_{i}'}, inplace=True)

    return df

In [7]:
def concat_embeddings(df):
    def concat(row):
        return np.concatenate([row[c] for c in df.columns if 'Embeddings' in c])
        
    df['Embeddings'] = df.apply(concat, axis=1)
    return df

In [8]:
def load_data(file_path, vocablary, window_size=2):
    df = load_file('ner/dev', include_target=True)
    df = create_context(df, window_size, include_target=True)
    df = add_embeddings(df, vocablary, np.ones(embedding_dim), window_size)
    df = concat_embeddings(df)
    return df

In [9]:
embedding_dim = 50
vocablary = random_embeddings(load_file('ner/train'), embedding_dim, np.zeros(embedding_dim))
vocablary

Unnamed: 0,Word,Embeddings
0,EU,"[0.5488135039273248, 0.7151893663724195, 0.602..."
1,rejects,"[0.5701967704178796, 0.43860151346232035, 0.98..."
2,German,"[0.6778165367962301, 0.27000797319216485, 0.73..."
3,call,"[0.14944830465799375, 0.8681260573682142, 0.16..."
4,to,"[0.3117958819941026, 0.6963434888154595, 0.377..."
...,...,...
23619,217,"[0.7443155169127373, 0.17838038955337676, 0.19..."
23620,Swe,"[0.4201910042326812, 0.7328467101191916, 0.349..."
23621,Bradley,"[0.06132456847139178, 0.5485009220847855, 0.20..."
23622,Hughes,"[0.686228831252667, 0.43195766289570403, 0.858..."


In [10]:
df_train = load_data('ner/train', vocablary)
df_train

Unnamed: 0,word_-2,word_-1,word_0,word_1,word_2,Target,Embeddings_-2,Embeddings_-1,Embeddings_0,Embeddings_1,Embeddings_2,Embeddings
0,,,CRICKET,-,LEICESTERSHIRE,O,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3084179063126684, 0.8723737149804298, 0.049...","[0.8158313019092511, 0.7218350953831094, 0.832...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,,CRICKET,-,LEICESTERSHIRE,TAKE,O,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3084179063126684, 0.8723737149804298, 0.049...","[0.8158313019092511, 0.7218350953831094, 0.832...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.8172612817463183, 0.2791303964823778, 0.727...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,CRICKET,-,LEICESTERSHIRE,TAKE,OVER,ORG,"[0.3084179063126684, 0.8723737149804298, 0.049...","[0.8158313019092511, 0.7218350953831094, 0.832...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.8172612817463183, 0.2791303964823778, 0.727...","[0.25311905844002736, 0.49469351115011506, 0.1...","[0.3084179063126684, 0.8723737149804298, 0.049..."
3,-,LEICESTERSHIRE,TAKE,OVER,AT,O,"[0.8158313019092511, 0.7218350953831094, 0.832...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.8172612817463183, 0.2791303964823778, 0.727...","[0.25311905844002736, 0.49469351115011506, 0.1...","[0.3221873547257291, 0.10051427962582815, 0.86...","[0.8158313019092511, 0.7218350953831094, 0.832..."
4,LEICESTERSHIRE,TAKE,OVER,AT,TOP,O,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.8172612817463183, 0.2791303964823778, 0.727...","[0.25311905844002736, 0.49469351115011506, 0.1...","[0.3221873547257291, 0.10051427962582815, 0.86...","[0.14286211686287031, 0.4800564078207882, 0.21...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...
51357,on,Thursday,.,,,O,"[0.8890599531897286, 0.7372785797141679, 0.005...","[0.8108386151289514, 0.3481919427465201, 0.211...","[0.4012595008036087, 0.9292914173027139, 0.099...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.8890599531897286, 0.7372785797141679, 0.005..."
51358,,,--,Dhaka,Newsroom,O,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.6533942786712941, 0.496042060304297, 0.0276...","[0.5202697323431233, 0.10982671160240787, 0.10...","[0.12428466610588107, 0.5390241453658275, 0.14...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
51359,,--,Dhaka,Newsroom,880-2-506363,ORG,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.6533942786712941, 0.496042060304297, 0.0276...","[0.5202697323431233, 0.10982671160240787, 0.10...","[0.12428466610588107, 0.5390241453658275, 0.14...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
51360,--,Dhaka,Newsroom,880-2-506363,,ORG,"[0.6533942786712941, 0.496042060304297, 0.0276...","[0.5202697323431233, 0.10982671160240787, 0.10...","[0.12428466610588107, 0.5390241453658275, 0.14...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.6533942786712941, 0.496042060304297, 0.0276..."


In [11]:
df_dev = load_data('ner/dev', vocablary)
df_dev

Unnamed: 0,word_-2,word_-1,word_0,word_1,word_2,Target,Embeddings_-2,Embeddings_-1,Embeddings_0,Embeddings_1,Embeddings_2,Embeddings
0,,,CRICKET,-,LEICESTERSHIRE,O,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3084179063126684, 0.8723737149804298, 0.049...","[0.8158313019092511, 0.7218350953831094, 0.832...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,,CRICKET,-,LEICESTERSHIRE,TAKE,O,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3084179063126684, 0.8723737149804298, 0.049...","[0.8158313019092511, 0.7218350953831094, 0.832...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.8172612817463183, 0.2791303964823778, 0.727...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,CRICKET,-,LEICESTERSHIRE,TAKE,OVER,ORG,"[0.3084179063126684, 0.8723737149804298, 0.049...","[0.8158313019092511, 0.7218350953831094, 0.832...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.8172612817463183, 0.2791303964823778, 0.727...","[0.25311905844002736, 0.49469351115011506, 0.1...","[0.3084179063126684, 0.8723737149804298, 0.049..."
3,-,LEICESTERSHIRE,TAKE,OVER,AT,O,"[0.8158313019092511, 0.7218350953831094, 0.832...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.8172612817463183, 0.2791303964823778, 0.727...","[0.25311905844002736, 0.49469351115011506, 0.1...","[0.3221873547257291, 0.10051427962582815, 0.86...","[0.8158313019092511, 0.7218350953831094, 0.832..."
4,LEICESTERSHIRE,TAKE,OVER,AT,TOP,O,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.8172612817463183, 0.2791303964823778, 0.727...","[0.25311905844002736, 0.49469351115011506, 0.1...","[0.3221873547257291, 0.10051427962582815, 0.86...","[0.14286211686287031, 0.4800564078207882, 0.21...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...
51357,on,Thursday,.,,,O,"[0.8890599531897286, 0.7372785797141679, 0.005...","[0.8108386151289514, 0.3481919427465201, 0.211...","[0.4012595008036087, 0.9292914173027139, 0.099...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.8890599531897286, 0.7372785797141679, 0.005..."
51358,,,--,Dhaka,Newsroom,O,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.6533942786712941, 0.496042060304297, 0.0276...","[0.5202697323431233, 0.10982671160240787, 0.10...","[0.12428466610588107, 0.5390241453658275, 0.14...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
51359,,--,Dhaka,Newsroom,880-2-506363,ORG,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.6533942786712941, 0.496042060304297, 0.0276...","[0.5202697323431233, 0.10982671160240787, 0.10...","[0.12428466610588107, 0.5390241453658275, 0.14...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
51360,--,Dhaka,Newsroom,880-2-506363,,ORG,"[0.6533942786712941, 0.496042060304297, 0.0276...","[0.5202697323431233, 0.10982671160240787, 0.10...","[0.12428466610588107, 0.5390241453658275, 0.14...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.6533942786712941, 0.496042060304297, 0.0276..."


In [12]:
encoder = OneHotEncoder()
encoder = encoder.fit(df_train['Target'].unique().reshape(-1,1))

X_train = torch.tensor(np.stack(df_train['Embeddings']), dtype=torch.float32)
y_train = torch.tensor(encoder.transform(df_train[['Target']].values).toarray(), dtype=torch.float32)

X_dev = torch.tensor(np.stack(df_dev['Embeddings']), dtype=torch.float32)
y_dev = torch.tensor(encoder.transform(df_dev[['Target']].values).toarray(), dtype=torch.float32)

In [13]:
dataset = TensorDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)

In [14]:
class_counts = torch.bincount(torch.argmax(y_train, dim=1))
class_weights = 1.0 / class_counts.float()
class_weights /= class_weights.max()
# class_weights = class_weights ** 0.75
print("Class Weights:", class_weights)

Class Weights: tensor([0.6055, 1.0000, 0.0297, 0.6061, 0.4027])


In [15]:
model = MLP(len(X_train[0]), len(y_train[0]), hidden_layer_size=1024)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [16]:
def compute_metrics(model, X, y, target_to_ignore):
    outputs = model(X)
    loss = criterion(outputs, y)

    y = torch.argmax(y, axis=1)
    outputs = torch.argmax(outputs.detach(), axis=1)
    acc = accuracy_score(y, outputs)

    mask = ~((y==outputs) & (y==target_to_ignore))
    return loss, acc, accuracy_score(y[mask], outputs[mask])

In [None]:
o_position = np.argmax(encoder.transform([['O']]).toarray()[0])
num_epochs = 50

for epoch in range(num_epochs):
    for X_batch, y_batch in dataloader:
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        train_loss, train_acc, train_weird_acc = compute_metrics(model, X_train, y_train, o_position)
        dev_loss, dev_acc, dev_weird_acc = compute_metrics(model, X_dev, y_dev, o_position)
        
        print(f'Epoch [{epoch+1}/{num_epochs}]:')
        print(f'Train) Loss: {train_loss.item():.4f}, Accuracy: {train_acc}, No O acc: {train_weird_acc}')
        print(f'Dev) Loss: {dev_loss.item():.4f}, Accuracy: {dev_acc}, No O acc: {dev_weird_acc}')

Epoch [1/50]:
Train) Loss: 0.1883, Accuracy: 0.5402632296250146, No O acc: 0.12857511901686533
Dev) Loss: 0.1883, Accuracy: 0.5402632296250146, No O acc: 0.12857511901686533
Epoch [2/50]:
Train) Loss: 0.1810, Accuracy: 0.45506405513803977, No O acc: 0.11890071145249638
Dev) Loss: 0.1810, Accuracy: 0.45506405513803977, No O acc: 0.11890071145249638
Epoch [3/50]:
Train) Loss: 0.1776, Accuracy: 0.4536427709201355, No O acc: 0.12374707259953162
Dev) Loss: 0.1776, Accuracy: 0.4536427709201355, No O acc: 0.12374707259953162
Epoch [4/50]:
Train) Loss: 0.1746, Accuracy: 0.4997274249445115, No O acc: 0.14158286840610698
Dev) Loss: 0.1746, Accuracy: 0.4997274249445115, No O acc: 0.14158286840610698
Epoch [5/50]:
Train) Loss: 0.1732, Accuracy: 0.5074958140259336, No O acc: 0.14462516484631252
Dev) Loss: 0.1732, Accuracy: 0.5074958140259336, No O acc: 0.14462516484631252
Epoch [6/50]:
Train) Loss: 0.1716, Accuracy: 0.5275106109575173, No O acc: 0.15294938917975567
Dev) Loss: 0.1716, Accuracy: 0.52

In [None]:
cm = confusion_matrix(np.argmax(y.numpy(), axis=1), np.argmax(outputs.detach().numpy(), axis=1))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=encoder.categories_[0])
disp.plot(cmap=plt.cm.Blues)
plt.show()