In [None]:
from cnn_lstm_1d import CNNLSTM
import torch
import torch.nn as nn

In [None]:
# Load pretrained model weights
pretrained_model = CNNLSTM()
pretrained_model.load_state_dict(torch.load("weights/cnnlstm-1.pth"))

# Freeze CNN and LSTM layers
for param in pretrained_model.cnn.parameters():
    param.requires_grad = False
for param in pretrained_model.lstm.parameters():
    param.requires_grad = False

# Add adapter layers
class AdaptedModel(nn.Module):
    def __init__(self, base_model, adapter_size=32, num_classes=2):
        super(AdaptedModel, self).__init__()
        self.base = base_model
        self.adapter = nn.Sequential(
            nn.Linear(base_model.fc.in_features, adapter_size),
            nn.ReLU(),
            nn.Linear(adapter_size, num_classes)
        )
        self.optimizer = torch.optim.Adam(self.adapter.parameters())
        self.criterion = nn.CrossEntropy
    
    def forward(self, x):
        x = self.base(x)
        x = self.adapter(x[:, -1, :])
        return x

    def train(self, train_loader, num_epochs=50):
        for epoch in range(num_epochs):
            self.adapter.train()
            total_loss = 0
            correct = 0
            for x, y in train_loader:
                self.optimizer.zero_grad()
                outputs = self(x)
                loss = self.criterion(outputs, y)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
                correct += (outputs.argmax(dim=1) == y).sum().item()
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}, Accuracy: {correct / len(train_loader.dataset):.4f}")
        torch.save(self.state_dict(), "weights/adapted-1.pth")
        print("Model saved")


In [None]:
# Load the train data
from utils import get_train_test_loaders
train_loader, test_loader = get_train_test_loaders("data/custom.csv")

In [None]:
# Initialize adapted model
model = AdaptedModel(pretrained_model)
model.train(train_loader)

In [None]:
# Test the model
from utils import test_model
test_model(model, test_loader)