In [57]:
# Link to paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9971386
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import torch.nn.functional as F
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from isar_dataset import ISARDataset

## Model Definitions

In [58]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    
class CNN(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.cnn_layers = nn.Sequential(
            ConvBlock(in_channels, 8),
            ConvBlock(8, 8),
            ConvBlock(8, 8),
            ConvBlock(8, 8),
            nn.MaxPool2d(kernel_size=2, stride=2),

            ConvBlock(8, 16),
            ConvBlock(16, 16),
            ConvBlock(16, 16),
            nn.MaxPool2d(kernel_size=2, stride=2),

            ConvBlock(16, 32),
            ConvBlock(32, 32),
            ConvBlock(32, 32),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        return x

In [59]:
class BiLSTM(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = 2
        self.lstm = nn.LSTM(7200, self.hidden_size, self.num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(self.hidden_size * 2, num_classes)

    def forward(self, x):
        # Initialize hidden and cell states (only needs to be done once)
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device) 
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)

        # Forward pass through LSTM
        out, _ = self.lstm(x, (h0, c0))  

        # Extract last hidden state
        return out[:, -1, :]

class CNN_BiLSTM(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cnn = CNN()
        hidden_size = 1000
        self.bilstm = BiLSTM(hidden_size, num_classes)
        self.fc1 = nn.Linear(2 * hidden_size, 100)
        self.fc2 = nn.Linear(100, num_classes)

    def forward(self, x):
        print(x.shape)
        x = self.cnn(x)
        print(x.shape)
        x = torch.flatten(x, 1)
        print(x.shape)
        x = self.bilstm(x)
        print(x.shape)
        x = torch.flatten(x, 1)
        print(x.shape)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)

        return x

## Training

In [60]:
# 1. Data Preparation
center_crop = transforms.CenterCrop((120, 120))

train_dataset = ISARDataset('test/test_labels.csv', 'test/data/', transform=center_crop)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=False)

# 2. Model and Optimization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN_BiLSTM(num_classes=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adjust learning rate as needed

# 3. Training Loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")  # Progress bar

    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)  # CrossEntropyLoss automatically handles log_softmax

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')


Epoch 1/100:   0%|          | 0/10 [00:23<?, ?it/s]


torch.Size([4, 1, 120, 120])
torch.Size([4, 32, 9, 9])
torch.Size([4, 2592])


RuntimeError: For unbatched 2-D input, hx and cx should also be 2-D but got (3-D, 3-D) tensors