# TDT13 Project - Oskar Holm (F2023)

This project is based on the shared task related to Social Media Geolocation (SMG) from VarDial 2020 and 2021, specifically the Workshop on Natural Language Processing (NLP) for Similar Languages, Varieties, and Dialects. Unlike typical VarDial tasks that involve choosing from a set of variety labels, this task focuses on predicting the latitude and longitude from which a social media post was made.

The task remained the same in both 2020 and 2021, covering three language areas: Bosnian-Croatian-Montenegrin-Serbian, German (Germany and Austria), and German-speaking Switzerland. This project is limited to the German-speaking Switzerland area due to time constraints and resource availability.

The goal of the project is to replicate the results of a study that used a BERT-based classifier for this double regression task. The dataset from the 2020 VarDial challenge is chosen because it had more submissions compared to the 2021 dataset. 

## Setup

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import torch
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader
import pickle
import json 
from torch.optim import AdamW
from tqdm import tqdm
from transformers import logging
import random

from lib.preprocessing import scalers, get_reduced_dev_split
from lib.train_utils import TensorBoardCheckpoint, get_model, get_lossfn, get_scheduler, evaluate_geolocation_model_by_checkpoint
from lib.geo import to_projection, GeolocationDataset
from lib.metrics import median_distance, mean_distance
from lib.plotting import plot_switzerland, plot_barchart

logging.set_verbosity_error()

data_path = '../data'
vardial_path = '../vardial-shared-tasks/SMG2020'

config_name = 'bert-finetuned-swiss-L1-reduced-dev-plateau'
with open('./configs.json', 'r') as f: 
    configs = json.load(f)

config = configs[config_name]

torch.manual_seed(config['seed'])
random.seed(0)
np.random.seed(0)

config

In [None]:
# Load datasets
train_data = pd.read_table(f'{vardial_path}/ch/train.txt', header=None, names=['lat', 'lon', 'text'])
dev_data = pd.read_table(f'{vardial_path}/ch/dev.txt', header=None, names=['lat', 'lon', 'text'])
test_gold_data = pd.read_table(f'{vardial_path}/ch/test_gold.txt', header=None, names=['lat', 'lon', 'text'])

# Get alternative split, if specified in config
if 'split' in config and config['split'] == 'reduced-dev':
    train_data, dev_data = get_reduced_dev_split(train_data, dev_data)

# Convert to specified projection, if any
train_data, col_names = to_projection(train_data, config)
dev_data, _ = to_projection(dev_data, config)
test_gold_data, _ = to_projection(test_gold_data, config)

# Scaling
scaler = scalers[config['scaler']]()
train_coords = scaler.fit_transform(train_data[col_names[:2]].values)
Path(f'{data_path}/ch').mkdir(exist_ok=True)
with open(f'{data_path}/ch/scaler.pkl', 'wb') as f:
    pickle.dump(scaler, f)

# Scale other datasets and create loaders
train_dataset = GeolocationDataset(train_data['text'].tolist(), train_coords, config)
train_loader = DataLoader(train_dataset, batch_size=config['train_batch_size'], shuffle=True)

dev_coords = scaler.transform(dev_data[col_names[:2]].values)
dev_dataset = GeolocationDataset(dev_data['text'].tolist(), dev_coords, config)
dev_loader = DataLoader(dev_dataset, batch_size=config['train_batch_size'], shuffle=False)

test_gold_coords = scaler.transform(test_gold_data[col_names[:2]].values)
test_gold_dataset = GeolocationDataset(test_gold_data['text'].tolist(), test_gold_coords, config)
test_gold_loader = DataLoader(test_gold_dataset, batch_size=config['train_batch_size'], shuffle=False)

train_data.head()

## Training 

### Load Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')
torch.cuda.empty_cache()

model = get_model(config)
model.to(device)

optimizer = AdamW(model.parameters(), config['lr'])
if 'scheduler' in config:
    scheduler = get_scheduler(optimizer, config)
    
loss_function = get_lossfn(config)

### Training Loop

In [None]:
tb_checkpoint = TensorBoardCheckpoint(log_dir=f'{data_path}/ch/logs',
                                      checkpoint_path=f'{data_path}/ch/checkpoints', run_name=config_name)

def train(model, train_loader, dev_loader, optimizer, loss_function, scaler, epochs=10):
    loss_function = loss_function.to(device)

    # Early stopping
    best_metric = float('inf')  
    epochs_no_improve = 0
    early_stop_patience = 10  

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader):
            optimizer.zero_grad()

            # Forward pass
            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            outputs = model(**inputs)

            # Calculate loss
            loss = loss_function(outputs.logits, labels)
            total_loss += loss.item()

            # Backward pass
            loss.backward()
            optimizer.step()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} - Training loss: {avg_train_loss:.4f}")

        # Evaluate on dev (validation)
        model.eval()
        dev_preds = []
        dev_labels = []
        with torch.no_grad():
            for batch in tqdm(dev_loader):
                inputs, labels = batch
                inputs = {k: v.to(device) for k, v in inputs.items()}
                labels = labels.to(device)
                outputs = model(**inputs)
                dev_preds.append(outputs.logits.cpu().numpy())
                dev_labels.append(labels.cpu().numpy())

        # Metrics
        dev_preds = np.vstack(dev_preds)
        dev_labels = np.vstack(dev_labels)

        median_dist = median_distance(dev_preds, dev_labels, scaler, config)
        mean_dist = mean_distance(dev_preds, dev_labels, scaler, config)

        metrics = {
            'Loss/train': avg_train_loss, 
            'Median_Distance/dev': median_dist, 
            'Mean_Distance/dev': mean_dist,
        }
        tb_checkpoint.log_metrics(metrics, epoch)
        tb_checkpoint.save_checkpoint(model, optimizer, epoch, metrics, scaler)

        current_metric = metrics['Median_Distance/dev'] 
        if current_metric < best_metric:
            best_metric = current_metric
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == early_stop_patience:
                print(f"Early stopping triggered after {epoch + 1} epochs.")
                break

        if 'scheduler' in config:
            scheduler.step(metrics['Median_Distance/dev'])

with open(f'{data_path}/ch/scaler.pkl', 'rb') as f:
    scaler = pickle.load(f)

train(model, train_loader, dev_loader, optimizer, loss_function, scaler, epochs=config['epochs'])

tb_checkpoint.close()

## Evaluation

### Evaluate and compare all models 

In [None]:
chkp_dir = f'{data_path}/ch/checkpoints'

config_names = configs.keys()

best_results = {
    'median_distance': np.inf,
    'mean_distance': np.inf
}

for config_name in config_names:
    chkp_file = f'{config_name}_best_model.pth'
    chkp_config = configs[config_name]

    results, _ = evaluate_geolocation_model_by_checkpoint(
        chkp_dir,
        chkp_file,
        vardial_path,
        chkp_config,
    )

    if results['median_distance'] < best_results['median_distance']:
        best_checkpoint = chkp_file
        best_results['median_distance'] = results['median_distance']
        best_results['mean_distance'] = results['mean_distance']

print(f"{'-' * 40}\nBest Checkpoint:", best_checkpoint)
print("Best Results:", best_results)

### Evaluate single model

In [None]:
# config_name = '20231107-230346'
# config_name = 'utm_lr2e-5'
# config_name = 'swissbert'

chkp_file = f'{config_name}_best_model.pth'
chkp_config = configs[config_name]

_, test_preds = evaluate_geolocation_model_by_checkpoint(
    chkp_dir,
    chkp_file,
    vardial_path,
    chkp_config,
)

test_gold_data = pd.read_table(f'{vardial_path}/ch/test_gold.txt', header=None, names=['lat', 'lon', 'text'])

## Plots

In [None]:
plot_switzerland(test_preds, test_gold_data, data_path)

In [None]:
plot_barchart(test_preds, test_gold_data)