In [2]:
# --- Imports & Paths ---
import pandas as pd
import numpy as np
import torch
from pathlib import Path

IN_PATH  = Path("processed_data/aligned_train_subwords.tsv")  # train-only to avoid leakage
OUT_PATH = Path("processed_data/class_weights.pt")            # torch tensor for CE loss

In [3]:
# --- Load aligned subwords (train) ---
df = pd.read_csv(IN_PATH, sep="\t")
required = {"Label_ID"}
assert required.issubset(df.columns), f"Missing columns: {required - set(df.columns)}"

# keep valid labels (>=0); ignore -100 (special/suffix subwords & specials)
labels = df["Label_ID"]
valid  = labels[labels >= 0]     # only supervised positions
n_tokens = len(labels)
n_supervised = len(valid)
print(f"Rows (subwords): {n_tokens} | supervised (labels>=0): {n_supervised}")

Rows (subwords): 124489 | supervised (labels>=0): 51184


In [4]:
# --- Compute class weights ---
# freq[c] = count of label c among supervised positions
counts = valid.value_counts().sort_index()
num_labels = int(counts.index.max()) + 1  # label ids assumed 0..K-1
freq = np.zeros(num_labels, dtype=np.float64)
freq[counts.index.values] = counts.values

# smoothing to avoid div-by-zero if any class missing (rare but safe)
eps = 1e-6
freq = np.maximum(freq, eps)

# Option A (common): inverse-frequency^alpha  (alpha in [0.5, 1.0])
alpha = 0.5
weights = (1.0 / (freq ** alpha))

# normalize for stability (optional but nice): scale so mean weight = 1
weights = weights * (len(weights) / weights.sum())

# build final torch tensor; add slot for ignore_index=-100 with weight 0 (not used by CE anyway)
class_weights = torch.tensor(weights, dtype=torch.float32)
print("Label counts:\n", counts.to_string())
print("\nClass weights (id -> weight):")
for i, w in enumerate(class_weights.tolist()):
    print(f"{i}: {w:.4f}")

Label counts:
 Label_ID
0     16964
1        29
2       249
3        68
4         3
5      1086
6       313
7      2329
8         5
9        24
10     3539
11     1234
12        1
13     4912
14     4929
15      225
16     5528
17        3
18       34
19      314
20       24
21      110
22       16
23      186
24     3002
25       97
26        5
27       22
28        1
29        2
30        1
31        4
32        1
33      132
34        1
35       32
36       75
37      823
38     1756
39        2
40     2916
41       51
42      136

Class weights (id -> weight):
0: 0.0283
1: 0.6840
2: 0.2334
3: 0.4467
4: 2.1268
5: 0.1118
6: 0.2082
7: 0.0763
8: 1.6474
9: 0.7519
10: 0.0619
11: 0.1049
12: 3.6837
13: 0.0526
14: 0.0525
15: 0.2456
16: 0.0495
17: 2.1268
18: 0.6318
19: 0.2079
20: 0.7519
21: 0.3512
22: 0.9209
23: 0.2701
24: 0.0672
25: 0.3740
26: 1.6474
27: 0.7854
28: 3.6837
29: 2.6048
30: 3.6837
31: 1.8419
32: 3.6837
33: 0.3206
34: 3.6837
35: 0.6512
36: 0.4254
37: 0.1284
38: 0.0879
39: 2.6048

In [5]:
# --- Save ---
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
torch.save(class_weights, OUT_PATH)
print(f"\nSaved class weights -> {OUT_PATH}")


Saved class weights -> processed_data/class_weights.pt
