In [1]:
import numpy as np
import scipy
import pandas as pd
import torch
import matplotlib.pyplot as plt

from data import load_data

In [2]:
class ConvRNN(torch.nn.Module):
    
    def __init__(self, num_classes):
        super().__init__()
        
        self.conv = torch.nn.Conv2d(1, 256, kernel_size=(12, 5), padding=(0, int(5 / 2)))
        self.gru = torch.nn.RNN(256, 256, batch_first=True)
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, num_classes),
            torch.nn.ReLU()
        )
        
    def forward(self, x):
        x = self.conv(x)                # (N, 1, H, W) => (N, 256, 1, W)
        x = x.squeeze(2)                # (N, 256, 1, W) => (N, 256, W)
        x = torch.permute(x, (0, 2, 1)) # (N, 256, W) =>  (N, W, 256)
        _, hn = self.gru(x)             # (1, N, 256)
        hn = hn.squeeze(0)              # (1, N, 256) => (N, 256)
        x = self.fc(hn)                 # (N, 256) => (N, 12)
        return torch.nn.functional.softmax(x, dim=1)

In [3]:
X, Y = load_data(1)
def reshape_for_cnn(X):
    """
    Reshapes the data from (X_LEN, B_SZ, X_DIM) (e.g. (5000, 10, 12))
    to (B_SZ, 1, X_DIM, X_LEN), because we have 1 channel.
    """
    W, N, H = X.shape
    X = torch.permute(X, (1, 2, 0))
    X = torch.reshape(X, (N, 1, H, W))
    return X

X = reshape_for_cnn(X)

N, C, H, W = X.shape

mdl = ConvRNN(12)
optimizer = torch.optim.Adam(mdl.parameters(), lr = 1e-3, weight_decay = 1e-8)

# Train
N_EPOCHS = 10
BATCH_SZ = 1
for e_i in range(N_EPOCHS):
    
    # shuffle the indices
    idx = np.arange(N)
    np.random.shuffle(idx)
    
    # iterate over batches
    epoch_loss = 0.0
    num_batches = N // BATCH_SZ
    for b_i in range(num_batches):
        
        # Make batch.
        X_batch = X[idx[b_i * BATCH_SZ : (b_i + 1) * BATCH_SZ],:,:,:]
        Y_batch = Y[idx[b_i * BATCH_SZ : (b_i + 1) * BATCH_SZ],:]
        
        # Zero gradients.
        optimizer.zero_grad()
        
        # Forward pass.
        pred = mdl.forward(X_batch)
        
        # Calculate loss.
        loss = torch.nn.functional.cross_entropy(pred, Y_batch)
        
        # Compute gradients.
        loss.backward()
        
        # Update parameters.
        optimizer.step()
    
    # Print Loss.
    print("Epoch #{0} Loss:{1}".format(e_i + 1, loss.item()))

Loading 1 Examples.


0it [00:00, ?it/s]


Epoch #1 Loss:2.481394052505493
Epoch #2 Loss:2.442871332168579
Epoch #3 Loss:2.389258623123169
Epoch #4 Loss:2.30055570602417
Epoch #5 Loss:2.1826696395874023
Epoch #6 Loss:2.0382473468780518
Epoch #7 Loss:1.881241798400879
Epoch #8 Loss:1.7519906759262085
Epoch #9 Loss:1.6737626791000366
Epoch #10 Loss:1.6384851932525635


In [None]:
with torch.no_grad():
    # Test the model.