In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.backends.cudnn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertModel
from transformers import AutoModel, AutoTokenizer

In [2]:
import sys
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

src_dir = os.path.join(os.getcwd(), 'src')
sys.path.append(os.path.abspath(src_dir))

from importlib import reload

import src.utils.data_loader

reload(src.utils.data_loader)

from src.utils.data_loader import DataEncoder

In [26]:
train_station_df = pd.read_csv('../data/liste-des-gares.csv', sep=';')

data_type = ["train_station", "city"]
version = "1"

selected_type = data_type[0]
test_df = pd.read_csv(f'../data/dataset/test/test{version}_{selected_type}.csv')

In [8]:
batch_size = 30
epochs = 2
learning_rate = 2e-5

device = "cpu"

tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-french-europeana-cased")
bert_model = AutoModel.from_pretrained("dbmdz/bert-base-french-europeana-cased")
max_len = 128

In [10]:
le_train_stations = LabelEncoder()

le_train_stations.fit(train_station_df["LIBELLE"].unique())

In [11]:
class BertForSequenceClassification(nn.Module):
    def __init__(self, n_depart, n_arrival):
        super(BertForSequenceClassification, self).__init__()
        self.bert = bert_model
        self.drop = nn.Dropout(p=0.3)
        self.out_depart = nn.Linear(self.bert.config.hidden_size, n_depart)
        self.out_arrival = nn.Linear(self.bert.config.hidden_size, n_arrival)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        pooled_output = outputs[1]

        output_depart = self.out_depart(self.drop(pooled_output))
        output_arrival = self.out_arrival(self.drop(pooled_output))

        return output_depart, output_arrival

    def freeze_arrival_layer(self):
        for param in self.out_arrival.parameters():
            param.requires_grad = False

    def unfreeze_arrival_layer(self):
        for param in self.out_arrival.parameters():
            param.requires_grad = True

    def freeze_depart_layer(self):
        for param in self.out_depart.parameters():
            param.requires_grad = False

    def unfreeze_depart_layer(self):
        for param in self.out_depart.parameters():
            param.requires_grad = True

In [13]:
model = BertForSequenceClassification(len(le_train_stations.classes_), len(le_train_stations.classes_))
model.load_state_dict(torch.load("./processed/4000_classes_classification/departure_arrival_model3_trained.pth"))
model = model.to(device)

In [15]:
test_df

Unnamed: 0,arrival,departure,sentence
0,génelard,franxault,"Je pars de franxault et j'arrive à génelard, c..."
1,laillé,saubusse les bains,Je pars de saubusse les bains et j'arrive à la...
2,poligny,vougeot gilly lès cîteaux,Pour aller à poligny depuis vougeot gilly lès ...
3,varilhes,rennes,"Pour aller à varilhes depuis rennes, quelle es..."
4,terrenoire,naintré les barres,naintré les barres est ma gare de départ et j'...
...,...,...,...
205,rémy,vert galant,Si je commence mon voyage à vert galant avec u...
206,bièvres,bayonne,Il est important que je parte de bayonne à une...
207,versailles chantiers,st yorre,Il est important que je parte de st yorre à un...
208,melun,marmagne sous creusot,Pour m’assurer que je ne rate pas mon rendez-v...


In [33]:
def infer(sentence, model):
    inputs = tokenizer(sentence, return_tensors="pt")

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    depart = le_train_stations.inverse_transform(torch.max(outputs[0], 1).indices)
    arrival = le_train_stations.inverse_transform(torch.max(outputs[1], 1).indices)

    return depart, arrival

def sanitize_name(name: str) -> str:
    if name is None:
        return None
    return name.replace("-", " ").replace("(", "").replace(")", "").replace("  ", " ").lower()
            

In [None]:
results = []

with tqdm(total=len(test_df)) as pbar:
    for idx, row in test_df.iterrows():
        infer_depart, infer_arrival = infer(row.sentence, model)
        
        evaluation = True

        content = {
            "sentence": row.sentence,
            "departure": sanitize_name(row.departure),
            "arrival": sanitize_name(row.arrival),
            "intermediary_1": sanitize_name(row.get("intermediary_1")),
            "intermediary_2": sanitize_name(row.get("intermediary_2")),
            "none": sanitize_name(row.get("none")),
            "infer_departure": sanitize_name(infer_depart[0]),
            "infer_arrival": sanitize_name(infer_arrival[0]),
        }

        if content["infer_departure"] is None:
            evaluation = False
        evaluation = content["infer_departure"] == content["departure"] if evaluation else False

        if content["infer_arrival"] is None:
            evaluation = False
        evaluation = content["infer_arrival"] == content["arrival"] if evaluation else False

        if row.get("intermediary_1") or row.get("intermediary_2") or row.get("none"):
            evaluation = False

        if content.get("infer_departure") is None or content.get("infer_arrival") is None:
            evaluation = False

        content["evaluation"] = evaluation

        results.append(content)

        pbar.update(1)

df = pd.DataFrame(results)
df.to_csv(f"./processed/test_results/4000_classes_classification/test{version}_{selected_type}_results.csv", index=False)

  0%|          | 0/210 [00:00<?, ?it/s]