<a href="https://colab.research.google.com/github/rbdus0715/Machine-Learning/blob/main/study/torch/8.lstm_generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')
!unzip -qq "/content/drive/MyDrive/lstm.zip"

Mounted at /content/drive


In [4]:
import pandas as pd
import os
import string

df = pd.read_csv("./ArticlesApril2017.csv")
print(df.columns)
df.head(3)

Index(['abstract', 'articleID', 'articleWordCount', 'byline', 'documentType',
       'headline', 'keywords', 'multimedia', 'newDesk', 'printPage', 'pubDate',
       'sectionName', 'snippet', 'source', 'typeOfMaterial', 'webURL'],
      dtype='object')


Unnamed: 0,abstract,articleID,articleWordCount,byline,documentType,headline,keywords,multimedia,newDesk,printPage,pubDate,sectionName,snippet,source,typeOfMaterial,webURL
0,,58def1347c459f24986d7c80,716,By STEPHEN HILTNER and SUSAN LEHMAN,article,Finding an Expansive View of a Forgotten Peop...,"['Photography', 'New York Times', 'Niger', 'Fe...",3,Insider,2,2017-04-01 00:15:41,Unknown,One of the largest photo displays in Times his...,The New York Times,News,https://www.nytimes.com/2017/03/31/insider/nig...
1,,58def3237c459f24986d7c84,823,By GAIL COLLINS,article,"And Now, the Dreaded Trump Curse","['United States Politics and Government', 'Tru...",3,OpEd,23,2017-04-01 00:23:58,Unknown,Meet the gang from under the bus.,The New York Times,Op-Ed,https://www.nytimes.com/2017/03/31/opinion/and...
2,,58def9f57c459f24986d7c90,575,By THE EDITORIAL BOARD,article,Venezuela’s Descent Into Dictatorship,"['Venezuela', 'Politics and Government', 'Madu...",3,Editorial,22,2017-04-01 00:53:06,Unknown,A court ruling annulling the legislature’s aut...,The New York Times,Editorial,https://www.nytimes.com/2017/03/31/opinion/ven...


In [9]:
import numpy as np
import glob
from torch.utils.data.dataset import Dataset

class TextGeneration(Dataset):
    def clean_text(self, txt):
        # 모두 소문자 / 특수 문자 제거
        txt = "".join(v for v in txt if v not in string.punctuation).lower()
        return txt

    def __init__(self):
        all_headlines = []

        for filename in glob.glob("./*.csv"):
            if 'Articles' in filename:
                article_df = pd.read_csv(filename)
                all_headlines.extend(list(article_df.headline.values))
                break

        # unknown 제거
        all_headlines = [h for h in all_headlines if h != "Unknown"]

        # 모든 문장
        self.corpus = [self.clean_text(x) for x in all_headlines]
        self.BOW = {}

        for line in self.corpus:
            for word in line.split():
                if word not in self.BOW.keys():
                    self.BOW[word] = len(self.BOW.keys())

        self.data = self.generate_sequence(self.corpus)

    def generate_sequence(self, txt):
        seq = []

        for line in txt:
            line = line.split()
            line_bow = [self.BOW[word] for word in line]

            data = [
                (
                    # 입력: 단어 2개, 정답: 1개
                    [line_bow[i], line_bow[i+1]], line_bow[i+2]
                ) for i in range(len(line_bow)-2)
            ]

            seq.extend(data)

        return seq

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        data = np.array(self.data[i][0])
        label = np.array(self.data[i][1]).astype(np.float32)
        return data, label

In [10]:
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, num_embeddings):
        super(LSTM, self).__init__()

        # num_embeddings: 전체 단어의 개수
        # embedding_dim: 몇 차원으로 압축할 것인가
        self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=16)

        self.lstm = nn.LSTM(input_size=16, hidden_size=64, num_layers=5, batch_first=True)
        self.fc1 = nn.Linear(128, num_embeddings)
        self.fc2 = nn.Linear(num_embeddings, num_embeddings)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.embed(x)
        x, _ = self.lstm(x)
        x = torch.reshape(x, (x.shape[0], -1))
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x


In [None]:
import tqdm

from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam

# 학습을 진행할 프로세서 정의
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset = TextGeneration()  # 데이터셋 정의
model = LSTM(num_embeddings=len(dataset.BOW)).to(device)  # 모델 정의
loader = DataLoader(dataset, batch_size=64)
optim = Adam(model.parameters(), lr=0.001)

for epoch in range(200):
   iterator = tqdm.tqdm(loader)
   for data, label in iterator:
       # 기울기 초기화
       optim.zero_grad()

       # 모델의 예측값
       pred = model(torch.tensor(data, dtype=torch.long).to(device))

       # 정답 레이블은 long 텐서로 반환해야 함
       loss = nn.CrossEntropyLoss()(
           pred, torch.tensor(label, dtype=torch.long).to(device))

       # 오차 역전파
       loss.backward()
       optim.step()

       iterator.set_description(f"epoch{epoch} loss:{loss.item()}")

torch.save(model.state_dict(), "lstm.pth")

In [25]:
string = input("your input: ")

def generate(model, BOW, string="finding an ", strlen=10):
   device = "cuda" if torch.cuda.is_available() else "cpu"

   print(f"input word: {string}")

   with torch.no_grad():
       for p in range(strlen):
           # 입력 문장을 텐서로 변경
           words = torch.tensor(
               [BOW[w] for w in string.split()], dtype=torch.long).to(device)

           input_tensor = torch.unsqueeze(words[-2:], dim=0)
           output = model(input_tensor)  # 모델을 이용해 예측
           output_word = (torch.argmax(output).cpu().numpy())
           string += list(BOW.keys())[output_word]  # 문장에 예측된 단어를 추가
           string += " "

   print(f"predicted sentence: {string}")

model.load_state_dict(torch.load("lstm.pth", map_location=device))
pred = generate(model, dataset.BOW, string=string, strlen=len(string))

your input: over safety 
input word: over safety 
predicted sentence: over safety concerns and ‘living season 3 episode 2 swedish physicists a plates rationale 
