전체 코드

1. 데이터 준비

In [7]:
from datasets import load_dataset
import torch

dataset = load_dataset("daekeun-ml/naver-news-summarization-ko")

data = dataset
ko_text = "".join(data["train"]["document"])
ko_chars = sorted(list(set((ko_text))))
ko_vocab_size = len(ko_chars)
print("총 글자 수 :", ko_vocab_size)

character_to_ids = {char:i for i, char in enumerate(ko_chars)}
ids_to_character = {i:char for i, char in enumerate(ko_chars)}
token_encode = lambda s:[character_to_ids[c] for c in s]
token_decode = lambda l: "".join([ids_to_character[i] for i in l])

tokenized_data = torch.tensor(token_encode(ko_text), dtype=torch.long)

n = int(0.9 * len(tokenized_data))
train_dataset = tokenized_data[:n]
test_dataset = tokenized_data[n:]

torch.manual_seed(1234)

총 글자 수 : 2701


<torch._C.Generator at 0x7fab9e1af7f0>

2. 학습

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size = 32
block_size = 8
max_iteration = 50000
eval_interval = 300
learning_rate = 1e-2
device = "cuda" if torch.cuda.is_available() else "cpu"
eval_iteration = 200

def batch_function(mode):
    dataset = train_dataset if mode == "train" else test_dataset
    idx = torch.randint(len(dataset) - block_size, (batch_size,))
    x = torch.stack([dataset[index:index+block_size] for index in idx])
    y = torch.stack([dataset[index+1:index+block_size+1] for index in idx])
    x, y = x.to(device), y.to(device) # .to 를 추가
    return x, y

@torch.no_grad()
def compute_loss_metrics():
    out = {}
    model.eval()
    for mode in ["train", "eval"]:
        losses = torch.zeros(eval_iteration)
        for k in range(eval_iteration):
            inputs, targets = batch_function(mode)
            logits, loss = model(inputs, targets)
            losses[k] = loss.item()
        out[mode] = losses.mean()
    model.train()
    return out

class semiGPT(nn.Module):
    def __init__(self, vocab_length):
        super().__init__()
        self.embedding_token_table = nn.Embedding(vocab_length, vocab_length)

    def forward(self, inputs, targets=None):
        logits = self.embedding_token_table(inputs)
        if targets is None:
            loss = None
        else:
            batch, seq_length, vocab_length = logits.shape
            logits = logits.view(batch * seq_length, vocab_length)
            targets = targets.view(batch*seq_length)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, inputs, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self.forward(inputs)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_inputs = torch.multinomial(probs, num_samples=1)
            inputs = torch.cat((inputs, next_inputs), dim=1)
        return inputs

model = semiGPT(ko_vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for step in range(max_iteration):
    if step % eval_interval == 0 :
        losses = compute_loss_metrics()
        print(f'step : {step}, train loss : {losses["train"]:.4f}, val loss : {losses["eval"]:.4f}')

    example_x, example_y = batch_function("train")
    logits, loss = model(example_x, example_y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

inputs = torch.zeros((1,1), dtype=torch.long, device=device)
print(token_decode(model.generate(inputs, max_new_tokens=100)[0].tolist()))

step : 0, train loss : 8.3552, val loss : 8.3545
step : 300, train loss : 6.0840, val loss : 6.0975
step : 600, train loss : 4.7668, val loss : 4.7639
step : 900, train loss : 4.2182, val loss : 4.2527
step : 1200, train loss : 3.9751, val loss : 3.9498
step : 1500, train loss : 3.8110, val loss : 3.8238
step : 1800, train loss : 3.7284, val loss : 3.7125
step : 2100, train loss : 3.6604, val loss : 3.6491
step : 2400, train loss : 3.6326, val loss : 3.5933
step : 2700, train loss : 3.5815, val loss : 3.5987
step : 3000, train loss : 3.5630, val loss : 3.5603
step : 3300, train loss : 3.5380, val loss : 3.5452
step : 3600, train loss : 3.5020, val loss : 3.5332
step : 3900, train loss : 3.5006, val loss : 3.4992
step : 4200, train loss : 3.4940, val loss : 3.4854
step : 4500, train loss : 3.4734, val loss : 3.4746
step : 4800, train loss : 3.4622, val loss : 3.4616
step : 5100, train loss : 3.4861, val loss : 3.4637
step : 5400, train loss : 3.4492, val loss : 3.4620
step : 5700, train