In [1]:
!pip install transformers datasets accelerate -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m54.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m43.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m31.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
pip install --upgrade datasets fsspec

Collecting fsspec
  Downloading fsspec-2025.7.0-py3-none-any.whl.metadata (12 kB)


In [3]:
# -----------------------------------------------------------
# 0. 환경 설정 및 라이브러리 설치
# -----------------------------------------------------------
print("Step 0: Installing necessary libraries...")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer
import torch.nn.functional as F
from torch.optim import AdamW
from datasets import load_dataset
import math
from tqdm import tqdm
import warnings
import re # 정규표현식을 사용하기 위해 import

warnings.filterwarnings("ignore")
print("Setup complete!")

# -----------------------------------------------------------
# 1. BioGPT 모델 아키텍처 정의 (이전 코드와 동일)
# -----------------------------------------------------------
print("\nStep 1: Defining BioGPT model architecture...")
class BioLinear(nn.Module):
    """
    A biologically-inspired linear layer with Hebbian-like updates
    and weight update clipping for stability.
    """
    def __init__(self, in_features, out_features, alpha=0.1, lr=1e-5, decay=1e-6, clip_value=0.01):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.register_buffer('post_activation_ema', torch.zeros(out_features))
        self.alpha, self.lr, self.decay = alpha, lr, decay

        # --- 클리핑 값을 위한 파라미터 추가 ---
        self.clip_value = clip_value
        # ------------------------------------

        self.linear.weight.requires_grad = False
        self.linear.bias.requires_grad = False

    def forward(self, x):
        pre_activations = x
        post_activations = torch.relu(self.linear(pre_activations))

        if self.training:
            with torch.no_grad():
                avg_pre = pre_activations.mean(dim=[0, 1]) if pre_activations.dim() == 3 else pre_activations.mean(dim=0)
                avg_post = post_activations.mean(dim=[0, 1]) if post_activations.dim() == 3 else post_activations.mean(dim=0)

                self.post_activation_ema.data.mul_(1 - self.alpha).add_(avg_post, alpha=self.alpha)
                delta_w = self.lr * torch.einsum('i,j->ji', avg_pre, self.post_activation_ema)

                # --- 💡 가중치 업데이트 클리핑 적용 ---
                # 가중치 변화량의 크기를 self.clip_value로 제한합니다.
                torch.clamp(delta_w, -self.clip_value, self.clip_value, out=delta_w)
                # ------------------------------------

                self.linear.weight.data.add_(delta_w)
                self.linear.weight.data.mul_(1 - self.decay)

        return post_activations

class BioFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.bio_linear1 = BioLinear(d_model, d_ff)
        self.bio_linear2 = BioLinear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.bio_linear2(self.dropout(self.bio_linear1(x)))

class DecoderBlock(nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.feed_forward = BioFeedForward(d_model, d_ff, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask=None):
        attn_output, _ = self.self_attn(src, src, src, attn_mask=src_mask, need_weights=False)
        src = src + self.dropout(attn_output)
        src = self.norm1(src)
        ff_output = self.feed_forward(src)
        src = src + self.dropout(ff_output)
        src = self.norm2(src)
        return src

class BioGPT(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, d_ff, num_layers, max_seq_len, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.layers = nn.ModuleList([DecoderBlock(d_model, nhead, d_ff, dropout) for _ in range(num_layers)])
        self.final_norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.token_embedding.weight = self.lm_head.weight
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        b, t = src.size()
        tok_emb = self.token_embedding(src)
        pos = torch.arange(0, t, dtype=torch.long, device=src.device).unsqueeze(0)
        pos_emb = self.position_embedding(pos)
        x = self.dropout(tok_emb + pos_emb)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(t, device=src.device)
        for layer in self.layers:
            x = layer(x, src_mask=causal_mask)
        x = self.final_norm(x)
        logits = self.lm_head(x)
        return logits
print("Model architecture defined!")

# -----------------------------------------------------------
# 2. 하이퍼파라미터 및 설정
# -----------------------------------------------------------
print("\nStep 2: Setting up configurations...")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 50257
MAX_SEQ_LEN = 512 # Colab 메모리 한계를 고려해 시퀀스 길이를 512로 줄임

config_10m = {
    "vocab_size": VOCAB_SIZE,
    "max_seq_len": MAX_SEQ_LEN,
    "d_model": 192,
    "num_layers": 6,
    "nhead": 6,
    "d_ff": 192 * 4,
    "dropout": 0.1,
}

TRAIN_CONFIG = {
    "batch_size": 8,
    "num_epochs": 10,
    "lr": 1e-4, # AdamW의 학습률
}
print(f"Device set to: {DEVICE}")
print("Configurations set!")

# -----------------------------------------------------------
# 3. 데이터셋 로드 및 전처리
# -----------------------------------------------------------
print("\nStep 3: Loading and preparing WikiText dataset...")
# 데이터셋 로드 (wikitext-2는 작아서 Colab에서 빠르게 테스트하기 좋음)
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')

# 토크나이저 로드
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token # GPT-2는 pad 토큰이 없으므로 eos 토큰으로 설정

# ==================================================================
# Step 2: 개선된 위키 형식 제거 함수 (이 부분을 교체)
# ==================================================================
print("\nStep 2: Cleaning comprehensive wiki formatting from the dataset...")

def clean_wiki_text(examples):
    """
    정규표현식을 사용하여 다양한 위키 마크업을 제거합니다.
    - 제목 (e.g., = Title =)
    - 링크 (e.g., [[Article]])
    - 템플릿 (e.g., {{cite web}})
    - 외부 링크 (e.g., [http://...])
    - HTML/XML 태그 (e.g., <ref>)
    - 강조 (e.g., ''italic'', '''bold''')
    """
    cleaned_texts = []
    for text in examples["text"]:
        # 1. 제목 제거 (모든 레벨)
        text = re.sub(r'=+\s*([^=]+?)\s* =+', r'\1', text) # 제목 마크업만 제거하고 내용은 남김
        # 2. 템플릿 제거
        text = re.sub(r'\{\{.*?\}\}', '', text)
        # 3. 파일 및 이미지 링크 제거
        text = re.sub(r'\[\[(File|Image):.*?\]\]', '', text)
        # 4. 내부 링크의 표시 텍스트만 남김
        text = re.sub(r'\[\[(?:[^|\]]*\|)?([^\]]+)\]\]', r'\1', text)
        # 5. 외부 링크 제거
        text = re.sub(r'\[https?://.*?\]', '', text)
        # 6. 강조 마크업 제거
        text = re.sub(r"''+", '', text)
        # 7. HTML/XML 태그 제거
        text = re.sub(r'<.*?>', '', text)
        # 8. 연속된 공백 및 줄바꿈 정리
        text = re.sub(r'\s{2,}', ' ', text).strip()
        cleaned_texts.append(text)

    examples["text"] = cleaned_texts
    return examples

# .map()을 사용하여 데이터셋 전체에 클리닝 함수를 적용합니다.
cleaned_datasets = datasets.map(
    clean_wiki_text,
    batched=True,
    num_proc=4 # 여러 프로세스를 사용하여 처리 속도를 높입니다.
)

# 내용이 비게 된 텍스트를 필터링하여 제거합니다.
cleaned_datasets = cleaned_datasets.filter(lambda example: len(example['text'].strip()) > 0)

def tokenize_function(examples):
    # 모든 텍스트를 하나로 합침
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # MAX_SEQ_LEN 단위로 자름
    total_length = (total_length // MAX_SEQ_LEN) * MAX_SEQ_LEN
    # 잘린 텍스트를 다시 배치로 나눔
    result = {
        k: [t[i : i + MAX_SEQ_LEN] for i in range(0, total_length, MAX_SEQ_LEN)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

# 텍스트만 있는 행을 필터링하고 토큰화
tokenized_datasets = cleaned_datasets.map(
    lambda x: tokenizer(x["text"]),
    batched=True,
    remove_columns=["text"]
).filter(lambda x: len(x['input_ids']) > 0)

# 고정된 길이로 데이터셋 재구성
processed_datasets = tokenized_datasets.map(
    tokenize_function,
    batched=True,
)

processed_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# Re-initialize your DataLoaders with the new collate function


train_dataset = processed_datasets["train"]
val_dataset = processed_datasets["validation"]
train_loader = DataLoader(train_dataset, batch_size=TRAIN_CONFIG['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=TRAIN_CONFIG['batch_size'])

print("Dataset ready!")


# -----------------------------------------------------------
# 4. 모델, 옵티마이저, 손실 함수 초기화
# -----------------------------------------------------------
print("\nStep 4: Initializing model, optimizer, and loss function...")
model = BioGPT(**config_10m).to(DEVICE)

# 그래디언트가 필요한 파라미터만 옵티마이저에 전달
grad_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(grad_params, lr=TRAIN_CONFIG['lr'])

# 패딩 토큰은 손실 계산에서 제외
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# 파라미터 수 계산
total, trainable, bio = sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), 0
bio = total - trainable
print(f"Model initialized: Total params: {total/1e6:.2f}M, Trainable (Gradient): {trainable/1e6:.2f}M, Bio (Activation): {bio/1e6:.2f}M")


Step 0: Installing necessary libraries...
Setup complete!

Step 1: Defining BioGPT model architecture...
Model architecture defined!

Step 2: Setting up configurations...
Device set to: cuda
Configurations set!

Step 3: Loading and preparing WikiText dataset...


README.md: 0.00B [00:00, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]


Step 2: Cleaning comprehensive wiki formatting from the dataset...


Map (num_proc=4):   0%|          | 0/4358 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/36718 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/3760 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/36718 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3760 [00:00<?, ? examples/s]

Map:   0%|          | 0/2880 [00:00<?, ? examples/s]

Map:   0%|          | 0/23764 [00:00<?, ? examples/s]

Map:   0%|          | 0/2461 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2880 [00:00<?, ? examples/s]

Filter:   0%|          | 0/23764 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2461 [00:00<?, ? examples/s]

Map:   0%|          | 0/2880 [00:00<?, ? examples/s]

Map:   0%|          | 0/23764 [00:00<?, ? examples/s]

Map:   0%|          | 0/2461 [00:00<?, ? examples/s]

Dataset ready!

Step 4: Initializing model, optimizer, and loss function...
Model initialized: Total params: 12.42M, Trainable (Gradient): 10.64M, Bio (Activation): 1.78M


In [None]:
print("Step 4.5: Checking for NaN in the training dataset...")

nan_found = False
# train_loader를 순회하며 각 배치를 확인
for i, batch in enumerate(train_loader):
    # 'input_ids' 텐서에 NaN이 있는지 확인
    if torch.isnan(batch['input_ids']).any():
        print(f"NaN found in 'input_ids' of batch {i}!")
        nan_found = True
        break # 첫 번째 발견 시 중단

    # 'labels' 텐서에 NaN이 있는지 확인
    if torch.isnan(batch['labels']).any():
        print(f"NaN found in 'labels' of batch {i}!")
        nan_found = True
        break # 첫 번째 발견 시 중단

if not nan_found:
    print("--- No NaN found in the training dataset. The data is clean. ---")

# 이제 원래의 학습 루프를 시작할 수 있습니다.
# print("\nStep 5: Starting training and validation...")
# for epoch in range(TRAIN_CONFIG['num_epochs']):
# ...

In [4]:
# -----------------------------------------------------------
# 5. 학습 및 검증 루프 실행 (수정된 버전)
# -----------------------------------------------------------
print("\nStep 5: Starting training and validation...")

for epoch in range(TRAIN_CONFIG['num_epochs']):
    # --- Training ---
    model.train()
    total_train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} [Training]")

    for batch in progress_bar:
        optimizer.zero_grad()

        inputs = batch['input_ids'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        # 모델로부터 logits 출력 받기
        outputs = model(inputs)

        # --- Loss Calculation (Shifted) ---
        # 1. Logits의 마지막 토큰 예측은 사용하지 않음
        shift_logits = outputs[:, :-1, :].contiguous()
        # 2. Labels의 첫 번째 토큰(BOS)은 예측 대상이 아님
        shift_labels = labels[:, 1:].contiguous()

        # 3. 정렬된 logits과 labels로 손실 계산
        loss = criterion(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
        # ------------------------------------

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_train_loss += loss.item()
        progress_bar.set_postfix({'train_loss': f'{loss.item():.3f}'})

    avg_train_loss = total_train_loss / len(train_loader)

    # --- Validation ---
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} [Validation]")
        for batch in progress_bar:
            inputs = batch['input_ids'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            outputs = model(inputs)

            # --- Loss Calculation (Shifted) ---
            shift_logits = outputs[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            loss = criterion(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
            # ------------------------------------

            total_val_loss += loss.item()
            progress_bar.set_postfix({'val_loss': f'{loss.item():.3f}'})

    avg_val_loss = total_val_loss / len(val_loader)
    perplexity = math.exp(avg_val_loss)

    print(f"\nEpoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} | "
          f"Avg Train Loss: {avg_train_loss:.3f} | "
          f"Avg Val Loss: {avg_val_loss:.3f} | "
          f"Validation Perplexity: {perplexity:.2f}")

print("\n--- Training complete! ---")


Step 5: Starting training and validation...


Epoch 1/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.72it/s, train_loss=7.567]
Epoch 1/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 16.93it/s, val_loss=7.473]



Epoch 1/10 | Avg Train Loss: 8.038 | Avg Val Loss: 7.583 | Validation Perplexity: 1964.01


Epoch 2/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.866]
Epoch 2/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.38it/s, val_loss=7.464]



Epoch 2/10 | Avg Train Loss: 7.588 | Avg Val Loss: 7.560 | Validation Perplexity: 1919.57


Epoch 3/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.669]
Epoch 3/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.33it/s, val_loss=7.487]



Epoch 3/10 | Avg Train Loss: 7.543 | Avg Val Loss: 7.586 | Validation Perplexity: 1970.41


Epoch 4/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.636]
Epoch 4/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.26it/s, val_loss=7.493]



Epoch 4/10 | Avg Train Loss: 7.562 | Avg Val Loss: 7.605 | Validation Perplexity: 2007.32


Epoch 5/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.749]
Epoch 5/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.37it/s, val_loss=7.505]



Epoch 5/10 | Avg Train Loss: 7.560 | Avg Val Loss: 7.609 | Validation Perplexity: 2015.43


Epoch 6/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.72it/s, train_loss=7.664]
Epoch 6/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.35it/s, val_loss=7.518]



Epoch 6/10 | Avg Train Loss: 7.554 | Avg Val Loss: 7.617 | Validation Perplexity: 2033.19


Epoch 7/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.705]
Epoch 7/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.32it/s, val_loss=7.518]



Epoch 7/10 | Avg Train Loss: 7.550 | Avg Val Loss: 7.615 | Validation Perplexity: 2028.41


Epoch 8/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.72it/s, train_loss=7.449]
Epoch 8/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.23it/s, val_loss=7.518]



Epoch 8/10 | Avg Train Loss: 7.545 | Avg Val Loss: 7.619 | Validation Perplexity: 2036.62


Epoch 9/10 [Training]: 100%|██████████| 569/569 [02:01<00:00,  4.69it/s, train_loss=7.557]
Epoch 9/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.39it/s, val_loss=7.524]



Epoch 9/10 | Avg Train Loss: 7.543 | Avg Val Loss: 7.621 | Validation Perplexity: 2041.09


Epoch 10/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.72it/s, train_loss=7.523]
Epoch 10/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.38it/s, val_loss=7.531]


Epoch 10/10 | Avg Train Loss: 7.540 | Avg Val Loss: 7.621 | Validation Perplexity: 2041.53

--- Training complete! ---





In [5]:
ADDITIONAL_EPOCHS =10 # 추가로 학습할 에포크 수
start_epoch=10
total_epochs = start_epoch + ADDITIONAL_EPOCHS

for epoch in range(start_epoch, total_epochs):
    # --- Training ---
    model.train()
    total_train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} [Training]")

    for batch in progress_bar:
        optimizer.zero_grad()

        inputs = batch['input_ids'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        # 모델로부터 logits 출력 받기
        outputs = model(inputs)

        # --- Loss Calculation (Shifted) ---
        # 1. Logits의 마지막 토큰 예측은 사용하지 않음
        shift_logits = outputs[:, :-1, :].contiguous()
        # 2. Labels의 첫 번째 토큰(BOS)은 예측 대상이 아님
        shift_labels = labels[:, 1:].contiguous()

        # 3. 정렬된 logits과 labels로 손실 계산
        loss = criterion(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
        # ------------------------------------

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_train_loss += loss.item()
        progress_bar.set_postfix({'train_loss': f'{loss.item():.3f}'})

    avg_train_loss = total_train_loss / len(train_loader)

    # --- Validation ---
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} [Validation]")
        for batch in progress_bar:
            inputs = batch['input_ids'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            outputs = model(inputs)

            # --- Loss Calculation (Shifted) ---
            shift_logits = outputs[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            loss = criterion(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
            # ------------------------------------

            total_val_loss += loss.item()
            progress_bar.set_postfix({'val_loss': f'{loss.item():.3f}'})

    avg_val_loss = total_val_loss / len(val_loader)
    perplexity = math.exp(avg_val_loss)

    print(f"\nEpoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} | "
          f"Avg Train Loss: {avg_train_loss:.3f} | "
          f"Avg Val Loss: {avg_val_loss:.3f} | "
          f"Validation Perplexity: {perplexity:.2f}")

print("\n--- Training complete! ---")

Epoch 11/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.642]
Epoch 11/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.20it/s, val_loss=7.527]



Epoch 11/10 | Avg Train Loss: 7.537 | Avg Val Loss: 7.620 | Validation Perplexity: 2039.10


Epoch 12/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.113]
Epoch 12/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.40it/s, val_loss=7.520]



Epoch 12/10 | Avg Train Loss: 7.533 | Avg Val Loss: 7.615 | Validation Perplexity: 2028.05


Epoch 13/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.381]
Epoch 13/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.38it/s, val_loss=7.519]



Epoch 13/10 | Avg Train Loss: 7.531 | Avg Val Loss: 7.610 | Validation Perplexity: 2018.98


Epoch 14/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.324]
Epoch 14/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.32it/s, val_loss=7.523]



Epoch 14/10 | Avg Train Loss: 7.528 | Avg Val Loss: 7.608 | Validation Perplexity: 2014.17


Epoch 15/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.683]
Epoch 15/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.25it/s, val_loss=7.504]



Epoch 15/10 | Avg Train Loss: 7.526 | Avg Val Loss: 7.600 | Validation Perplexity: 1998.59


Epoch 16/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.809]
Epoch 16/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.38it/s, val_loss=7.512]



Epoch 16/10 | Avg Train Loss: 7.524 | Avg Val Loss: 7.597 | Validation Perplexity: 1993.16


Epoch 17/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.490]
Epoch 17/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.32it/s, val_loss=7.504]



Epoch 17/10 | Avg Train Loss: 7.521 | Avg Val Loss: 7.594 | Validation Perplexity: 1986.92


Epoch 18/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.409]
Epoch 18/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.26it/s, val_loss=7.505]



Epoch 18/10 | Avg Train Loss: 7.520 | Avg Val Loss: 7.591 | Validation Perplexity: 1980.28


Epoch 19/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.893]
Epoch 19/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.34it/s, val_loss=7.505]



Epoch 19/10 | Avg Train Loss: 7.518 | Avg Val Loss: 7.587 | Validation Perplexity: 1972.04


Epoch 20/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.73it/s, train_loss=7.208]
Epoch 20/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.37it/s, val_loss=7.503]


Epoch 20/10 | Avg Train Loss: 7.516 | Avg Val Loss: 7.586 | Validation Perplexity: 1970.19

--- Training complete! ---





In [11]:
# -----------------------------------------------------------
# 6. 모델 체크포인트 저장
# -----------------------------------------------------------
print("\nStep 6: Saving model checkpoint...")

# 모델의 가중치, 옵티마이저 상태, 에포크 정보 등을 딕셔너리로 묶어 저장합니다.
checkpoint = {
    'epoch': TRAIN_CONFIG['num_epochs'],
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config_10m # 모델 구조를 알기 위해 설정도 함께 저장
}

# torch.save를 이용해 파일로 저장합니다.
torch.save(checkpoint, 'biogpt_checkpoint.pth')

print("✅ Model checkpoint saved to biogpt_checkpoint.pth")


Step 6: Saving model checkpoint...
✅ Model checkpoint saved to biogpt_checkpoint.pth


In [12]:
# -----------------------------------------------------------
# 7. 모델 불러와서 추가 학습 진행
# -----------------------------------------------------------
print("\nStep 7: Loading model for continued training...")

# 먼저, 저장할 때와 동일한 구조로 모델과 옵티마이저를 다시 생성합니다.
model = BioGPT(**config_10m).to(DEVICE)
grad_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(grad_params, lr=TRAIN_CONFIG['lr'])

# torch.load로 체크포인트 파일을 불러옵니다.
checkpoint = torch.load('biogpt_checkpoint.pth', map_location=DEVICE)

# 저장된 state_dict를 모델과 옵티마이저에 로드합니다.
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] # 마지막으로 학습한 에포크

print(f"✅ Model loaded. Resuming training from epoch {start_epoch + 1}.")

# --- 추가 학습 진행 ---
ADDITIONAL_EPOCHS = 30 # 추가로 학습할 에포크 수
total_epochs = start_epoch + ADDITIONAL_EPOCHS

for epoch in range(start_epoch, total_epochs):
     # --- Training ---
    model.train()
    total_train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} [Training]")

    for batch in progress_bar:
        optimizer.zero_grad()

        inputs = batch['input_ids'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        # 모델로부터 logits 출력 받기
        outputs = model(inputs)

        # --- Loss Calculation (Shifted) ---
        # 1. Logits의 마지막 토큰 예측은 사용하지 않음
        shift_logits = outputs[:, :-1, :].contiguous()
        # 2. Labels의 첫 번째 토큰(BOS)은 예측 대상이 아님
        shift_labels = labels[:, 1:].contiguous()

        # 3. 정렬된 logits과 labels로 손실 계산
        loss = criterion(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
        # ------------------------------------

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_train_loss += loss.item()
        progress_bar.set_postfix({'train_loss': f'{loss.item():.3f}'})

    avg_train_loss = total_train_loss / len(train_loader)

    # --- Validation ---
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} [Validation]")
        for batch in progress_bar:
            inputs = batch['input_ids'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            outputs = model(inputs)

            # --- Loss Calculation (Shifted) ---
            shift_logits = outputs[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            loss = criterion(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
            # ------------------------------------

            total_val_loss += loss.item()
            progress_bar.set_postfix({'val_loss': f'{loss.item():.3f}'})

    avg_val_loss = total_val_loss / len(val_loader)
    perplexity = math.exp(avg_val_loss)

    print(f"\nEpoch {epoch+1}/{TRAIN_CONFIG['num_epochs']} | "
          f"Avg Train Loss: {avg_train_loss:.3f} | "
          f"Avg Val Loss: {avg_val_loss:.3f} | "
          f"Validation Perplexity: {perplexity:.2f}")

print("\n--- Training complete! ---")


Step 7: Loading model for continued training...
✅ Model loaded. Resuming training from epoch 11.


Epoch 11/10 [Training]: 100%|██████████| 569/569 [02:01<00:00,  4.69it/s, train_loss=7.485]
Epoch 11/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.28it/s, val_loss=7.505]



Epoch 11/10 | Avg Train Loss: 7.515 | Avg Val Loss: 7.585 | Validation Perplexity: 1969.21


Epoch 12/10 [Training]: 100%|██████████| 569/569 [02:01<00:00,  4.70it/s, train_loss=7.423]
Epoch 12/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.27it/s, val_loss=7.511]



Epoch 12/10 | Avg Train Loss: 7.513 | Avg Val Loss: 7.583 | Validation Perplexity: 1964.44


Epoch 13/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.517]
Epoch 13/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.29it/s, val_loss=7.505]



Epoch 13/10 | Avg Train Loss: 7.512 | Avg Val Loss: 7.584 | Validation Perplexity: 1965.79


Epoch 14/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.70it/s, train_loss=7.711]
Epoch 14/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.40it/s, val_loss=7.504]



Epoch 14/10 | Avg Train Loss: 7.512 | Avg Val Loss: 7.583 | Validation Perplexity: 1965.40


Epoch 15/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=8.033]
Epoch 15/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.40it/s, val_loss=7.506]



Epoch 15/10 | Avg Train Loss: 7.511 | Avg Val Loss: 7.585 | Validation Perplexity: 1968.59


Epoch 16/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.347]
Epoch 16/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.33it/s, val_loss=7.511]



Epoch 16/10 | Avg Train Loss: 7.509 | Avg Val Loss: 7.585 | Validation Perplexity: 1969.00


Epoch 17/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.390]
Epoch 17/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.44it/s, val_loss=7.514]



Epoch 17/10 | Avg Train Loss: 7.509 | Avg Val Loss: 7.585 | Validation Perplexity: 1969.17


Epoch 18/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.517]
Epoch 18/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.43it/s, val_loss=7.508]



Epoch 18/10 | Avg Train Loss: 7.508 | Avg Val Loss: 7.586 | Validation Perplexity: 1970.95


Epoch 19/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=8.172]
Epoch 19/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.35it/s, val_loss=7.515]



Epoch 19/10 | Avg Train Loss: 7.509 | Avg Val Loss: 7.588 | Validation Perplexity: 1973.55


Epoch 20/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.260]
Epoch 20/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.43it/s, val_loss=7.513]



Epoch 20/10 | Avg Train Loss: 7.507 | Avg Val Loss: 7.589 | Validation Perplexity: 1976.72


Epoch 21/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.290]
Epoch 21/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.40it/s, val_loss=7.516]



Epoch 21/10 | Avg Train Loss: 7.507 | Avg Val Loss: 7.591 | Validation Perplexity: 1979.46


Epoch 22/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.306]
Epoch 22/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.25it/s, val_loss=7.527]



Epoch 22/10 | Avg Train Loss: 7.506 | Avg Val Loss: 7.590 | Validation Perplexity: 1978.76


Epoch 23/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.964]
Epoch 23/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.42it/s, val_loss=7.516]



Epoch 23/10 | Avg Train Loss: 7.507 | Avg Val Loss: 7.591 | Validation Perplexity: 1980.75


Epoch 24/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.767]
Epoch 24/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.44it/s, val_loss=7.522]



Epoch 24/10 | Avg Train Loss: 7.507 | Avg Val Loss: 7.593 | Validation Perplexity: 1983.47


Epoch 25/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.670]
Epoch 25/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.27it/s, val_loss=7.526]



Epoch 25/10 | Avg Train Loss: 7.506 | Avg Val Loss: 7.593 | Validation Perplexity: 1984.84


Epoch 26/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.696]
Epoch 26/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.38it/s, val_loss=7.518]



Epoch 26/10 | Avg Train Loss: 7.506 | Avg Val Loss: 7.594 | Validation Perplexity: 1985.28


Epoch 27/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.702]
Epoch 27/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.42it/s, val_loss=7.518]



Epoch 27/10 | Avg Train Loss: 7.506 | Avg Val Loss: 7.595 | Validation Perplexity: 1988.54


Epoch 28/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.656]
Epoch 28/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.20it/s, val_loss=7.520]



Epoch 28/10 | Avg Train Loss: 7.506 | Avg Val Loss: 7.596 | Validation Perplexity: 1989.96


Epoch 29/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.650]
Epoch 29/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.41it/s, val_loss=7.522]



Epoch 29/10 | Avg Train Loss: 7.506 | Avg Val Loss: 7.596 | Validation Perplexity: 1990.05


Epoch 30/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.526]
Epoch 30/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.29it/s, val_loss=7.524]



Epoch 30/10 | Avg Train Loss: 7.505 | Avg Val Loss: 7.597 | Validation Perplexity: 1992.62


Epoch 31/10 [Training]: 100%|██████████| 569/569 [02:01<00:00,  4.68it/s, train_loss=7.671]
Epoch 31/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.32it/s, val_loss=7.528]



Epoch 31/10 | Avg Train Loss: 7.505 | Avg Val Loss: 7.597 | Validation Perplexity: 1992.34


Epoch 32/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.701]
Epoch 32/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.41it/s, val_loss=7.531]



Epoch 32/10 | Avg Train Loss: 7.505 | Avg Val Loss: 7.598 | Validation Perplexity: 1994.67


Epoch 33/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.908]
Epoch 33/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.41it/s, val_loss=7.530]



Epoch 33/10 | Avg Train Loss: 7.505 | Avg Val Loss: 7.596 | Validation Perplexity: 1990.62


Epoch 34/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.662]
Epoch 34/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.39it/s, val_loss=7.530]



Epoch 34/10 | Avg Train Loss: 7.505 | Avg Val Loss: 7.598 | Validation Perplexity: 1994.72


Epoch 35/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.405]
Epoch 35/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.35it/s, val_loss=7.527]



Epoch 35/10 | Avg Train Loss: 7.504 | Avg Val Loss: 7.599 | Validation Perplexity: 1996.67


Epoch 36/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.441]
Epoch 36/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.41it/s, val_loss=7.529]



Epoch 36/10 | Avg Train Loss: 7.504 | Avg Val Loss: 7.598 | Validation Perplexity: 1994.38


Epoch 37/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.825]
Epoch 37/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.48it/s, val_loss=7.525]



Epoch 37/10 | Avg Train Loss: 7.505 | Avg Val Loss: 7.598 | Validation Perplexity: 1993.40


Epoch 38/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.791]
Epoch 38/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.39it/s, val_loss=7.523]



Epoch 38/10 | Avg Train Loss: 7.505 | Avg Val Loss: 7.597 | Validation Perplexity: 1992.70


Epoch 39/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.591]
Epoch 39/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.36it/s, val_loss=7.522]



Epoch 39/10 | Avg Train Loss: 7.504 | Avg Val Loss: 7.597 | Validation Perplexity: 1992.88


Epoch 40/10 [Training]: 100%|██████████| 569/569 [02:00<00:00,  4.71it/s, train_loss=7.113]
Epoch 40/10 [Validation]: 100%|██████████| 59/59 [00:03<00:00, 17.41it/s, val_loss=7.525]


Epoch 40/10 | Avg Train Loss: 7.503 | Avg Val Loss: 7.599 | Validation Perplexity: 1995.35

--- Training complete! ---





In [7]:
# -----------------------------------------------------------
# 8. 학습된 모델로 텍스트 생성
# -----------------------------------------------------------
print("\nStep 8: Generating text with the trained model...")

def generate_text(model, tokenizer, prompt, max_length=100, top_k=50, device="cpu"):
    """
    학습된 모델을 사용해 텍스트를 생성하는 함수

    Args:
        model (nn.Module): 학습된 BioGPT 모델
        tokenizer: GPT-2 토크나이저
        prompt (str): 생성을 시작할 텍스트
        max_length (int): 생성할 최대 토큰 길이
        top_k (int): Top-k 샘플링에서 사용할 k값
        device (str): 연산을 수행할 장치 (e.g., "cuda" or "cpu")
    """
    model.eval()  # 모델을 평가 모드로 설정
    model.to(device)

    # 입력 프롬프트를 토큰화하여 텐서로 변환
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    with torch.no_grad(): # 그래디언트 계산 비활성화
        for _ in range(max_length):
            # 현재까지 생성된 시퀀스를 모델에 입력
            outputs = model(input_ids)

            # 다음 토큰 예측을 위해 마지막 시점의 로짓(logits)만 사용
            next_token_logits = outputs[:, -1, :]

            # --- Top-k 샘플링 ---
            # 1. 로짓 값 중 상위 k개만 선택
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)

            # 2. k개 외의 로짓 값은 -inf로 필터링
            filter_logits = torch.full_like(next_token_logits, -float('Inf'))
            filter_logits.scatter_(1, top_k_indices, top_k_logits)

            # 3. 필터링된 로짓에 softmax를 적용해 확률 분포 생성
            probabilities = F.softmax(filter_logits, dim=-1)

            # 4. 생성된 확률 분포에 따라 다음 토큰을 샘플링
            next_token = torch.multinomial(probabilities, num_samples=1)

            # 생성된 토큰을 기존 시퀀스에 추가
            input_ids = torch.cat((input_ids, next_token), dim=1)

            # 만약 EOS(End-of-Sentence) 토큰이 생성되면 루프 종료
            if next_token.item() == tokenizer.eos_token_id:
                break

    # 최종적으로 생성된 토큰 시퀀스를 다시 텍스트로 디코딩
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


# --- 텍스트 생성 실행 ---
prompt = "i will kill"
generated_output = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    max_length=80,    # 생성할 문장의 최대 길이 (토큰 기준)
    top_k=50,         # 상위 50개 단어 중에서 샘플링
    device=DEVICE
)

print(f"\n🎬 Prompt: '{prompt}'")
print(f"🤖 Generated Text: '{generated_output}'")


Step 8: Generating text with the trained model...

🎬 Prompt: 'i will kill'
🤖 Generated Text: 'i will kill the . and ( . as@ " a in to the in , of the the to , ) ( the of had@ a the the the . the- the the the of , the@ of to as that for . and = " The was and to "- was the thes . is and and in@ , as ' with the , in be of . the- to to his which'


In [13]:

# --- 텍스트 생성 실행 ---
prompt = "gps technoloy "
generated_output = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    max_length=80,    # 생성할 문장의 최대 길이 (토큰 기준)
    top_k=50,         # 상위 50개 단어 중에서 샘플링
    device=DEVICE
)

print(f"\n🎬 Prompt: '{prompt}'")
print(f"🤖 Generated Text: '{generated_output}'")


🎬 Prompt: 'gps technoloy '
🤖 Generated Text: 'gps technoloy  to@ in = ' the to The the- hass that a@ the from . @ to ,= to that@ a the , a , . @ is the who their@ @ to that fors the@ of ands had not in . " . .@ that , the he , but for . and by " wass not is , not in The " the the in and In'
