<a href="https://colab.research.google.com/github/yunju-1118/EWHA/blob/main/seq_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **RNN**

In [8]:
import torch
import torch.nn as nn
import torch.functional as F

import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10
from torch.utils.data import DataLoader

## **MNIST**

In [10]:
import numpy as np

path = './datasets/'

transform = transforms.Compose([transforms.ToTensor()])

train_data = MNIST(root=path, train=True, transform=transform, download=True)
test_data = MNIST(root=path, train=False, transform=transform, download=True)

batch_size = 100

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

_,seq_len, input_size = train_data[0][0].shape
output_shape = len(train_data.classes)

In [11]:
hidden_size = input_size*2

model_name = "rnn"

**RNN Cell** 이용

In [12]:
class RNNClassifier(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        self.cell = nn.RNNCell(input_size=self.input_size,
                               hidden_size=self.hidden_size)

        self.fc = nn.Linear(self.hidden_size, output_shape)

        def forward(self,x):
            x = x.reshape(-1, seq_len, input_size).permute((1,0,2))
            hidden_state = torch.zeros(batch_size, self.hidden_size).to(device)
            for i in range(seq_len):
                hidden_state = self.cell(x[i], hidden_state)
            out = self.fc(hidden_state)

            return out

In [13]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        self.cell = nn.RNNCell(input_size=self.input_size,
                               hidden_size=self.hidden_size)

        self.fc = nn.Linear(self.hidden_size, output_shape)

        def forward(self,x):
            x = x.reshape(-1, seq_len, input_size).permute((1,0,2))
            hidden_state = torch.zeros(batch_size, self.hidden_size).to(device)
            cell_state = torch.zeros(batch_size, self.hidden_size).to(device)
            for i in range(seq_len):
                hidden_state, cell_state = self.cell(x[i], hidden_state, cell_state)
            out = self.fc(hidden_state)

            return out

## **CIFAR10**

In [14]:
import numpy as np

path = './datasets/'

transform = transforms.Compose([transforms.ToTensor()])

train_data = CIFAR10(root=path, train=True, transform=transform, download=True)
test_data = CIFAR10(root=path, train=False, transform=transform, download=True)

batch_size = 100

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

100%|██████████| 170M/170M [00:04<00:00, 42.3MB/s]


In [15]:
num_channel, seq_len, input_size = train_data[0][0].shape
output_shape = len(train_data.classes)

hidden_size = input_size*2
num_layers = 3
batch_first = True ## batch를 맨 앞으로
bidirectional = True

In [None]:
if model_name == "rnn":
    classifier = RNNClassifier
elif model_name == "lstm":
    classifier = LSTMClassifier
elif model_name == "gru":
    classifier = GRUClassifier

In [None]:
class RNNClassifier(nn.Module):
    def __init__(self.num_channel, input_size, hidden_size, num_layers=1,
                 batch_first=True, bidirectional=False):
        super().__init__()

        self.num_channel = num_channel
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bidiretional = bidirectional
        if self.bidirectional:
            self.direction = 2
        else:
            self.direction = 1

        self.seq = nn.RNN(input_size=self.input_size*self.num_channel,
                          hidden_size = self.hidden_size,
                          num_layers = self.num_layers,
                          batch_first = self.batch_first,
                          bidirectional = self.bidirectional)
        self.fc = nn.Linear(self.hidden_size*self.direction, output_shape)

    def forward(self,x):
        x = x.permute((0,2,3,1)).reshape(-1,seq_len,self.input_size*self.num_channel)
        h0 = torch.zeros(self.direction*self.num_layers,batch_size, self.hidden_size)
        out, hidden = self.seq(x,h0.detach().to(device))

        out = out[:,-1,:].squeeze() # batch_size, seq_len, features
        # bidirectional = True이면,
        ## out = torch.cat((hidden[-2], hidden[-1]), dim=1)
        out = self.fc(out)
        return out