In [1]:
!nvidia-smi

Thu Jan 21 01:56:07 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   73C    P0    39W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# New Section

In [2]:
!pip install transformers
!pip install datasets



In [3]:
import numpy as np
import json
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import collections

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers.trainer_utils import is_main_process

from datasets import load_dataset

In [4]:
def load_json(path):
  with open(path, 'r') as f:
    return json.load(f)

In [5]:
!wget https://mtgjson.com/api/v5/AllPrintings.json

--2021-01-21 01:56:15--  https://mtgjson.com/api/v5/AllPrintings.json
Resolving mtgjson.com (mtgjson.com)... 104.21.64.186, 172.67.154.80, 2606:4700:3030::6815:40ba, ...
Connecting to mtgjson.com (mtgjson.com)|104.21.64.186|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 210653155 (201M) [application/json]
Saving to: ‘AllPrintings.json.1’


2021-01-21 01:56:17 (105 MB/s) - ‘AllPrintings.json.1’ saved [210653155/210653155]



In [6]:
with open('AllPrintings.json', 'r') as f:
  data = json.load(f)['data']

In [7]:
card_list = []
keys = ['name', 'type', 'manaCost', 'rarity', 'text', 'power', 'toughness', 'loyalty', 'flavorText']
seen_names = set([])
pbar = tqdm(total = len(data), leave=False)
for set_name, set_list in data.items():
  for card in set_list['cards']:
    if card['name'] in seen_names:
      continue
    else:
      seen_names.add(card['name'])

    card_info = [str(card.get(k, None)) for k in keys]
    if '|' in ' '.join(card_info):
      print('PIPES FOUND IN %s, %s'%(set_name, card['name']))
    else:
      card_str = '| '.join(card_info) + ' end_of_card '
      card_str = card_str.replace('\n', ' line_break ')
    card_list.append(card_str)

  pbar.update(1)

HBox(children=(FloatProgress(value=0.0, max=548.0), HTML(value='')))

PIPES FOUND IN UNH, Magical Hacker


In [8]:
len(card_list)

21696

In [9]:
card_list = np.array(card_list)

In [10]:
card_list[0]

"Ancestor's Chosen| Creature — Human Cleric| {5}{W}{W}| uncommon| First strike (This creature deals combat damage before creatures without first strike.) line_break When Ancestor's Chosen enters the battlefield, you gain 1 life for each card in your graveyard.| 4| 4| None| None end_of_card "

In [11]:
np.random.seed(42)
train_indexer = np.random.rand(card_list.shape[0]) < .9
train_data = card_list[train_indexer]
test_data = card_list[~train_indexer]

In [12]:
train_data.shape, test_data.shape

((19532,), (2164,))

In [13]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
def save_txt(path, to_save):
  with open(path, 'w') as f:
    for line in to_save:
      f.write(line + '\n')

In [15]:
save_path = './drive/My Drive/models/mtg_card_gen/'
train_data_save_path = save_path + 'train_data.txt'
test_data_save_path = save_path + 'test_data.txt'

In [16]:
save_txt(train_data_save_path, train_data)
save_txt(test_data_save_path, test_data)

In [17]:
datasets = load_dataset('text', data_files = {'train' : train_data_save_path, 'test' :  test_data_save_path})

Using custom data configuration default


Downloading and preparing dataset text/default-2bfd794794bc0b0e (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/text/default-2bfd794794bc0b0e/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset text downloaded and prepared to /root/.cache/huggingface/datasets/text/default-2bfd794794bc0b0e/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


In [18]:
load_path = 'gpt2'

In [19]:
config = AutoConfig.from_pretrained(load_path)
tokenizer = AutoTokenizer.from_pretrained(load_path)
model = AutoModelForCausalLM.from_pretrained(
            load_path,
            from_tf=False,
            config=config,) 

In [20]:
model.num_parameters()

124439808

In [21]:
model.resize_token_embeddings(len(tokenizer))

Embedding(50257, 768)

In [22]:
column_names = datasets["train"].column_names

In [23]:
text_column_name = "text" if "text" in column_names else column_names[0]

In [24]:
def tokenize_function(examples):
  return tokenizer(examples[text_column_name])

In [25]:
tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        num_proc=1,
        remove_columns=column_names,
        load_from_cache_file=True,
    )

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))




In [26]:
block_size = tokenizer.model_max_length

In [27]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [28]:
lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=1,
        load_from_cache_file=True,
    )

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))




In [29]:
lm_datasets

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 1521
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 169
    })
})

In [30]:
training_args = TrainingArguments(
    output_dir=save_path,
    overwrite_output_dir = False,
    num_train_epochs = 100,
    per_device_train_batch_size = 1,
    per_device_eval_batch_size = 1,
    save_steps = 50,
    gradient_accumulation_steps = 64,
    logging_steps = 50,
    
    eval_steps = 50,
    evaluation_strategy = 'steps',

    save_total_limit = 10,
)



trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=lm_datasets["train"],
        eval_dataset=lm_datasets["test"],
        tokenizer=tokenizer,
        # Data collator will default to DataCollatorWithPadding, so we change it.
        data_collator=default_data_collator,
    )

In [31]:
trainer.train()

Step,Training Loss,Validation Loss,Runtime,Samples Per Second
50,1.9269,1.660821,23.5342,7.181
100,1.6485,1.567201,23.6769,7.138
150,1.5543,1.515008,23.4972,7.192
200,1.4905,1.480472,23.462,7.203
250,1.4364,1.454522,23.4741,7.199
300,1.4172,1.43483,23.5236,7.184


KeyboardInterrupt: ignored

In [None]:
-