In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import cross_entropy
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import lightning as pl

import torch.nn.functional as F
import torchmetrics


In [2]:
def create_label_encoder(labels):
    from functools import singledispatch
    unique_labels = sorted(set(labels))
    unique_labels = [unique_label.upper() for unique_label in unique_labels]
    num_classes = len(unique_labels)

    # Two way dictionary hack
    label_to_int = {label: i for i, label in enumerate(unique_labels)}
    int_to_label = {i: label for label, i in label_to_int.items()}

    identity_matrix = np.eye(num_classes)

    @singledispatch
    def one_hot_encoder(y):
        raise NotImplementedError("Unsupported input type")

    @one_hot_encoder.register(list)
    @one_hot_encoder.register(np.ndarray)
    def _one_hot_list_encoder(y):
        encoded_labels = [label_to_int[label] for label in y]
        one_hot_labels = identity_matrix[encoded_labels]
        return one_hot_labels

    @one_hot_encoder.register(int)
    def _(y):
        if y < 0 or y >= num_classes:
            raise ValueError(f"y is out of range for {num_classes} classes")
        one_hot_label = identity_matrix[y]
        return one_hot_label

    @one_hot_encoder.register(str)
    def _one_hot_str_encoder(y):
        # Sanitize
        y = y.upper()
        if y not in unique_labels:
            raise ValueError(f"{y} not in labels")

        # Assuming a single string label
        encoded_label = label_to_int[y]
        one_hot_label = identity_matrix[encoded_label]
        return one_hot_label

    @singledispatch
    def one_hot_decoder(encoded_labels):
        raise NotImplementedError("Unsupported input type")

    @one_hot_decoder.register(torch.Tensor)
    def _one_hot_torch_decoder(encoded_labels):
        return one_hot_decoder(encoded_labels.numpy())

    @one_hot_decoder.register(np.ndarray)
    def _one_hot_np_decoder(encoded_labels):
        if encoded_labels.ndim > 2:
            decoded_labels = [int_to_label[i]
                              for i in np.argmax(encoded_labels, axis=-1)]
        else:
            decoded_labels = int_to_label[np.argmax(encoded_labels, axis=-1)]
        return decoded_labels

    return one_hot_encoder, one_hot_decoder, num_classes

In [3]:
class RamanDataset(Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = x_data
        self.y_data = y_data

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.x_data[idx]).float()
        y = torch.from_numpy(self.y_data[idx]).float()
        return x, y

In [4]:
class RamanSpectraDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: Path, batch_size: int = 128):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.x_data_raw = []
        self.y_data_raw = []

    def load_data(self):
        datasets = list(self.data_dir.rglob("*.csv"))

        for file in datasets:
            data = np.loadtxt(file, comments='#', delimiter=',')
            self.x_data_raw.append(data)

            for row in data:
                self.y_data_raw.append(file.parent.stem.split("-")[0])

        self.x_data = np.concatenate(self.x_data_raw, axis=0)
        self.y_data = np.array(self.y_data_raw)

        self.classification_list = sorted(set(self.y_data))
        self.label_encoder, _, _ = create_label_encoder(
            self.classification_list)
        self.y_data_encoded = self.label_encoder(self.y_data)

    def standard_scaler(self, x):
        mean_ = np.mean(x, axis=-1, keepdims=True)
        scale_ = np.std(x, axis=-1, keepdims=True)
        standardized_data = (x - mean_) / scale_
        return standardized_data

    def prepare_data(self):
        self.load_data()
        self.x_data = self.standard_scaler(self.x_data)
        self.dataset = RamanDataset(self.x_data, self.y_data_encoded)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            # Split into train, validation, and test sets
            train_length = int(len(self.dataset) * 0.8)
            val_length = int(len(self.dataset) * 0.1)
            test_length = len(self.dataset) - (train_length + val_length)
            lengths = [train_length, val_length, test_length]

            self.train_dataset, self.val_dataset, self.test_dataset = random_split(
                self.dataset, lengths, generator=torch.Generator().manual_seed(42))

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

In [5]:
datamodule = RamanSpectraDataModule("./Dataset/cells-raman-spectra/dataset_i/")

In [11]:
class LightningClassifier(pl.LightningModule):
    def __init__(self, input_size, num_classes, lr=1.e-3):
        super().__init__()

        # Model
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, num_classes)

        self.accuracy = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=num_classes)
        self.lr = lr

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        x = self.fc4(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)

        self.accuracy(y_hat, y)
        self.log('train_acc_step', self.accuracy, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        stepping_batches = self.trainer.estimated_stepping_batches
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-3, total_steps=stepping_batches)
        return [optimizer], [scheduler]

    def on_train_epoch_end(self):
        # log epoch metric
        self.log('train_acc_epoch', self.accuracy)

In [12]:
input_shape = 2090  # todo: fix
num_classes = 6  # todo: fix

# PyTorch Lightning Trainer
trainer = pl.Trainer(max_epochs=50)
lightning_model = LightningClassifier(input_shape, num_classes)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
# Train the model
trainer.fit(lightning_model, datamodule=datamodule)

Loading `train_dataloader` to estimate number of stepping batches.
c:\Users\rjk217\Miniconda3\envs\guidestar\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
c:\Users\rjk217\Miniconda3\envs\guidestar\lib\site-packages\lightning\pytorch\loops\fit_loop.py:293: The number of training batches (37) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name     | Type               | Params
------------------------------------------------
0 | fc1      | Linear             | 535 K 
1 | fc2      | Linear             | 16.4 K
2 | fc3      | Linear             | 2.1 K 
3 | fc4      | Linear             | 198   
4 | accuracy | MulticlassAccuracy | 0     
----------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\rjk217\Miniconda3\envs\guidestar\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [14]:
trainer.test(datamodule=datamodule)

Restoring states from the checkpoint path at c:\Users\rjk217\OneDrive - University of Exeter\Documents\raman-cnn\lightning_logs\version_2\checkpoints\epoch=49-step=1850.ckpt
Loaded model weights from the checkpoint at c:\Users\rjk217\OneDrive - University of Exeter\Documents\raman-cnn\lightning_logs\version_2\checkpoints\epoch=49-step=1850.ckpt
c:\Users\rjk217\Miniconda3\envs\guidestar\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{}]

In [10]:
font_dict = {'color': 'black'}
i = 0
for x_test, y_test in test_loader.dataset:
    plt.plot(np.linspace(100, 4278, 2090), x_test)
    with torch.no_grad():
        y_pred = label_decoder(lightning_model(x_test).numpy())
    true_value = label_decoder(y_test)
    font_dict['color'] = "black" if y_pred == true_value else 'red'
    plt.title(
        f"Predicted: {y_pred}, Truth: {true_value}", fontdict=font_dict)
    plt.tight_layout()
    plt.show()

    if i > 4:
        break
    i+=1
    

NameError: name 'test_loader' is not defined