In [1]:
# pip install flash_attn

In [2]:
import polars as pl
import orjson
import torch
from tqdm.autonotebook import tqdm

import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

  from tqdm.autonotebook import tqdm


In [3]:
df = pl.read_parquet("mtg_data.parquet").sample(fraction=1, shuffle=True, seed=42)

df

name,manaCost,type,text,power,toughness,loyalty,rarities,sets
str,str,str,str,str,str,str,list[enum],list[enum]
"""Vilis, Broker of Blood""","""{5}{B}{B}{B}""","""Legendary Creature — Demon""","""Flying\n{B}, Pay 2 life: Targe…","""8""","""8""",,"[""rare""]","[""PM20"", ""M20"", … ""J25""]"
"""Cruel Celebrant""","""{W}{B}""","""Creature — Vampire""","""Whenever ~ or another creature…","""1""","""2""",,"[""uncommon""]","[""WAR"", ""PLST"", ""LCC""]"
"""Leyline of Lifeforce""","""{2}{G}{G}""","""Enchantment""","""If ~ is in your opening hand, …",,,,"[""rare""]","[""GPT""]"
"""Lumbering Falls""",,"""Land""","""~ enters tapped.\n{T}: Add {G}…",,,,"[""rare""]","[""PBFZ"", ""BFZ"", … ""PIO""]"
"""Grand Warlord Radha""","""{2}{R}{G}""","""Legendary Creature — Elf Warri…","""Haste\nWhenever one or more cr…","""3""","""4""",,"[""rare""]","[""PDOM"", ""DOM""]"
…,…,…,…,…,…,…,…,…
"""Mimeofacture""","""{3}{U}""","""Sorcery""","""Replicate {3}{U} (When you cas…",,,,"[""rare""]","[""GPT""]"
"""Scar""","""{B/R}""","""Instant""","""Put a -1/-1 counter on target …",,,,"[""common""]","[""SHM""]"
"""Field Marshal""","""{1}{W}{W}""","""Creature — Human Soldier""","""Other Soldier creatures get +1…","""2""","""2""",,"[""rare""]","[""CSP"", ""10E"", ""SLD""]"
"""Flanking Licid""","""{1}{R}""","""Summon Licid""","""{R}, {T}: ~ loses this ability…",,,,"[""rare""]","[""MB2""]"


In [4]:
docs = []
for row in df.iter_rows(named=True):
    row_dict = {k: v for k, v in row.items() if v is not None}
    row_dict = orjson.dumps(row_dict, option=orjson.OPT_INDENT_2).decode("utf-8")
    docs.append(row_dict)

print(docs[0])

{
  "name": "Vilis, Broker of Blood",
  "manaCost": "{5}{B}{B}{B}",
  "type": "Legendary Creature — Demon",
  "text": "Flying\\n{B}, Pay 2 life: Target creature gets -1/-1 until end of turn.\\nWhenever you lose life, draw that many cards. (Damage causes loss of life.)",
  "power": "8",
  "toughness": "8",
  "rarities": [
    "rare"
  ],
  "sets": [
    "PM20",
    "M20",
    "SLD",
    "GN3",
    "J25"
  ]
}


In [5]:
model_path = "Alibaba-NLP/gte-modernbert-base"
device = "cuda:0"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)
_ = model.to(device)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [6]:
tokenized_docs = tokenizer(
    docs[0], max_length=8192, padding=True, truncation=True, return_tensors="pt"
).to(device)

tokenized_docs

{'input_ids': tensor([[50281,    92,   187, 50276,     3,  1590,  1381,   346,    55, 27154,
            13,  4819,  6426,   273, 14169,   995,   187, 50276,     3,  1342,
            66, 25997,  1381, 36028,    22,  1217,    35,  1217,    35,  1217,
            35, 32722,   187, 50276,     3,   881,  1381,   346, 18596,   423,
           552, 13489,   459,  1905,  4281,   251,   995,   187, 50276,     3,
          1156,  1381,   346,    39,  2943,  3353,    79,    92,    35,  2023,
         12286,   374,  1495,    27, 17661, 15906,  4850,   428,    18,  7448,
            18,  1919,   990,   273,  1614,    15,  3353,    79, 43835,   368,
          7168,  1495,    13,  3812,   326,  1142,  8364,    15,   313, 21727,
           486,  5997,  2957,   273,  1495,  2698,   995,   187, 50276,     3,
          9177,  1381,   346,    25,   995,   187, 50276,     3,    85,   602,
          1255,  1381,   346,    25,   995,   187, 50276,     3, 23537,  1005,
          1381,   544,   187, 50274,  

In [7]:
dataloader = torch.utils.data.DataLoader(docs, batch_size=32, shuffle=False)

dataset_embeddings = []
for batch in tqdm(dataloader, smoothing=0):
    tokenized_batch = tokenizer(
        batch, max_length=8192, padding=True, truncation=True, return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**tokenized_batch)
        embeddings = outputs.last_hidden_state[:, 0]
    dataset_embeddings.append(embeddings)

dataset_embeddings = torch.cat(dataset_embeddings)
dataset_embeddings = F.normalize(dataset_embeddings, p=2, dim=1)
dataset_embeddings.size()

  0%|          | 0/990 [00:00<?, ?it/s]

torch.Size([31650, 768])

In [8]:
df_2 = df.with_columns(embeds=dataset_embeddings.cpu().numpy()).sort("name")

df_2

name,manaCost,type,text,power,toughness,loyalty,rarities,sets,embeds
str,str,str,str,str,str,str,list[enum],list[enum],"array[f32, 768]"
"""""Ach! Hans, Run!""""","""{2}{R}{R}{G}{G}""","""Enchantment""","""At the beginning of your upkee…",,,,"[""rare""]","[""UNH""]","[0.021551, -0.036194, … -0.006151]"
"""""Brims"" Barone, Midway Mobster""","""{3}{W}{B}""","""Legendary Creature — Human Rog…","""When ~ enters, put a +1/+1 cou…","""5""","""4""",,"[""uncommon""]","[""UNF""]","[0.000329, 0.025074, … 0.005499]"
"""""Intimidation Tactics""""","""{B}""","""Sorcery""","""Target opponent reveals their …",,,,"[""uncommon""]","[""DFT""]","[-0.030729, -0.004603, … 0.007073]"
"""""Lifetime"" Pass Holder""","""{B}""","""Creature — Zombie Guest""","""~ enters tapped.\nWhen ~ dies,…","""2""","""1""",,"[""rare""]","[""UNF""]","[0.001363, -0.003904, … 0.013427]"
"""""Name Sticker"" Goblin""","""{2}{R}""","""Creature — Goblin Guest""","""When this creature enters from…","""2""","""2""",,"[""common""]","[""UNF""]","[-0.032506, -0.006977, … -0.026924]"
…,…,…,…,…,…,…,…,…,…
"""Éomer, King of Rohan""","""{3}{R}{W}""","""Legendary Creature — Human Nob…","""Double strike\n~ enters with a…","""2""","""2""",,"[""rare""]","[""LTC""]","[0.014294, 0.012956, … 0.019557]"
"""Éomer, Marshal of Rohan""","""{2}{R}{R}""","""Legendary Creature — Human Kni…","""Haste\nWhenever one or more ot…","""4""","""4""",,"[""rare""]","[""PLTR"", ""LTR""]","[-0.022398, 0.017195, … 0.046952]"
"""Éowyn, Fearless Knight""","""{2}{R}{W}""","""Legendary Creature — Human Kni…","""Haste\nWhen ~ enters, exile ta…","""3""","""4""",,"[""rare""]","[""PLTR"", ""LTR""]","[-0.033406, 0.007331, … 0.028238]"
"""Éowyn, Lady of Rohan""","""{2}{W}""","""Legendary Creature — Human Nob…","""At the beginning of combat on …","""2""","""4""",,"[""uncommon""]","[""LTR""]","[-0.001306, 0.005256, … 0.048215]"


In [9]:
df_2.write_parquet("mtg_embeddings.parquet")

In [10]:
!gsutil cp mtg_embeddings.parquet gs://maxw-imdb-embeddings/

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Copying file://mtg_embeddings.parquet [Content-Type=application/octet-stream]...
- [1 files][ 87.5 MiB/ 87.5 MiB]                                                
Operation completed over 1 objects/87.5 MiB.                                     
