### 임포트 & 환경설정

In [42]:
import os
import re
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

# 학습 때 사용한 최대 길이
MAX_SEQ_LEN = 256

DEVICE: cuda


### 텍스트 전처리

In [43]:
# 셀 2:

def simple_tokenize(text: str):
    """
    LSTM 학습 때와 최대한 비슷한 규칙으로 토큰화.
    필요하면 너가 학습에 썼던 전처리 규칙에 맞춰 조금 수정해도 됨.
    """
    text = text.lower()
    tokens = re.findall(r"[a-z0-9']+", text)
    return tokens

def encode_text_lstm(text: str, stoi, pad_idx: int, unk_idx: int, max_len: int = MAX_SEQ_LEN, device=DEVICE):
    tokens = simple_tokenize(text)
    ids = [stoi.get(tok, unk_idx) for tok in tokens][:max_len]
    if len(ids) < max_len:
        ids += [pad_idx] * (max_len - len(ids))
    return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, T)

### LSTM 모델 로드

In [44]:
LSTM_MODEL_PATH = "model/lstm.pt"

assert os.path.exists(LSTM_MODEL_PATH), f"LSTM 모델 파일을 찾을 수 없습니다: {LSTM_MODEL_PATH}"

ckpt_lstm = torch.load(LSTM_MODEL_PATH, map_location=DEVICE)

best_epoch_lstm = ckpt_lstm["epoch"]
model_state_lstm = ckpt_lstm["model_state_dict"]
stoi = ckpt_lstm["vocab"]["stoi"]
itos = ckpt_lstm["vocab"]["itos"]
genre_to_idx = ckpt_lstm["genre_to_idx"]
idx_to_genre = ckpt_lstm["idx_to_genre"]
config_lstm = ckpt_lstm["config"]

print(f"✅ LSTM 체크포인트 로드 (epoch={best_epoch_lstm})")
print("장르 개수:", config_lstm["num_labels"])

EMBED_DIM   = config_lstm["embed_dim"]
HIDDEN_DIM  = config_lstm["hidden_dim"]
NUM_LAYERS  = config_lstm["num_layers"]
BIDIRECTIONAL = config_lstm["bidirectional"]
NUM_LABELS  = config_lstm["num_labels"]
PAD_IDX     = config_lstm["pad_idx"]
MAX_SEQ_LEN = config_lstm["max_seq_len"]

DROPOUT     = 0.4
UNK_IDX     = stoi.get("<unk>", 0)

class LSTMTextClassifier(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dim: int,
        num_layers: int,
        num_labels: int,
        dropout: float = 0.0,
        bidirectional: bool = False,
        pad_idx: int = 0,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        self.dropout = nn.Dropout(dropout)
        factor = 2 if bidirectional else 1
        self.fc = nn.Linear(hidden_dim * factor, num_labels)

    def forward(self, input_ids):
        emb = self.embedding(input_ids)      # (B, T, E)
        output, _ = self.lstm(emb)          # (B, T, H * num_directions)
        last_hidden = output[:, -1, :]      # 마지막 타임스텝
        last_hidden = self.dropout(last_hidden)
        logits = self.fc(last_hidden)       # (B, num_labels)
        return logits

VOCAB_SIZE = max(stoi.values()) + 1  # stoi가 0~N-1 구조라고 가정

lstm_model = LSTMTextClassifier(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    num_labels=NUM_LABELS,
    dropout=DROPOUT,
    bidirectional=BIDIRECTIONAL,
    pad_idx=PAD_IDX,
).to(DEVICE)

lstm_model.load_state_dict(model_state_lstm)
lstm_model.eval()

print("✅ LSTM 모델 로드 완료")

✅ LSTM 체크포인트 로드 (epoch=8)
장르 개수: 27
✅ LSTM high 모델 로드 완료


  ckpt_lstm = torch.load(LSTM_MODEL_PATH, map_location=DEVICE)


In [45]:
LSTM_MODEL_PATH_CW = "model/lstm_CW.pt"

assert os.path.exists(LSTM_MODEL_PATH_CW), f"LSTM CW 모델 파일을 찾을 수 없습니다: {LSTM_MODEL_PATH_CW}"

ckpt_lstm_CW = torch.load(LSTM_MODEL_PATH_CW, map_location=DEVICE)

best_epoch_lstm_CW = ckpt_lstm_CW["epoch"]
model_state_lstm_CW = ckpt_lstm_CW["model_state_dict"]
stoi_CW = ckpt_lstm_CW["vocab"]["stoi"]
itos_CW = ckpt_lstm_CW["vocab"]["itos"]
genre_to_idx_CW = ckpt_lstm_CW["genre_to_idx"]
idx_to_genre_CW = ckpt_lstm_CW["idx_to_genre"]
config_lstm_CW = ckpt_lstm_CW["config"]

print(f"✅ LSTM CW 체크포인트 로드 (epoch={best_epoch_lstm_CW})")
print("장르 개수:", config_lstm_CW["num_labels"])

EMBED_DIM_CW   = config_lstm_CW["embed_dim"]
HIDDEN_DIM_CW  = config_lstm_CW["hidden_dim"]
NUM_LAYERS_CW  = config_lstm_CW["num_layers"]
BIDIRECTIONAL_CW = config_lstm_CW["bidirectional"]
NUM_LABELS_CW  = config_lstm_CW["num_labels"]
PAD_IDX_CW     = config_lstm_CW["pad_idx"]
MAX_SEQ_LEN_CW_MODEL = config_lstm_CW["max_seq_len"]

DROPOUT_CW     = 0.4
UNK_IDX_CW     = stoi_CW.get("<unk>", 0)

VOCAB_SIZE_CW = max(stoi_CW.values()) + 1

lstm_model_CW = LSTMTextClassifier(
    vocab_size=VOCAB_SIZE_CW,
    embed_dim=EMBED_DIM_CW,
    hidden_dim=HIDDEN_DIM_CW,
    num_layers=NUM_LAYERS_CW,
    num_labels=NUM_LABELS_CW,
    dropout=DROPOUT_CW,
    bidirectional=BIDIRECTIONAL_CW,
    pad_idx=PAD_IDX_CW,
).to(DEVICE)

lstm_model_CW.load_state_dict(model_state_lstm_CW)
lstm_model_CW.eval()

print("✅ LSTM CW 모델 로드 완료")

✅ LSTM V8 체크포인트 로드 (epoch=9)
장르 개수: 27
✅ LSTM low 모델 로드 완료


  ckpt_lstm_v8 = torch.load(LSTM_MODEL_PATH_V8, map_location=DEVICE)


### LSTM 예측 함수

In [46]:
@torch.no_grad()
def predict_genre_lstm(text: str, model, stoi_vocab, pad_idx_val, unk_idx_val, max_len_val, idx_to_genre_map, device=DEVICE):
    """
    LSTM 모델로 영화 줄거리를 입력받아 장르를 예측.
    return: (pred_label(str), confidence(float))
    """
    input_ids = encode_text_lstm(
        text,
        stoi=stoi_vocab,
        pad_idx=pad_idx_val,
        unk_idx=unk_idx_val,
        max_len=max_len_val,
        device=device,
    )
    logits = model(input_ids)
    probs = torch.softmax(logits, dim=-1)
    pred_id = int(torch.argmax(probs, dim=-1).item())
    conf = float(probs[0, pred_id].item())
    pred_label = idx_to_genre_map[pred_id]
    return pred_label, conf

# 간단 테스트 (기존 LSTM 모델 사용)
example = "A young boy discovers he has magical powers and attends a school for wizards."
label, conf = predict_genre_lstm(
    example,
    model=lstm_model,
    stoi_vocab=stoi,
    pad_idx_val=PAD_IDX,
    unk_idx_val=UNK_IDX,
    max_len_val=MAX_SEQ_LEN,
    idx_to_genre_map=idx_to_genre,
    device=DEVICE
)
print("[LSTM high 테스트]")
print("입력:", example)
print("예측 장르:", label, f"(confidence={conf:.3f})")

[LSTM high 테스트]
입력: A young boy discovers he has magical powers and attends a school for wizards.
예측 장르: comedy (confidence=0.118)


### BERT 모델 로드

In [47]:
BERT_SAVE_DIR = "model/bert"

assert os.path.exists(BERT_SAVE_DIR), f"BERT 저장 폴더를 찾을 수 없습니다: {BERT_SAVE_DIR}"

bert_tokenizer = AutoTokenizer.from_pretrained(BERT_SAVE_DIR)
bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_SAVE_DIR)
bert_model.to(DEVICE)
bert_model.eval()

print("✅ BERT 모델 / 토크나이저 로드 완료:", BERT_SAVE_DIR)

# id2label 정리 (int 인덱스로 접근)
raw_id2label = bert_model.config.id2label  # {'0': 'Action', '1': 'Comedy', ...} 또는 {0: 'Action', ...}

id2label_bert = {}
for k, v in raw_id2label.items():
    try:
        idx = int(k)
    except Exception:
        idx = k
    id2label_bert[idx] = v

NUM_LABELS_BERT = len(id2label_bert)
print("BERT 클래스 수:", NUM_LABELS_BERT)
print("예시 라벨 매핑:", list(id2label_bert.items())[:5])


✅ BERT 모델 / 토크나이저 로드 완료: model/bert_base_full
BERT 클래스 수: 27
예시 라벨 매핑: [(0, 'action'), (1, 'adult'), (2, 'adventure'), (3, 'animation'), (4, 'biography')]


### BERT 예측 함수

In [48]:
@torch.no_grad()
def predict_genre_bert(text: str):
    """
    BERT 모델로 영화 줄거리를 입력받아 장르를 예측.
    return: (pred_label(str), confidence(float))
    """
    enc = bert_tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=MAX_SEQ_LEN,
        return_tensors="pt"
    )
    enc = {k: v.to(DEVICE) for k, v in enc.items()}

    outputs = bert_model(**enc)
    logits = outputs.logits              # (1, num_labels)
    probs = torch.softmax(logits, dim=-1)
    pred_id = int(torch.argmax(probs, dim=-1).item())
    conf = float(probs[0, pred_id].item())
    pred_label = id2label_bert[pred_id]
    return pred_label, conf

# 간단 테스트
example = "A young boy discovers he has magical powers and attends a school for wizards."
label, conf = predict_genre_bert(example)
print("[BERT 테스트]")
print("입력:", example)
print("예측 장르:", label, f"(confidence={conf:.3f})")


[BERT 테스트]
입력: A young boy discovers he has magical powers and attends a school for wizards.
예측 장르: fantasy (confidence=0.331)


### 예측 (LSTM / BERT 선택)

In [52]:
print("\n--- 영화 줄거리 -> 장르 예측 ---")
print("사용할 모델을 선택하고, 영화 줄거리를 입력하면 장르를 예측합니다.")
print("모델: 'lstm', 'lstm_CW' 또는 'bert'")
print("종료: 모델에서 'quit' 또는 빈 줄\n")

while True:
    model_choice = input("사용할 모델 선택 (lstm / lstm_CW / bert, 종료: quit): ").strip().lower()
    if model_choice in ("quit", "exit", ""):
        print("종료합니다.")
        break
    if model_choice not in ("lstm", "lstm_CW", "bert"):
        print("❗ 잘못된 입력입니다. 'lstm', 'lstm_CW' 또는 'bert' 중 하나를 입력하세요.\n")
        continue

    while True:
        text = input(f"[{model_choice.upper()}] 영화 줄거리 입력 (종료: 빈 줄): ").strip()
        if text == "":
            print(f"[{model_choice.upper()}] 입력 종료.\n")
            break

        if model_choice == "lstm":
            label, conf = predict_genre_lstm(
                text,
                model=lstm_model,
                stoi_vocab=stoi,
                pad_idx_val=PAD_IDX,
                unk_idx_val=UNK_IDX,
                max_len_val=MAX_SEQ_LEN,
                idx_to_genre_map=idx_to_genre,
                device=DEVICE
            )
        elif model_choice == "lstm_CW":
            label, conf = predict_genre_lstm(
                text,
                model=lstm_model_CW,
                stoi_vocab=stoi_CW,
                pad_idx_val=PAD_IDX_CW,
                unk_idx_val=UNK_IDX_CW,
                max_len_val=MAX_SEQ_LEN_CW_MODEL,
                idx_to_genre_map=idx_to_genre_CW,
                device=DEVICE
            )
        else: # bert
            label, conf = predict_genre_bert(text)

        print(f"→ 예측 장르: {label} (confidence={conf:.3f})\n")


--- 영화 줄거리 -> 장르 예측 ---
사용할 모델을 선택하고, 영화 줄거리를 입력하면 장르를 예측합니다.
모델: 'lstm_high', 'lstm_low' 또는 'bert'
종료: 모델에서 'quit' 또는 빈 줄

사용할 모델 선택 (lstm_high / lstm_low / bert, 종료: quit): lstm_high
[LSTM_HIGH] 영화 줄거리 입력 (종료: 빈 줄): Jack, a free-spirited painter, falls in love at first sight with Rose. A passionate love story.
→ 예측 장르: comedy (confidence=0.118)

[LSTM_HIGH] 영화 줄거리 입력 (종료: 빈 줄): 
[LSTM_HIGH] 입력 종료.

사용할 모델 선택 (lstm_high / lstm_low / bert, 종료: quit): lstm_low
[LSTM_LOW] 영화 줄거리 입력 (종료: 빈 줄): Jack, a free-spirited painter, falls in love at first sight with Rose. A passionate love story.
→ 예측 장르: comedy (confidence=0.075)

[LSTM_LOW] 영화 줄거리 입력 (종료: 빈 줄): 
[LSTM_LOW] 입력 종료.

사용할 모델 선택 (lstm_high / lstm_low / bert, 종료: quit): bert
[BERT] 영화 줄거리 입력 (종료: 빈 줄): Jack, a free-spirited painter, falls in love at first sight with Rose. A passionate love story.
→ 예측 장르: drama (confidence=0.815)

[BERT] 영화 줄거리 입력 (종료: 빈 줄): 
[BERT] 입력 종료.

사용할 모델 선택 (lstm_high / lstm_low / bert, 종료: quit): 
종료합니다.
