In [1]:
%%writefile requirements.txt
lightning
peft
trl
accelerate
fire
optimum
jax

Writing requirements.txt


In [2]:
%%capture
!pip install -r requirements.txt

In [11]:
import jax.tools.colab_tpu
import jax
jax.tools.colab_tpu.setup_tpu()
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [1]:
%%writefile data_module.py
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer
import lightning.pytorch as pl


class WikiTextDataModule(pl.LightningDataModule):
    def __init__(self, tokenizer_name='gpt2', max_length=512, batch_size=32):
        super().__init__()
        self.tokenizer_name = tokenizer_name
        self.max_length = max_length
        self.batch_size = batch_size
        self.tokenizer = None

    def prepare_data(self):
        # 데이터셋 다운로드 (이 메소드는 한 번만 실행됨)
        load_dataset("wikitext", "wikitext-103-raw-v1")
        GPT2Tokenizer.from_pretrained(self.tokenizer_name)

    def setup(self, stage=None):
        # 데이터셋 로드 및 전처리
        self.dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
        self.tokenizer = GPT2Tokenizer.from_pretrained(self.tokenizer_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        def tokenize_function(examples):
            return self.tokenizer(examples["text"], truncation=True, max_length=self.max_length, padding="max_length")

        self.tokenized_datasets = self.dataset.map(
            tokenize_function,
            batched=True,
            num_proc=4,
            remove_columns=["text"],
        )

        # 데이터셋을 PyTorch 텐서로 변환
        self.tokenized_datasets.set_format("torch")

    def train_dataloader(self):
        return DataLoader(
            self.tokenized_datasets["train"],
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.tokenized_datasets["validation"],
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.tokenized_datasets["test"],
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True
        )

Writing data_module.py


In [7]:
%%writefile llm_model.py
import os
import jax
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import lightning.pytorch as pl
from transformers import GPT2Config, GPT2LMHeadModel

# XLA 설정
os.environ['XLA_USE_BF16'] = "1"

class ParameterParallelLLM(pl.LightningModule):
    def __init__(self, vocab_size=50257, n_layer=12, n_head=12, n_embd=768):
        super().__init__()
        self.save_hyperparameters()

        # 모델 설정
        config = GPT2Config(
            vocab_size=vocab_size,
            n_layer=n_layer,
            n_head=n_head,
            n_embd=n_embd
        )
        self.model = GPT2LMHeadModel(config)

        # 모델 파라미터 병렬화
        self.parallelize_model()

    def parallelize_model(self):

        device_count = xm.xrt_world_size()
        layers_per_device = self.hparams.n_layer // device_count

        # 임베딩 레이어를 첫 번째 디바이스에 할당
        self.model.transformer.wte.to(xm.xla_device(0))
        self.model.transformer.wpe.to(xm.xla_device(0))

        # 트랜스포머 레이어를 여러 디바이스에 분산
        for i, layer in enumerate(self.model.transformer.h):
            device_id = (i // layers_per_device) % device_count
            layer.to(xm.xla_device(device_id))

        # LM 헤드를 마지막 디바이스에 할당
        self.model.lm_head.to(xm.xla_device(device_count - 1))

    def forward(self, input_ids, attention_mask=None):
        # 입력을 첫 번째 디바이스로 이동
        input_ids = input_ids.to(xm.xla_device(0))
        if attention_mask is not None:
            attention_mask = attention_mask.to(xm.xla_device(0))

        # 모델 실행
        outputs = self.model(input_ids, attention_mask=attention_mask)

        # 출력을 마지막 디바이스에서 가져옴
        return outputs.to(xm.xla_device(xm.xrt_world_size() - 1))

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask = batch
        outputs = self(input_ids, attention_mask)
        loss = outputs.loss
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=5e-5)


Overwriting llm_model.py


In [8]:
%%writefile tpu_trainer.py
from lightning.pytorch.plugins import TPUPrecisionPlugin
from data_module import WikiTextDataModule
from llm_model import ParameterParallelLLM

def main():
    # 데이터 모듈 초기화
    dm = WikiTextDataModule(tokenizer_name='gpt2', max_length=512, batch_size=32)

    # 모델 초기화 (이전에 정의한 ParameterParallelLLM 사용)
    model = ParameterParallelLLM(vocab_size=50257, n_layer=12, n_head=12, n_embd=768)

    # 트레이너 설정
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator='tpu',
        devices=8,  # TPU 코어 수
        num_nodes=1,
        strategy='ddp',
        precision='bf16',
        # plugins=[TPUPlugin(device='tpu')]
    )

    # 학습 실행
    trainer.fit(model, dm)

if __name__ == '__main__':
    main()


Overwriting tpu_trainer.py


In [9]:
!python tpu_trainer.py

  self.pid = os.fork()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
I0000 00:00:1721009824.614279    6548 pjrt_api.cc:100] GetPjrtApi was found for tpu at /usr/local/lib/python3.10/dist-packages/torch_xla/lib/libtpu.so
I0000 00:00:1721009824.614360    6548 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721009824.614375    6548 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
E0715 02:17:04.720415604    7621 oauth2_credentials.cc:176]          