In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import mean_squared_error
from tqdm import tqdm
import categories

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
THRESHOLD = 100

model_names = {
    "herbert-klej-cased-v1": {
        "tokenizer": "allegro/herbert-klej-cased-tokenizer-v1", 
        "model": "allegro/herbert-klej-cased-v1",
    },
    "herbert-base-cased": {
        "tokenizer": "allegro/herbert-base-cased", 
        "model": "allegro/herbert-base-cased",
    },
    "herbert-large-cased": {
        "tokenizer": "allegro/herbert-large-cased", 
        "model": "allegro/herbert-large-cased",
    },
}

tokenizer = AutoTokenizer.from_pretrained(model_names["herbert-base-cased"]["tokenizer"])
herbert = AutoModel.from_pretrained(model_names["herbert-base-cased"]["model"]).to(device)

@torch.no_grad()
def herbert_forward(data, batch_size=256):
    embeddings = []
    for i in tqdm(range(0, len(data), batch_size)):
        batch = data[i:i+batch_size]
        tokens = tokenizer.batch_encode_plus(
            batch,
            padding="longest",
            add_special_tokens=True,
            return_tensors="pt",
        )

        if torch.cuda.is_available():
            for key in tokens.keys():
                tokens[key] = tokens[key].to(device)

        embeddings.append(herbert(**tokens)['pooler_output'].cpu())
    return torch.cat(embeddings)


import xgboost as xgb

model = xgb.XGBRegressor()

def train(batch_size=256, validate=False):
    places = pd.read_csv('places.csv.gz')
    places = places[places['language'] == 'pl'][places['category'].map(places['category'].value_counts()) > THRESHOLD].reset_index()
    X = pd.DataFrame(herbert_forward(list(places['query'])).numpy())
    X['category'] = places['category'].map(categories.cat_id)
    X['audit_latitude'] = places['audit_latitude']
    X['audit_longitude'] = places['audit_longitude']
    X.fillna(len(categories.id_cat))
    y = places['position']
    # print(y.isna())
    model.fit(X, y)
    if validate:
        y_pred = model.predict(X)
        print(mean_squared_error(y, y_pred))




Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.sso.sso_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
validate = True
columns = ['language', 'category', 'query', 'position', 'audit_latitude', 'audit_longitude']
places = pd.read_csv('places.csv.gz')[columns].dropna()
places = places[places['language'] == 'pl'][places['category'].map(places['category'].value_counts()) > THRESHOLD].reset_index()
y = places['position']
X = pd.DataFrame(herbert_forward(list(places['query'])).numpy())
X['category'] = places['category'].map(categories.cat_id)
X['audit_latitude'] = places['audit_latitude']
X['audit_longitude'] = places['audit_longitude']
X.fillna(len(categories.id_cat), inplace=True)
# print(y.isna())


  places = places[places['language'] == 'pl'][places['category'].map(places['category'].value_counts()) > THRESHOLD]
100%|██████████| 1761/1761 [00:39<00:00, 44.45it/s]


In [19]:
model.fit(X, y)
if validate:
    y_pred = model.predict(X)
    print(mean_squared_error(y, y_pred))

29.91313709764606


In [22]:
y.value_counts()

1.0     23746
2.0     23679
3.0     23567
4.0     23508
5.0     23298
6.0     23115
7.0     23031
8.0     22901
9.0     22819
10.0    22632
11.0    22452
12.0    22343
13.0    22203
14.0    22037
15.0    21835
16.0    21699
17.0    21606
18.0    21527
19.0    21435
20.0    21218
Name: position, dtype: int64

In [15]:
y[y.isna()]

560758   NaN
560759   NaN
560760   NaN
560761   NaN
560762   NaN
          ..
902829   NaN
902830   NaN
902831   NaN
902832   NaN
902833   NaN
Name: position, Length: 315, dtype: float64