In [1]:
import os
from pathlib import Path

import math
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.utils.tensorboard.writer import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

from torchmetrics import Metric
from torchmetrics.classification import MultilabelF1Score
import pytorch_lightning as pl

from PIL import Image
import matplotlib.pyplot as plt

import random

import timm
from timm.optim import Lookahead

2023-03-17 03:28:15.383406: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-17 03:28:16.326535: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/cray/pe/papi/6.0.0.15/lib64:/opt/cray/libfabric/1.15.0.0/lib64
2023-03-17 03:28:16.326617: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/cray/pe/papi/6.0.0.15/lib64:/opt/cray/libfabric/1.15.0.0/

In [12]:
class CFG:
    # data
    dataset_dir = Path("/lustrefs/disk/project/lt900011-ai2310/dataset")
    train_csv = dataset_dir / "train.csv"
    val_csv = dataset_dir / "val.csv"
    test_csv = dataset_dir / "test.csv"
    
    # model
    model_name = "gluon_resnet18_v1b"
    num_classes = 3
    learning_rate = 1e-3
    # checkpoint_path = "weights/coatnet_2_rw_224/coatnet_2_rw_224_e30.pth"
    checkpoint_path = "/home/superai052/weights/gluon_resnet18_v1b/gluon_resnet18_v1b_e30.pth"


    # trainer
    batch_size = 64
    shuffle = True
    num_workers = 4
    epoch = 30
    gpus = 1
    
    # save weight
    model_dir = Path("weights") / model_name
    model_path = model_dir / f"{model_name}_e{epoch}.pth"

In [13]:
class PulseRecall(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("true_positives", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("pulse_class_total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds = torch.sigmoid(preds) > 0.5
        
        mask = target[:, 0].bool()  # Only consider instances where the first class is present in the ground truth
        masked_preds = preds[mask]
        masked_target = target[mask]

        true_positives = (masked_preds[:, 0].bool() & masked_target[:, 0].bool()).sum()
        pulse_class_total = masked_target[:, 0].sum().long()

        self.true_positives += true_positives.item()
        self.pulse_class_total += pulse_class_total.item()

    def compute(self):
        pulse_class_recall = self.true_positives.float() / self.pulse_class_total.float()
        return pulse_class_recall


class FRBAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("true_positives", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("pulse_class_total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds = torch.sigmoid(preds) > 0.5
        correct = (preds == target).sum()
        total_elements = target.numel()
        true_positives = (preds[:, 0].bool() & target[:, 0].bool()).sum()
        pulse_class_total = target[:, 0].sum().long()

        self.correct = self.correct + correct.item()
        self.total = self.total + total_elements
        self.true_positives = self.true_positives + true_positives.item()
        self.pulse_class_total = self.pulse_class_total + pulse_class_total.item()

    def compute(self):
        all_class_accuracy = self.correct.float() / self.total.float()
        pulse_class_recall = self.true_positives.float() / self.pulse_class_total.float()
        return pulse_class_recall * all_class_accuracy




In [9]:
class MLPLayer(nn.Module):
    def __init__(self, input_size, output_size=3, hidden_sizes=None):
        super(MLPLayer, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [512, 256]

        layer_sizes = [input_size] + hidden_sizes + [output_size]
        layers = []
        for i in range(len(layer_sizes) - 1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            if i < len(layer_sizes) - 2:  # Add ReLU activation only between hidden layers
                layers.append(nn.ReLU())

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

class TimmLightningModel(pl.LightningModule):
    def __init__(self, 
                 model_name='resnet18',
                 pretrained=False, 
                 num_classes=3, 
                 learning_rate=1e-3, 
                 hidden_dim=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = timm.create_model(model_name, 
                                       pretrained=pretrained, 
                                       num_classes=0)
        self.mlp = MLPLayer(input_size=self.model.num_features, 
                            output_size=num_classes,
                            hidden_sizes=hidden_dim)
        self.loss_fn = nn.BCEWithLogitsLoss()

        # accuracy
        self.train_f1 = MultilabelF1Score(
            num_labels=num_classes, 
            threshold=0.5, 
            average='weighted', 
            multidim_average='global')
        self.val_f1 = MultilabelF1Score(
            num_labels=num_classes, 
            threshold=0.5, 
            average='weighted', 
            multidim_average='global')
        
        # pulse recall
        self.train_pulse_recall = PulseRecall()
        self.val_pulse_recall = PulseRecall()
        
        # # remove
        # self.train_frb_accuracy = FRBAccuracy()
        # self.val_frb_accuracy = FRBAccuracy()

    def forward(self, x):
        x = self.model(x)
        x = self.mlp(x)
        return x
    
    @torch.no_grad()
    def predict(self, x):
        x = x.to(self.device)
        logits = self.forward(x)
        prob = torch.sigmoid(logits)
        return prob
    
    def configure_optimizers(self):
        base_optimizer = torch.optim.Adam(self.parameters(), 
                                          lr=self.hparams.learning_rate)
        optimizer = Lookahead(base_optimizer)
        
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='max', 
                                      factor=0.1, 
                                      patience=10, 
                                      verbose=True)
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_frb_accuracy",
            }
        }

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y_hat_sig = torch.sigmoid(y_hat)
        
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', 
                 loss, 
                 on_step=False, 
                 on_epoch=True, 
                 logger=True, 
                 prog_bar=True)
        
        # accuracy
        self.train_f1(y_hat_sig, y)
        accuracy = self.train_f1.compute()
        self.log('train_f1', 
                 accuracy, 
                 on_step=False, 
                 on_epoch=True, 
                 logger=True, 
                 prog_bar=True)
        
        # pulse recall
        self.train_pulse_recall(y_hat, y)
        pulse_recall = self.train_pulse_recall.compute()
        self.log("train_pulse_recall",
                 pulse_recall,
                 on_step=False, 
                 on_epoch=True, 
                 logger=True,
                 prog_bar=True)
        
        # frb accuracy | pulse_recall * accuracy
        frb_accuracy = pulse_recall * accuracy
        self.log('train_frb_accuracy', 
                 frb_accuracy, 
                 on_step=False, 
                 on_epoch=True, 
                 logger=True,
                 prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y_hat_sig = torch.sigmoid(y_hat)
        
        # loss
        loss = self.loss_fn(y_hat, y)
        self.log('val_loss', 
                 loss, 
                 on_step=False, 
                 on_epoch=True, 
                 logger=True, 
                 prog_bar=True)
        
        # accuracy
        self.val_f1(y_hat_sig, y)
        accuracy = self.val_f1.compute()
        self.log('val_f1', 
                 accuracy, 
                 on_step=False, 
                 on_epoch=True, 
                 logger=True, 
                 prog_bar=True)

        # pulse recall
        self.val_pulse_recall(y_hat, y)
        pulse_recall = self.val_pulse_recall.compute()
        self.log("val_pulse_recall",
                 pulse_recall,
                 on_step=False, 
                 on_epoch=True, 
                 logger=True,
                 prog_bar=True)
        
        
        # frb accuracy
        frb_accuracy = pulse_recall * accuracy
        self.log('val_frb_accuracy', 
                 frb_accuracy, 
                 on_step=False, 
                 on_epoch=True, 
                 logger=True,
                 prog_bar=True)

    def test_step(self, batch, batch_idx):
        # Implement the test step if needed
        pass

In [36]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = TimmLightningModel(
    model_name=CFG.model_name,
    num_classes=CFG.num_classes,
    learning_rate=CFG.learning_rate,
)

saved_state_dict = torch.load(CFG.checkpoint_path)
device = torch.device(device)
model.load_state_dict(saved_state_dict)
# model.to(device)
model.eval()
model

TimmLightningModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act1): ReLU(inplace=True)
        (aa): Identity()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padd

In [20]:
model.device

device(type='cpu')

In [21]:
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]

transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std),
])

def preprocess_image(sub_signal, transform):
    sub_signal = Image.fromarray(sub_signal, "L").convert("RGB")
    return transform(sub_signal).unsqueeze(0)

In [22]:
test = np.zeros((256, 256))
test = preprocess_image(test, transform=transform)
test.shape

torch.Size([1, 3, 224, 224])

In [23]:
prob = model.predict(test)
print(prob)

tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]])


In [24]:
prob.cpu()

tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]])

In [26]:
prob.cpu().numpy()[0] > 0.5

array([False, False,  True])

In [28]:
from tqdm import tqdm

In [31]:
%%time
list_data_prob = []
for _ in tqdm(list(range(10500))):
    test = np.zeros((256, 256))
    test = preprocess_image(test, transform=transform)
    prob = model.predict(test)
    list_data_prob.append(prob)
    # print(prob)

100%|██████████████████████████████████████████████████████████████████████████████████| 10500/10500 [01:39<00:00, 105.62it/s]

CPU times: user 26min 28s, sys: 624 ms, total: 26min 29s
Wall time: 1min 39s





In [32]:
list_data_prob

[tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05, 0.0000e+00, 1.0000e+00]]),
 tensor([[2.3921e-05

In [24]:
import torch

RuntimeError: THPDtypeType.tp_dict == nullptr INTERNAL ASSERT FAILED at "../torch/csrc/Dtype.cpp":135, please report a bug to PyTorch. 

In [33]:
checkpoint_path = "/home/superai052/super_workspace/model_suea/weights/gluon_resnet18_v1b/gluon_resnet18_v1b_e10.pth"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# For simple models, you can load the state dict directly
# model.load_state_dict(checkpoint)

# If the checkpoint contains other information (e.g., optimizer state, training progress), extract the model state dict
model_state_dict = checkpoint['model_state_dict']
model.load_state_dict(model_state_dict)


KeyError: 'model_state_dict'

In [None]:
# Set the model to evaluation mode
model.eval()

# Define a sample input tensor
dummy_input = torch.randn(1, 3, 224, 224) # Adjust the input size based on your model's input

# Export the model to ONNX format
onnx_model_path = "output_model.onnx"
torch.onnx.export(model, dummy_input, onnx_model_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
