In [1]:
# Link to paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9971386
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.io import read_image
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

from isar_dataset import ISARDataset

## Model Definitions

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    
class CNN(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.cnn_layers = nn.Sequential(
            ConvBlock(in_channels, 8),
            ConvBlock(8, 8),
            ConvBlock(8, 8),
            ConvBlock(8, 8),
            nn.MaxPool2d(kernel_size=2, stride=2),

            ConvBlock(8, 16),
            ConvBlock(16, 16),
            ConvBlock(16, 16),
            nn.MaxPool2d(kernel_size=2, stride=2),

            ConvBlock(16, 32),
            ConvBlock(32, 32),
            ConvBlock(32, 32),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        return x

In [3]:
class BiLSTM(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = 2
        self.lstm = nn.LSTM(2592, self.hidden_size, self.num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(self.hidden_size * 2, num_classes)

    def forward(self, x):
        # Initialize hidden and cell states (only needs to be done once)
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device) 
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)

        # Forward pass through LSTM
        out, _ = self.lstm(x, (h0, c0))  

        # Extract last hidden state
        return out[:, -1, :]

class CNN_BiLSTM(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cnn = CNN()
        hidden_size = 1000
        self.bilstm = BiLSTM(hidden_size, num_classes)
        self.fc1 = nn.Linear(2 * hidden_size, 100)
        self.fc2 = nn.Linear(100, num_classes)

    def forward(self, x):
        B, seq_length, C, H, W = x.shape

        # Process images in sequence via CNN independently
        cnn_out = []
        for i in range(seq_length):
            img = x[:, i, :, :, :]
            img = self.cnn(img)
            img = img.flatten(1)
            cnn_out.append(img)
        x = torch.stack(cnn_out, dim=1)
        
        x = self.bilstm(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)

        return x

## Training

In [4]:
def evaluate(model, test_loader, criterion, device):
    model.eval()  # Set model to evaluation mode

    total_loss = 0.0
    num_correct = 0
    num_samples = len(test_loader.dataset)

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            # Compute the outputs and loss
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Compute the accuracy
            _, predictions = torch.max(outputs, dim=1)
            num_correct += (predictions == labels).sum().item()

    # Compute the average loss and accuracy
    avg_loss = total_loss / num_samples
    accuracy = num_correct / num_samples

    return avg_loss, accuracy

In [5]:
torch.manual_seed(0)

# 0. Declare constants
NUM_CLASSES = 4
# DATA_PATH = 'data/raw/'
# LABEL_PATH = 'labels.csv'
DATA_PATH = 'test/data/'
LABEL_PATH = 'test/test_labels.csv'
MODEL_PATH = 'weights/test_cnn_bilstm_model.pth'

# Hyperparameters
SEQ_LENGTH = 3
BATCH_SIZE = 4
LEARNING_RATE = 1e-3
NUM_EPOCHS=50

TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

# 1. Data Preparation
center_crop = transforms.CenterCrop((120, 120))

# Load in dataset
dataset = ISARDataset(LABEL_PATH, DATA_PATH, seq_length=SEQ_LENGTH, transform=center_crop)

# Partition dataset into train, validation, and test sets
train_size = int(TRAIN_RATIO * len(dataset))
val_size = int(VAL_RATIO * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 2. Model and Optimization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN_BiLSTM(num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)  # Decays LR by a factor of 0.1 every 30 epochs

# 3. Training Loop
for epoch in range(NUM_EPOCHS):
    model.train()

    train_correct = 0
    train_total = 0
    train_loss = 0.0
    
    with tqdm(total=len(train_loader), unit="batch", desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}") as pbar:
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward, backward pass + optimize
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Calculate loss
            train_loss += loss.item()

            # Calculate accuracy
            _, predictions = torch.max(outputs, dim=1)
            train_correct += (predictions == labels).sum().item()
            train_total += len(labels)

            # Update progress bar
            pbar.update(1)
            pbar.set_postfix(loss=train_loss / train_total, accuracy=train_correct / train_total)

    # Validation every 5 epochs
    if (epoch + 1) % 5 == 0:
        val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
        print(f"Validation Set: Loss={val_loss}, Accuracy={val_accuracy}")

torch.save(model.state_dict(), MODEL_PATH)

Epoch 1/50: 100%|██████████| 6/6 [00:13<00:00,  2.22s/batch, accuracy=0.476, loss=0.353]
Epoch 2/50: 100%|██████████| 6/6 [00:12<00:00,  2.04s/batch, accuracy=0.667, loss=0.278]
Epoch 3/50: 100%|██████████| 6/6 [00:14<00:00,  2.41s/batch, accuracy=0.952, loss=0.247]
Epoch 4/50: 100%|██████████| 6/6 [00:16<00:00,  2.69s/batch, accuracy=0.952, loss=0.226]
Epoch 5/50: 100%|██████████| 6/6 [00:14<00:00,  2.43s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.24842366576194763, Accuracy=0.75


Epoch 6/50: 100%|██████████| 6/6 [00:11<00:00,  1.94s/batch, accuracy=0.952, loss=0.224]
Epoch 7/50: 100%|██████████| 6/6 [00:11<00:00,  1.85s/batch, accuracy=0.952, loss=0.224]
Epoch 8/50: 100%|██████████| 6/6 [00:12<00:00,  2.04s/batch, accuracy=0.952, loss=0.224]
Epoch 9/50: 100%|██████████| 6/6 [00:09<00:00,  1.51s/batch, accuracy=0.952, loss=0.224]
Epoch 10/50: 100%|██████████| 6/6 [00:09<00:00,  1.51s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.2484246790409088, Accuracy=0.75


Epoch 11/50: 100%|██████████| 6/6 [00:09<00:00,  1.57s/batch, accuracy=0.952, loss=0.224]
Epoch 12/50: 100%|██████████| 6/6 [00:09<00:00,  1.56s/batch, accuracy=0.952, loss=0.224]
Epoch 13/50: 100%|██████████| 6/6 [00:09<00:00,  1.65s/batch, accuracy=0.952, loss=0.224]
Epoch 14/50: 100%|██████████| 6/6 [00:09<00:00,  1.61s/batch, accuracy=0.952, loss=0.224]
Epoch 15/50: 100%|██████████| 6/6 [00:10<00:00,  1.72s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.24842104315757751, Accuracy=0.75


Epoch 16/50: 100%|██████████| 6/6 [00:09<00:00,  1.56s/batch, accuracy=0.952, loss=0.224]
Epoch 17/50: 100%|██████████| 6/6 [00:09<00:00,  1.58s/batch, accuracy=0.952, loss=0.224]
Epoch 18/50: 100%|██████████| 6/6 [00:09<00:00,  1.57s/batch, accuracy=0.952, loss=0.224]
Epoch 19/50: 100%|██████████| 6/6 [00:09<00:00,  1.58s/batch, accuracy=0.952, loss=0.224]
Epoch 20/50: 100%|██████████| 6/6 [00:09<00:00,  1.57s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.24841980636119843, Accuracy=0.75


Epoch 21/50: 100%|██████████| 6/6 [00:11<00:00,  1.87s/batch, accuracy=0.952, loss=0.224]
Epoch 22/50: 100%|██████████| 6/6 [00:09<00:00,  1.58s/batch, accuracy=0.952, loss=0.224]
Epoch 23/50: 100%|██████████| 6/6 [00:10<00:00,  1.80s/batch, accuracy=0.952, loss=0.224]
Epoch 24/50: 100%|██████████| 6/6 [00:11<00:00,  1.93s/batch, accuracy=0.952, loss=0.224]
Epoch 25/50: 100%|██████████| 6/6 [00:11<00:00,  1.98s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.24841934442520142, Accuracy=0.75


Epoch 26/50: 100%|██████████| 6/6 [00:10<00:00,  1.80s/batch, accuracy=0.952, loss=0.224]
Epoch 27/50: 100%|██████████| 6/6 [00:10<00:00,  1.83s/batch, accuracy=0.952, loss=0.224]
Epoch 28/50: 100%|██████████| 6/6 [00:09<00:00,  1.65s/batch, accuracy=0.952, loss=0.224]
Epoch 29/50: 100%|██████████| 6/6 [00:10<00:00,  1.79s/batch, accuracy=0.952, loss=0.224]
Epoch 30/50: 100%|██████████| 6/6 [00:10<00:00,  1.80s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.24841901659965515, Accuracy=0.75


Epoch 31/50: 100%|██████████| 6/6 [00:13<00:00,  2.22s/batch, accuracy=0.952, loss=0.224]
Epoch 32/50: 100%|██████████| 6/6 [00:12<00:00,  2.02s/batch, accuracy=0.952, loss=0.224]
Epoch 33/50: 100%|██████████| 6/6 [00:10<00:00,  1.71s/batch, accuracy=0.952, loss=0.26]
Epoch 34/50: 100%|██████████| 6/6 [00:10<00:00,  1.72s/batch, accuracy=0.952, loss=0.224]
Epoch 35/50: 100%|██████████| 6/6 [00:10<00:00,  1.76s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.2484186887741089, Accuracy=0.75


Epoch 36/50: 100%|██████████| 6/6 [00:13<00:00,  2.24s/batch, accuracy=0.952, loss=0.224]
Epoch 37/50: 100%|██████████| 6/6 [00:12<00:00,  2.17s/batch, accuracy=0.952, loss=0.224]
Epoch 38/50: 100%|██████████| 6/6 [00:10<00:00,  1.81s/batch, accuracy=0.952, loss=0.224]
Epoch 39/50: 100%|██████████| 6/6 [00:11<00:00,  1.92s/batch, accuracy=0.952, loss=0.26]
Epoch 40/50: 100%|██████████| 6/6 [00:11<00:00,  1.96s/batch, accuracy=0.952, loss=0.26]


Validation Set: Loss=0.24841848015785217, Accuracy=0.75


Epoch 41/50: 100%|██████████| 6/6 [00:13<00:00,  2.18s/batch, accuracy=0.952, loss=0.224]
Epoch 42/50: 100%|██████████| 6/6 [00:11<00:00,  1.92s/batch, accuracy=0.952, loss=0.224]
Epoch 43/50: 100%|██████████| 6/6 [00:10<00:00,  1.82s/batch, accuracy=0.952, loss=0.224]
Epoch 44/50: 100%|██████████| 6/6 [00:09<00:00,  1.62s/batch, accuracy=0.952, loss=0.224]
Epoch 45/50: 100%|██████████| 6/6 [00:09<00:00,  1.52s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.24841824173927307, Accuracy=0.75


Epoch 46/50: 100%|██████████| 6/6 [00:09<00:00,  1.58s/batch, accuracy=0.952, loss=0.224]
Epoch 47/50: 100%|██████████| 6/6 [00:08<00:00,  1.46s/batch, accuracy=0.952, loss=0.224]
Epoch 48/50: 100%|██████████| 6/6 [00:09<00:00,  1.61s/batch, accuracy=0.952, loss=0.224]
Epoch 49/50: 100%|██████████| 6/6 [00:09<00:00,  1.55s/batch, accuracy=0.952, loss=0.224]
Epoch 50/50: 100%|██████████| 6/6 [00:10<00:00,  1.77s/batch, accuracy=0.952, loss=0.224]


Validation Set: Loss=0.24841804802417755, Accuracy=0.75


In [7]:
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Test Set: Loss={test_loss}, Accuracy={test_acc}")

Test Set: Loss=0.39746745824813845, Accuracy=0.6


In [None]:

# # Testing
# IMG_PATH = 'data/raw/image_50.png'
# MODEL_PATH = 'weights/cnn_bilstmv1.pth'

# image = read_image(IMG_PATH).float()
# plt.imshow(image.permute(1, 2, 0).int())
# plt.axis('off')

# saved_model = CNN_BiLSTM(num_classes=4)
# saved_model.load_state_dict(torch.load(MODEL_PATH))
# saved_model.eval()

# test_loss, test_acc = evaluate(saved_model, test_loader, criterion, device)
# print(f"Test Set: Loss={test_loss}, Accuracy={test_acc}")

# # Add batch and sequence dimensions
# image = center_crop(image).unsqueeze(0).unsqueeze(0)

# with torch.no_grad():
#     output = saved_model(image)
#     prediction = torch.argmax(output, dim=1).item()
#     print(f'Prediction: Class {prediction}')
#     print(f'Probability: {output[0][prediction]:.2f}')