In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
import torch
import lightning.pytorch as pl
import torchmetrics
import torchvision
from torchinfo import summary
from torchview import draw_graph
from IPython.display import display
import sympy as sp
from datetime import datetime
import time
sp.init_printing(use_latex=True)

In [None]:
# get constant variables from environment
BUCKET_NAME = "csc7400-deepsight"
N = os.environ.get("SIZE", None)
BATCH_SIZE = os.environ.get("BATCH_SIZE", 50)
VAL_SPLIT = os.environ.get("VAL_SPLIT", 0.2)
NUM_WORKERS = os.environ.get("NUM_WORKERS", 2)

In [3]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self,
                 batch_size=BATCH_SIZE,
                 val_split=VAL_SPLIT,
                 num_workers=NUM_WORKERS,
                 location="~/datasets",
                 **kwargs):
        super().__init__(**kwargs)
        self.batch_size = batch_size
        self.val_split = val_split
        self.num_workers = num_workers
        self.location = location
        self.input_shape = None
        self.output_shape = None
        self.data_train = None
        self.data_val = None
        self.data_test = None

    def setup(self, stage: str):
        if (stage == 'fit' or \
                stage == 'validate') and \
                not (self.data_train and self.data_val):
            start_time = time.perf_counter()
            training_dataset = torchvision.datasets.CIFAR10(root=self.location, download=True, train=True)
            end_time = time.perf_counter()
            elapsed_time = round(end_time - start_time, 3)
            print(f" Elapsed time of set training_dataset: {elapsed_time} seconds")
            # CIFAR10
            start_time = time.perf_counter()
            x_train = training_dataset.data.transpose((0, 3, 1, 2))[:N]
            end_time = time.perf_counter()
            elapsed_time = round(end_time - start_time, 3)
            print(f" Elapsed time of set x_train: {elapsed_time} seconds")
            # x_train - time com
            y_train = np.array(training_dataset.targets)[:N]
            self.input_shape = x_train.shape[1:]
            self.output_shape = (len(np.unique(y_train)),)
            rng = np.random.default_rng()
            permutation = rng.permutation(x_train.shape[0])
            split_point = int(x_train.shape[0]*(1.0-self.val_split))
            self.data_train = list(zip(torch.Tensor(x_train[permutation[:split_point]]).to(torch.float32),
                                       torch.Tensor(y_train[permutation[:split_point]]).to(torch.long)))
            self.data_val = list(zip(torch.Tensor(x_train[permutation[split_point:]]).to(torch.float32),
                                     torch.Tensor(y_train[permutation[split_point:]]).to(torch.long)))
        if (stage == 'test' or \
                stage == 'predict') and \
                not self.data_test:
            testing_dataset = torchvision.datasets.CIFAR10(root=self.location, download=True, train=False)
            x_test = testing_dataset.data.transpose((0, 3, 1, 2))[:N]
            y_test = np.array(testing_dataset.targets)[:N]
            self.input_shape = x_test.shape[1:]
            self.output_shape = (len(np.unique(y_test)),)
            self.data_test = list(zip(torch.Tensor(x_test).to(torch.float32),
                                      torch.Tensor(y_test).to(torch.long)))

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data_train,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.data_val,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           shuffle=False)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.data_test,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           shuffle=False)

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(self.data_test,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           shuffle=False)

In [4]:
data_module = CIFAR10DataModule(batch_size=20)
data_module.setup('fit')
dl = data_module.val_dataloader()

 Elapsed time of set training_dataset: 0.862 seconds


UnboundLocalError: cannot access local variable 'N' where it is not associated with a value

In [None]:
batch = next(iter(dl))

In [None]:
class SinePositionEmbedding(pl.LightningModule):
    def __init__(self,
                 max_wavelength=10000.0,
                 **kwargs):
        super().__init__(**kwargs)
        self.max_wavelength = torch.Tensor([max_wavelength])

    def forward(self, x):
        input_shape = x.shape
        seq_length = x.shape[-2]
        hidden_size = x.shape[-1]
        position = torch.arange(seq_length).type_as(x)
        min_freq = (1 / self.max_wavelength).type_as(x)
        timescales = torch.pow(
            min_freq,
            (2 * (torch.arange(hidden_size) // 2)).type_as(x)
            / torch.Tensor([hidden_size]).type_as(x)
        )
        angles = torch.unsqueeze(position, 1) * torch.unsqueeze(timescales, 0)
        cos_mask = (torch.arange(hidden_size) % 2).type_as(x)
        sin_mask = 1 - cos_mask
        positional_encodings = (
            torch.sin(angles) * sin_mask + torch.cos(angles) * cos_mask
        )
        return torch.broadcast_to(positional_encodings, input_shape)

In [None]:
class TransformerBlock(pl.LightningModule):
    def __init__(self,
                 latent_size=64,
                 num_heads=4,
                 dropout=0.1,
                 **kwargs):
        super().__init__(**kwargs)
        self.layer_norm1 = torch.nn.LayerNorm(latent_size)
        self.layer_norm2 = torch.nn.LayerNorm(latent_size)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()
        self.linear = torch.nn.Linear(latent_size,
                                      latent_size)
        self.mha = torch.nn.MultiheadAttention(latent_size,
                                               num_heads,
                                               dropout=dropout,
                                               batch_first=True)

    def forward(self, x):
        y = x
        y = self.layer_norm1(y)
        y = self.mha(y, y, y)[0]
        x = y = x + y
        y = self.layer_norm2(y)
        y = self.linear(y)
        y = self.dropout(y)
        y = self.activation(y)
        return x + y

In [None]:
# Define Trainable Module (Abstract Base Class)
class LightningBoilerplate(pl.LightningModule):
    def __init__(self, **kwargs):
        # This is the contructor, where we typically make
        # layer objects using provided arguments.
        super().__init__(**kwargs)  # Call the super class constructor

    def predict_step(self, predict_batch, batch_idx):
        x, y_true = predict_batch
        y_pred = self.predict(x)
        return y_pred, y_true

    def training_step(self, train_batch, batch_idx):
        x, y_true = train_batch
        y_pred = self(x)
        for metric_name, metric_function in self.network_metrics.items():
            metric_value = metric_function(y_pred, y_true)
            self.log('train_'+metric_name, metric_value, on_step=False, on_epoch=True)
        loss = self.network_loss(y_pred, y_true)
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y_true = val_batch
        y_pred = self(x)
        for metric_name, metric_function in self.network_metrics.items():
            metric_value = metric_function(y_pred, y_true)
            self.log('val_'+metric_name, metric_value, on_step=False, on_epoch=True)
        loss = self.network_loss(y_pred, y_true)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        return loss

    def test_step(self, test_batch, batch_idx):
        x, y_true = test_batch
        y_pred = self(x)
        for metric_name, metric_function in self.network_metrics.items():
            metric_value = metric_function(y_pred, y_true)
            self.log('test_'+metric_name, metric_value, on_step=False, on_epoch=True)
        loss = self.network_loss(y_pred, y_true)
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        return loss

In [None]:
# Attach loss, metrics, and optimizer
class MultiClassLightningModule(LightningBoilerplate):
    def __init__(self,
                 num_classes,
                 **kwargs):
        # This is the contructor, where we typically make
        # layer objects using provided arguments.
        super().__init__(**kwargs)  # Call the super class constructor

        # This creates an accuracy function
        self.network_metrics = torch.nn.ModuleDict({
            'acc': torchmetrics.classification.Accuracy(task='multiclass',
                                                        num_classes=num_classes)
        })
        # This creates a loss function
        self.network_loss = torch.nn.CrossEntropyLoss()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        # change lr: 0.01, 0.0001
        return optimizer

In [None]:
# Attach standardization and augmentation
class StandardizeTransformModule(MultiClassLightningModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Source: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
        # Needs to always be applied to any incoming
        # image for this model. The Compose operation
        # takes a list of torchvision transforms and
        # applies them in sequential order, similar
        # to neural layers...
        self.standardize = torchvision.transforms.Compose([
            torchvision.transforms.Resize([256]),
            torchvision.transforms.CenterCrop([224]),
            torchvision.transforms.Lambda(lambda x: x / 255.0),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225]),
        ])
        # Besides just standardization, the images can also undergo
        # augmentation using torchvision. Again, we compose
        # these operations together - ranges are provided for
        # each of these augmentations.
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.RandomAffine(degrees=(-10.0,10.0),
                                                translate=(0.1,0.1),
                                                scale=(0.9,1.1),
                                                shear=(-10.0,10.0)),
            torchvision.transforms.RandomHorizontalFlip(0.5),
        ])

    def forward(self, x):
        y = x
        y = self.standardize(y)
        if self.training:
            y = self.transform(y)
        return y

In [None]:
class Channel_Att(torch.nn.Module):
    def __init__(self, embed_dim):
        super(Channel_Att, self).__init__()
        self.embed_dim = embed_dim
        self.bn = torch.nn.LayerNorm(self.embed_dim)

    def forward(self, x):
        residual = x
        x = self.bn(x)
        weight_bn = self.bn.weight.data.abs() / torch.sum(self.bn.weight.data.abs())
        x = torch.mul(weight_bn, x)
        x = torch.sigmoid(x) * residual
        return x


class Spatial_Att(torch.nn.Module):
    def __init__(self, num_tokens, channels):
        super(Spatial_Att, self).__init__()
        self.conv1d = torch.nn.Conv1d(2, 1, kernel_size=3, padding=1, bias=False)
        self.sigmoid = torch.nn.Sigmoid()

    def pixel_normalization(self, x):
        norm = torch.norm(x, p=2, dim=2, keepdim=True)
        return x / (norm + 1e-8)

    def forward(self, x):
        x = self.pixel_normalization(x)

        avg_out = torch.mean(x, dim=2, keepdim=True)
        max_out, _ = torch.max(x, dim=2, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=2)
        x = x.permute(0, 2, 1)
        x = self.conv1d(x)
        x = self.sigmoid(x).permute(0, 2, 1)
        return x

class NAM(torch.nn.Module):
    def __init__(self, num_tokens, embed_dim):
        super(NAM, self).__init__()
        self.channel_att = Channel_Att(embed_dim)
        self.spatial_att = Spatial_Att(num_tokens, embed_dim)

    def forward(self, x):
        x = self.channel_att(x)
        spatial_weight = self.spatial_att(x)
        x = x * spatial_weight
        return x

In [None]:
class ViTNetwork(StandardizeTransformModule):
    def __init__(self,
                 input_shape,
                 patch_shape,
                 output_size,
                 latent_size=64,
                 num_heads=4,
                 n_layers=4,
                 **kwargs):
        super().__init__(num_classes=output_size, **kwargs)
        self.save_hyperparameters()

        # For MNIST, use this...
        # self.normalize = torchvision.transforms.Compose([
        #     torchvision.transforms.Lambda(lambda x: x / 255.0),
        # ])
        # Note that this normalization only makes sense for CIFAR!

        self.patches = torch.nn.Conv2d(input_shape[1],
                                       latent_size,
                                       patch_shape,
                                       patch_shape,
                                       bias=False)
        # self.position_embedding = torch.nn.Embedding((input_shape[-1]//patch_shape[-1])*
        #                                              (input_shape[-2]//patch_shape[-2]),
        #                                              latent_size)

        self.position_embedding = SinePositionEmbedding()
        self.att = NAM(num_tokens=(input_shape[-1] // patch_shape[-1]) *
                                       (input_shape[-2] // patch_shape[-2]),
                             embed_dim=latent_size)
        self.transformer_blocks = torch.nn.Sequential(*[
            TransformerBlock(latent_size=latent_size,
                             num_heads=num_heads) for _ in range(n_layers)
        ])
        self.pooling = torch.nn.AdaptiveAvgPool1d(1)
        self.linear = torch.nn.Linear(latent_size,
                                      output_size)

    def forward(self, x):
        y = x
        y = super().forward(y)
        y = self.patches(y)
        y = y.reshape(y.shape[0:2] + (-1,)).permute(0, 2, 1)
        # y = y + self.position_embedding(torch.arange(0,y.shape[1]).type_as(x).long())
        y = y + self.position_embedding(y)
        y = self.att(y)
        y = self.transformer_blocks(y).permute(0, 2, 1)
        y = self.pooling(y).squeeze()
        y = self.linear(y)
        return y

In [None]:
batch[0].shape

In [None]:
data_module.output_shape

In [None]:
vit_net = ViTNetwork(input_shape=batch[0].shape,
                     patch_shape=(16, 16),
                     output_size=data_module.output_shape[0],
                     latent_size=64,
                     n_layers=4)
summary(vit_net, input_size=batch[0].shape)

In [None]:
logger = pl.loggers.CSVLogger("logs",
                              name="2024-10-10-Transformers",
                              version="vit-0")

In [None]:
trainer = pl.Trainer(logger=logger,
                     max_epochs=1,
                     # epochs
                     enable_progress_bar=True,
                     log_every_n_steps=0,
                     enable_checkpointing=True,
                     callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=50)])

GPU available: False, used: False


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


In [None]:
trainer.validate(vit_net, data_module)


Validation DataLoader 0:   0%|                                                                                                      | 0/10 [00:00<?, ?it/s]


Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 22.07it/s]


Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 21.82it/s]


───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_acc            0.10000000149011612
        val_loss            2.3553812503814697
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_acc': 0.10000000149011612, 'val_loss': 2.3553812503814697}]

In [None]:
trainer.fit(vit_net, data_module)


Sanity Checking DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 16.87it/s]


                                                                                                                                                           




Training: |                                                                                                                          | 0/? [00:00<?, ?it/s]


Training:   0%|                                                                                                                     | 0/40 [00:00<?, ?it/s]


Epoch 0:   0%|                                                                                                                      | 0/40 [00:00<?, ?it/s]


Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:11<00:00,  3.53it/s]


Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:11<00:00,  3.53it/s, v_num=it-0]





Validation: |                                                                                                                        | 0/? [00:00<?, ?it/s]

[A





Validation:   0%|                                                                                                                   | 0/10 [00:00<?, ?it/s]

[A





Validation DataLoader 0:   0%|                                                                                                      | 0/10 [00:00<?, ?it/s]

[A





Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 20.26it/s]

[A





                                                                                                                                                           

[A


Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:11<00:00,  3.37it/s, v_num=it-0]


Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:11<00:00,  3.37it/s, v_num=it-0]

`Trainer.fit` stopped: `max_epochs=1` reached.



Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:11<00:00,  3.36it/s, v_num=it-0]




In [None]:
results = pd.read_csv(logger.log_dir+"/metrics.csv")

In [None]:
plt.plot(results["epoch"][np.logical_not(np.isnan(results["train_loss"]))],
         results["train_loss"][np.logical_not(np.isnan(results["train_loss"]))],
         label="Training")
plt.plot(results["epoch"][np.logical_not(np.isnan(results["val_loss"]))],
         results["val_loss"][np.logical_not(np.isnan(results["val_loss"]))],
         label="Validation")
plt.legend()
plt.ylabel("CCE Loss")
plt.xlabel("Epoch")
plt.show()

In [None]:
plt.plot(results["epoch"][np.logical_not(np.isnan(results["train_acc"]))],
         results["train_acc"][np.logical_not(np.isnan(results["train_acc"]))],
         label="Training")
plt.plot(results["epoch"][np.logical_not(np.isnan(results["val_acc"]))],
         results["val_acc"][np.logical_not(np.isnan(results["val_acc"]))],
         label="Validation")
plt.legend()
plt.ylabel("Accuracy")
plt.xlabel("Epoch")
plt.show()

In [None]:
import boto3

model_path = f"/tmp/{N}_{VAL_SPLIT}vit_nam_model_weights.pth"
torch.save(vit_net.state_dict(), model_path)

# upload to s3
s3 = boto3.client('s3')
output_filepath = f"checkpoints/{N}_{VAL_SPLIT}_vit_nam_model_weights.pth"
s3.upload_file(model_path, BUCKET_NAME, output_filepath)
print("✅ Upload model to S3！")