In [1]:
# pip install flash_attn

In [2]:
import polars as pl
import orjson
import torch
# import yaml
from tqdm import tqdm

import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

In [3]:
df = pl.read_parquet("mtg_data.parquet").sample(fraction=1, shuffle=True, seed=42)

df

name,scryfallId,manaCost,type,text,power,toughness,loyalty,rarities,sets
str,str,str,str,str,str,str,str,list[enum],list[enum]
"""Marhault Elsdragon""","""2e805883-081b-478a-aa58-172b65…","""{3}{R}{R}{G}""","""Legendary Creature — Elf Warri…","""Rampage 1 (Whenever this creat…","""4""","""6""",,"[""uncommon""]","[""LEG"", ""BCHR"", … ""ME3""]"
"""Raven's Crime""","""7a67bfc7-da5e-443c-88c9-1c8739…","""{B}""","""Sorcery""","""Target player discards a card.…",,,,"[""common""]","[""EVE"", ""MMA"", ""PLST""]"
"""Shauku, Endbringer""","""06d94b21-7568-4e5c-a8ec-ff5bb4…","""{5}{B}{B}""","""Legendary Creature — Vampire""","""Flying\n~ can't attack if ther…","""5""","""5""",,"[""rare""]","[""MIR""]"
"""Awakened Awareness""","""51378e10-2f17-445b-97e9-5c3732…","""{X}{U}{U}""","""Enchantment — Aura""","""Enchant artifact or creature\n…",,,,"[""uncommon""]","[""NEO""]"
"""Body Snatcher""","""2bd51f0b-3c8a-4b2e-a0b6-7fa98c…","""{2}{B}{B}""","""Creature — Phyrexian Minion""","""When this creature enters, exi…","""2""","""2""",,"[""rare""]","[""UDS"", ""PLST"", ""DMR""]"
…,…,…,…,…,…,…,…,…,…
"""Minamo Scrollkeeper""","""821c99d3-389a-4138-99a2-0c4a4b…","""{1}{U}""","""Creature — Human Wizard""","""Defender\nYour maximum hand si…","""2""","""3""",,"[""common""]","[""SOK"", ""PSAL"", ""CNS""]"
"""Scapeshift""","""5edac3e9-6c18-4801-a035-2803d3…","""{2}{G}{G}""","""Sorcery""","""Sacrifice any number of lands.…",,,,"[""rare"", ""mythic""]","[""PRM"", ""MOR"", … ""SPG""]"
"""Fierce Guardianship""","""c103c4d8-5a3c-443b-9caf-f41f8b…","""{2}{U}""","""Instant""","""If you control a commander, yo…",,,,"[""rare""]","[""SLD"", ""C20"", ""CMM""]"
"""Flavor Disaster""","""ca8606d7-0da5-4fb5-b541-a12900…","""{4}{G}""","""Enchantment Creature — Element…","""Reach\nWhen ~ enters the battl…","""4""","""3""",,"[""rare""]","[""MB2""]"


In [4]:
docs = []
for row in df.iter_rows(named=True):
    row_dict = {k: v for k, v in row.items() if v is not None and k != "scryfallId"}
    # row_str = yaml.dump(row_dict, sort_keys=False, allow_unicode=True)
    row_str = orjson.dumps(row_dict, option=orjson.OPT_INDENT_2).decode("utf-8")
    docs.append(row_str)

print(docs[0])

{
  "name": "Marhault Elsdragon",
  "manaCost": "{3}{R}{R}{G}",
  "type": "Legendary Creature — Elf Warrior",
  "text": "Rampage 1 (Whenever this creature becomes blocked, it gets +1/+1 until end of turn for each creature blocking it beyond the first.)",
  "power": "4",
  "toughness": "6",
  "rarities": [
    "uncommon"
  ],
  "sets": [
    "LEG",
    "BCHR",
    "CHR",
    "ME3"
  ]
}


In [5]:
model_path = "Alibaba-NLP/gte-modernbert-base"
device = "cuda:0"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)
_ = model.to(device)

torch.set_float32_matmul_precision('high')

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [6]:
tokenized_docs = tokenizer(
    docs[0], max_length=8192, padding=True, truncation=True, return_tensors="pt"
).to(device)

tokenized_docs

{'input_ids': tensor([[50281,    92,   187, 50276,     3,  1590,  1381,   346,  9709,    73,
          1923,  3599,    84,  5267,  5154,   995,   187, 50276,     3,  1342,
            66, 25997,  1381, 36028,    20,  1217,    51,  1217,    51,  1217,
            40, 32722,   187, 50276,     3,   881,  1381,   346, 18596,   423,
           552, 13489,   459,  1905,  3599,    71, 46191,   995,   187, 50276,
             3,  1156,  1381,   346,    51,  1301,   486,   337,   313, 43835,
           436, 15906,  4916, 13230,    13,   352,  4850,   559,    18, 23615,
            18,  1919,   990,   273,  1614,   323,  1016, 15906, 14589,   352,
          4457,   253,   806,  2698,   995,   187, 50276,     3,  9177,  1381,
           346,    21,   995,   187, 50276,     3,    85,   602,  1255,  1381,
           346,    23,   995,   187, 50276,     3, 23537,  1005,  1381,   544,
           187, 50274,     3,   328,  9784,     3,   187, 50276,  1092,   187,
         50276,     3, 19598,  1381,  

In [7]:
dataloader = torch.utils.data.DataLoader(docs, batch_size=64,
                                         shuffle=False,
                                         pin_memory=True,
                                         pin_memory_device=device)

dataset_embeddings = []
for batch in tqdm(dataloader, smoothing=0):
    tokenized_batch = tokenizer(
        batch, max_length=8192, padding=True, truncation=True, return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**tokenized_batch)
        embeddings = outputs.last_hidden_state[:, 0].detach().cpu()
    dataset_embeddings.append(embeddings)

dataset_embeddings = torch.cat(dataset_embeddings)
dataset_embeddings = F.normalize(dataset_embeddings, p=2, dim=1)
dataset_embeddings.size()

100%|██████████| 504/504 [01:17<00:00,  6.48it/s]


torch.Size([32254, 768])

In [8]:
df_2 = df.with_columns(embedding=dataset_embeddings.cpu().numpy()).sort("name")

df_2

name,scryfallId,manaCost,type,text,power,toughness,loyalty,rarities,sets,embedding
str,str,str,str,str,str,str,str,list[enum],list[enum],"array[f32, 768]"
"""""Ach! Hans, Run!""""","""84f2c8f5-8e11-4639-b7de-00e4a2…","""{2}{R}{R}{G}{G}""","""Enchantment""","""At the beginning of your upkee…",,,,"[""rare""]","[""UNH""]","[0.021458, -0.036102, … -0.00607]"
"""""Brims"" Barone, Midway Mobster""","""68832214-2943-4253-8884-ffa490…","""{3}{W}{B}""","""Legendary Creature — Human Rog…","""When ~ enters, put a +1/+1 cou…","""5""","""4""",,"[""uncommon""]","[""UNF""]","[0.000261, 0.025207, … 0.005641]"
"""""Lifetime"" Pass Holder""","""42293306-aaea-4542-8df4-813823…","""{B}""","""Creature — Zombie Guest""","""This creature enters tapped.\n…","""2""","""1""",,"[""rare""]","[""UNF""]","[-0.004467, -0.016707, … 0.001401]"
"""""Name Sticker"" Goblin""","""fd1442b4-da59-4042-835f-143c8d…","""{2}{R}""","""Creature — Goblin Guest""","""When this creature enters from…","""2""","""2""",,"[""common""]","[""UNF""]","[-0.03243, -0.006905, … -0.026888]"
"""""Rumors of My Death . . .""""","""cb3587b9-e727-4f37-b4d6-1baa73…","""{2}{B}""","""Enchantment""","""{3}{B}, Exile a permanent you …",,,,"[""uncommon""]","[""UST""]","[-0.008719, -0.009454, … 0.001481]"
…,…,…,…,…,…,…,…,…,…,…
"""Éomer, King of Rohan""","""f2c11695-f22b-44d5-937c-2578f2…","""{3}{R}{W}""","""Legendary Creature — Human Nob…","""Double strike\n~ enters with a…","""2""","""2""",,"[""rare""]","[""LTC""]","[0.014415, 0.012861, … 0.019452]"
"""Éomer, Marshal of Rohan""","""fba68512-f536-4961-9e24-563270…","""{2}{R}{R}""","""Legendary Creature — Human Kni…","""Haste\nWhenever one or more ot…","""4""","""4""",,"[""rare""]","[""PLTR"", ""LTR""]","[-0.022492, 0.017429, … 0.046833]"
"""Éowyn, Fearless Knight""","""c1b37891-5ed9-47e4-8d2f-c2bfd8…","""{2}{R}{W}""","""Legendary Creature — Human Kni…","""Haste\nWhen ~ enters, exile ta…","""3""","""4""",,"[""rare""]","[""PLTR"", ""LTR""]","[-0.033708, 0.007089, … 0.02833]"
"""Éowyn, Lady of Rohan""","""e59710c4-24de-419e-a8a0-e8392d…","""{2}{W}""","""Legendary Creature — Human Nob…","""At the beginning of combat on …","""2""","""4""",,"[""uncommon""]","[""LTR""]","[-0.001433, 0.00514, … 0.048259]"


In [9]:
df_2.write_parquet("mtg_embeddings.parquet")

In [10]:
!gsutil cp mtg_embeddings.parquet gs://maxw-imdb-embeddings/

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Copying file://mtg_embeddings.parquet [Content-Type=application/octet-stream]...
- [1 files][ 89.8 MiB/ 89.8 MiB]                                                
Operation completed over 1 objects/89.8 MiB.                                     
