In [12]:
# !pip install nltk
# !pip install accelerate
# !pip install datasets

In [13]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import nltk
from nltk.corpus import movie_reviews

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.utils.data import DataLoader

from datasets import load_metric

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import Trainer, TrainingArguments
from transformers import BertModel

In [2]:
# nltk.download("movie_reviews")
fieldids = movie_reviews.fileids()
reviews = [ movie_reviews.raw(fieldid) for fieldid in fieldids ]
categories = [ movie_reviews.categories(fieldid) for fieldid in fieldids ]

In [4]:
# labeled_categories = LabelEncoder().fit_transform(categories)
labeled_categories = LabelEncoder().fit_transform(categories).tolist()
labeled_categories[:5]

  y = column_or_1d(y, warn=True)


[0, 0, 0, 0, 0]

In [5]:
X_train, X_test ,y_train, y_test = \
    train_test_split(reviews, labeled_categories, stratify=labeled_categories, random_state=42)

In [6]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

train_inputs = tokenizer(X_train, truncation=True, padding=True, return_tensors='pt')
test_inputs = tokenizer(X_test, truncation=True, padding=True, return_tensors='pt')

In [7]:
class ReviewDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels
        pass

    def __getitem__(self, idx):
        input = { k: torch.tensor(v[idx]) for k, v in self.inputs.items() }
        input['label'] = torch.tensor(self.labels[idx])
        return input
    
    def __len__(self):
        return len(self.labels)
    
train_dataset = ReviewDataset(train_inputs, y_train)
test_dataset = ReviewDataset(test_inputs, y_test)

In [8]:
device_name = "cuda:0" if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
device

device(type='cuda', index=0)

In [12]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=8)

In [14]:
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [16]:
class ReviewClassificationModel(nn.Module):
    def __init__(self, pretrained_model, token_size, num_labels): 
        super(ReviewClassificationModel, self).__init__()
        self.token_size = token_size
        self.num_labels = num_labels
        
        self.pretrained_model = pretrained_model
        self.classifier = nn.Linear(self.token_size, self.num_labels)

    def forward(self, inputs):
        outputs = self.pretrained_model(**inputs)

        bert_clf_token = outputs.last_hidden_state[:,0,:]
        return self.classifier(bert_clf_token)

model = ReviewClassificationModel(bert_model, num_labels=2, token_size=bert_model.config.hidden_size) 


In [18]:
from transformers import AdamW
import torch.nn.functional as F
import time

model.to(device)
model.train() # 스위치 : 가중치 업데이터 활성화 모드

optim = AdamW(model.parameters(), lr=5e-5) 
criterion = torch.nn.CrossEntropyLoss() 

start = time.time()
num_epochs = 4
for epoch in range(num_epochs):
    total_epoch_loss = 0
    for step, batch in enumerate(train_loader):
        optim.zero_grad() # 기울기 초기화
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'} 
        labels = batch['label'].to(device)
        outputs = model(inputs)

        loss = criterion(outputs, F.one_hot(labels, num_classes=2).float()) # 손실 계산

        if (step+1) % 100 == 0:
            elapsed = time.time() - start
            print(f'Epoch {epoch+1}, batch {step+1}, elapsed time: {elapsed}, loss: {loss}')
        total_epoch_loss += loss
        loss.backward() # 기울기 계산
        optim.step() # 가중치 업데이트
    avg_epoch_loss = total_epoch_loss / len(train_loader)
    print(f'Average loss for epoch {epoch+1}: {avg_epoch_loss}')

  input = { k: torch.tensor(v[idx]) for k, v in self.inputs.items() }


KeyboardInterrupt: 