# IMDB 감정 분류 실습

In [1]:
!pip install torch torchtext datasets

Collecting torchtext
  Downloading torchtext-0.18.0-cp311-cp311-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets)
  Downloading aiohappyeyeballs-2.3.7-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets)
  Do

# 1. Data download

In [7]:
import torch
from torch import nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from datasets import load_dataset
from torch.utils.data import DataLoader

# IMDB 데이터셋 로드 및 토크나이저 설정
dataset = load_dataset("imdb")
tokenizer = get_tokenizer("basic_english")

  from .autonotebook import tqdm as notebook_tqdm


# 2. Pre-process & DataLoader

In [22]:
# 단어 사전 구축
def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(dataset['train']['text']), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# 데이터 전처리
def preprocess(text):
    return torch.tensor(vocab(tokenizer(text)), dtype=torch.long)

# 데이터 로더 설정
def collate_fn(batch):
    texts, labels = zip(*batch)
    texts = [preprocess(text) for text in texts]
    labels = torch.tensor(labels, dtype=torch.float)
    return nn.utils.rnn.pad_sequence(texts, batch_first=True), labels

train_loader = DataLoader(list(zip(dataset['train']['text'], dataset['train']['label'])), 
                          batch_size=32, collate_fn=collate_fn, shuffle=True)

# 3. model
# 3-1. RNN

In [8]:
class RNNCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RNNCell, self).__init__()
        self.hidden_size = hidden_size
        self.W_ih = nn.Linear(input_size, hidden_size)
        self.W_hh = nn.Linear(hidden_size, hidden_size)
        self.tanh = nn.Tanh()

    def forward(self, x, hidden):
        hidden = self.tanh(self.W_ih(x) + self.W_hh(hidden))
        return hidden

class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn_cell = RNNCell(embedding_dim, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.embedding(x)
        hidden = torch.zeros(x.size(0), self.hidden_size).to(x.device)
        for t in range(x.size(1)):
            hidden = self.rnn_cell(x[:, t, :], hidden)
        output = self.fc(hidden)
        return output

# 3-2. LSTM

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.hidden_size = hidden_size
        self.W_ii = nn.Linear(input_size, hidden_size)
        self.W_if = nn.Linear(input_size, hidden_size)
        self.W_ig = nn.Linear(input_size, hidden_size)
        self.W_io = nn.Linear(input_size, hidden_size)
        
        self.W_hi = nn.Linear(hidden_size, hidden_size)
        self.W_hf = nn.Linear(hidden_size, hidden_size)
        self.W_hg = nn.Linear(hidden_size, hidden_size)
        self.W_ho = nn.Linear(hidden_size, hidden_size)
        
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x, hidden, cell):
        i = self.sigmoid(self.W_ii(x) + self.W_hi(hidden))
        f = self.sigmoid(self.W_if(x) + self.W_hf(hidden))
        g = self.tanh(self.W_ig(x) + self.W_hg(hidden))
        o = self.sigmoid(self.W_io(x) + self.W_ho(hidden))
        cell = f * cell + i * g
        hidden = o * self.tanh(cell)
        return hidden, cell

class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm_cell = LSTMCell(embedding_dim, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.embedding(x)
        hidden = torch.zeros(x.size(0), self.hidden_size).to(x.device)
        cell = torch.zeros(x.size(0), self.hidden_size).to(x.device)
        for t in range(x.size(1)):
            hidden, cell = self.lstm_cell(x[:, t, :], hidden, cell)
        output = self.fc(hidden)
        return output

# 3-3. GRU

In [None]:
class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRUCell, self).__init__()
        self.hidden_size = hidden_size
        self.W_ir = nn.Linear(input_size, hidden_size)
        self.W_iz = nn.Linear(input_size, hidden_size)
        self.W_in = nn.Linear(input_size, hidden_size)
        
        self.W_hr = nn.Linear(hidden_size, hidden_size)
        self.W_hz = nn.Linear(hidden_size, hidden_size)
        self.W_hn = nn.Linear(hidden_size, hidden_size)
        
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x, hidden):
        r = self.sigmoid(self.W_ir(x) + self.W_hr(hidden))
        z = self.sigmoid(self.W_iz(x) + self.W_hz(hidden))
        n = self.tanh(self.W_in(x) + r * self.W_hn(hidden))
        hidden = (1 - z) * n + z * hidden
        return hidden

class GRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size):
        super(GRU, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru_cell = GRUCell(embedding_dim, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.embedding(x)
        hidden = torch.zeros(x.size(0), self.hidden_size).to(x.device)
        for t in range(x.size(1)):
            hidden = self.gru_cell(x[:, t, :], hidden)
        output = self.fc(hidden)
        return output

# 4. model & Hyper Parameters & loss & optimizer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = len(vocab)
embedding_dim = 128
hidden_size = 128
output_size = 1

## RNN, LSTM, GRU 중 하나를 선택
# model = RNN(vocab_size=vocab_size, embedding_dim=embedding_dim, hidden_size=hidden_size, output_size=output_size).to(device)
# model = LSTM(vocab_size=vocab_size, embedding_dim=embedding_dim, hidden_size=hidden_size, output_size=output_size).to(device)
model = GRU(vocab_size=vocab_size, embedding_dim=embedding_dim, hidden_size=hidden_size, output_size=output_size).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())

# 5. train + (val) <- 직접 구현해보기

In [13]:
num_epochs = 10
log_interval = 100  # 중간 결과를 출력할 배치 간격

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for batch_idx, (texts, labels) in enumerate(train_loader):
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        loss = criterion(outputs.squeeze(), labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if (batch_idx + 1) % log_interval == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    print(f'Epoch [{epoch+1}/{num_epochs}] Complete, Average Loss: {epoch_loss/len(train_loader):.4f}')

Epoch [1/10], Step [10/782], Loss: 0.7042
Epoch [1/10], Step [20/782], Loss: 0.7562
Epoch [1/10], Step [30/782], Loss: 0.7053
Epoch [1/10], Step [40/782], Loss: 0.6875
Epoch [1/10], Step [50/782], Loss: 0.7044
Epoch [1/10], Step [60/782], Loss: 0.7264
Epoch [1/10], Step [70/782], Loss: 0.7282
Epoch [1/10], Step [80/782], Loss: 0.6826
Epoch [1/10], Step [90/782], Loss: 0.6959
Epoch [1/10], Step [100/782], Loss: 0.6926
Epoch [1/10], Step [110/782], Loss: 0.6994
Epoch [1/10], Step [120/782], Loss: 0.6947
Epoch [1/10], Step [130/782], Loss: 0.7501
Epoch [1/10], Step [140/782], Loss: 0.6793
Epoch [1/10], Step [150/782], Loss: 0.6980
Epoch [1/10], Step [160/782], Loss: 0.6915
Epoch [1/10], Step [170/782], Loss: 0.7038
Epoch [1/10], Step [180/782], Loss: 0.6863
Epoch [1/10], Step [190/782], Loss: 0.7038
Epoch [1/10], Step [200/782], Loss: 0.6918
Epoch [1/10], Step [210/782], Loss: 0.7034
Epoch [1/10], Step [220/782], Loss: 0.6911
Epoch [1/10], Step [230/782], Loss: 0.7081
Epoch [1/10], Step [

KeyboardInterrupt: 