In [2]:
import torch.nn as nn
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, Trainer, TrainingArguments
from transformers import DistilBertTokenizer
from transformers import BertModel
import torch
import pandas as pd
import os
import numpy as np

2023-09-08 18:25:07.827592: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
os.environ["WANDB_DISABLED"] = "true"

In [4]:
train = pd.read_csv("data/food_aging_train.csv")
test = pd.read_csv("data/food_aging_test.csv")

In [5]:
train_list_names = train["식품오타"].values.tolist()
train_list_labels = train["label"].values.tolist()

test_list_names = test["식품오타"].values.tolist()
test_list_labels = test["label"].values.tolist()

In [6]:
import torch.nn as nn
from transformers import DistilBertModel

class DistilBertClassifier(nn.Module):
    def __init__(self, num_labels=775):
        super(DistilBertClassifier, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')
        self.pre_classifier = nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.dropout = nn.Dropout(self.bert.config.attention_dropout)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs[0]
        pooled_output = hidden_state[:, 0]  # take [CLS] token representation
        pooled_output = self.pre_classifier(pooled_output)
        pooled_output = nn.ReLU()(pooled_output)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        return logits

In [7]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-multilingual-cased')  #
model = DistilBertClassifier()

In [8]:
words = train_list_names 
labels = train_list_labels 

In [12]:
MAX_LENGTH = 100  # Adjust as needed

def tokenize_data(texts, labels):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoding = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=MAX_LENGTH,
            truncation=True,
            padding='max_length',
            return_attention_mask=True
        )
        input_ids.append(encoding['input_ids'])
        attention_masks.append(encoding['attention_mask'])

    return input_ids, attention_masks, labels

input_ids, attention_masks, labels = tokenize_data(words, labels)

In [13]:
num_labels = len(set(labels))
num_labels

775

In [14]:
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

# Convert data to tensors
input_ids = torch.tensor(input_ids)
attention_masks = torch.tensor(attention_masks)
labels = torch.tensor(labels)

dataset = TensorDataset(input_ids, attention_masks, labels)
dataloader = DataLoader(dataset, batch_size=16)  # Adjust batch size as needed

# Setup GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()

In [15]:
%%time 
# Training loop
NUM_EPOCHS = 400  # Adjust as needed

for epoch in range(NUM_EPOCHS):
    for batch in dataloader:
        batch_input_ids, batch_attention_masks, batch_labels = [b.to(device) for b in batch]
        
        optimizer.zero_grad()
        logits = model(batch_input_ids, batch_attention_masks)
        
        loss = loss_fn(logits, batch_labels)
        loss.backward()
        optimizer.step()

        print(f"Epoch: {epoch}, Loss: {loss.item()}")

Epoch: 0, Loss: 6.698608875274658
Epoch: 0, Loss: 6.645479202270508
Epoch: 0, Loss: 6.673892021179199
Epoch: 0, Loss: 6.634433269500732
Epoch: 0, Loss: 6.629505157470703
Epoch: 0, Loss: 6.651381492614746
Epoch: 0, Loss: 6.729428291320801
Epoch: 0, Loss: 6.660944938659668
Epoch: 0, Loss: 6.651851654052734
Epoch: 0, Loss: 6.6818389892578125
Epoch: 0, Loss: 6.646271228790283
Epoch: 0, Loss: 6.619174003601074
Epoch: 0, Loss: 6.710256099700928
Epoch: 0, Loss: 6.698944568634033
Epoch: 0, Loss: 6.623620986938477
Epoch: 0, Loss: 6.627994060516357
Epoch: 0, Loss: 6.6838860511779785
Epoch: 0, Loss: 6.6706976890563965
Epoch: 0, Loss: 6.630701065063477
Epoch: 0, Loss: 6.705967426300049
Epoch: 0, Loss: 6.646645545959473
Epoch: 0, Loss: 6.652819633483887
Epoch: 0, Loss: 6.633697986602783
Epoch: 0, Loss: 6.607198715209961
Epoch: 0, Loss: 6.613197326660156
Epoch: 0, Loss: 6.613039016723633
Epoch: 0, Loss: 6.6612420082092285
Epoch: 0, Loss: 6.704387664794922
Epoch: 0, Loss: 6.637092113494873
Epoch: 0, 

In [16]:
save_path = "distilBERT-ko-wikipedia-classifier.pth"
torch.save(model.state_dict(), save_path)

In [17]:
save_path = "distilBERT-ko-wikipedia-classifier.prm"
torch.save(model.state_dict(), save_path, pickle_protocol=4)