In [1]:
import pynattas as pnas
import configparser
import torch
import pytorch_lightning as pl
from torchvision import transforms
from datasets.L0_thraws_classifier.dataset_weighted import SentinelDataset, SentinelDataModule
from datasets.wake_classifier.dataset import xAIWakesDataset_inf

In [2]:
# Load configuration from config.ini
config = configparser.ConfigParser()
config.read('config.ini')

# Torch stuff
seed = config.getint(section='Computation', option='seed')
#pl.seed_everything(seed=seed, workers=True)  # For reproducibility
torch.set_float32_matmul_precision("medium")  # to make lightning happy
num_workers = config.getint(section='Computation', option='num_workers')
accelerator = config.get(section='Computation', option='accelerator')

# Other input
csv_file = config['Dataset']['csv_path']
root_dir = config['Dataset']['data_path']

## IMAGE LOADER

In [3]:
""" # Read the input image
idx = -1
while idx < 1 or idx > 269:
    try:
        print("Valid indexes for xAIWakes are from 1 to 269.")
        idx = int(input("Input index: "))
        if idx < 1 or idx > 269:
            print("Invalid index. Please enter a value between 1 and 269.")
    except ValueError:
        print("Invalid input. Please enter a valid integer.") """

# Uncomment the line below, and comment the while loop, if you want to set a specific index (e.g., idx = 4)
idx = 12

In [4]:
composed_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])
#dataset = SentinelDataset(
#    root_dir=root_dir,
#    transform=composed_transform,
#)

dataset = xAIWakesDataset_inf(
    root_dir=root_dir,
    csv_dir=csv_file,
    transform=None,
)

image, label = dataset[idx]  # Load one image from the dataset
in_channels = image.shape[0]  # Obtain the number of in channels. Height and Width are 256 x 256 due to transform

# Give the tensor to the right device
input_tensor = image.to('cuda' if torch.cuda.is_available() else 'cpu')

## MODELLO

In [5]:
# Use NAS result if available, otherwise load from config
architecture_code = config['NAS']['architecture_code']
layers = pnas.functions.architecture_builder.parse_architecture_code(architecture_code)

# Get model parameters
model_parameters = {}
log_lr = config.getfloat(section='Search Space', option='default_log_lr')
bs = config.getint(section='Search Space', option='default_bs')
lr = 10 ** log_lr

In [6]:
parsed_layers = pnas.functions.architecture_builder.parse_architecture_code(architecture_code)

In [7]:
in_channels = 4
num_classes = 2
model = pnas.classes.GenericLightningNetwork(
    parsed_layers=parsed_layers,
    input_channels=in_channels,
    #input_height=256,
    #input_width=256,
    num_classes=num_classes,
    learning_rate=lr,
)

In [8]:
#checkpoint = torch.load(rf"/media/warmachine/DBDISK/Andrea/DicDic/logs/tb_logs/checkpoints/OptimizedModel_2024-03-18_13-57-25/version_0/checkpoints/epoch=10-step=12386.ckpt")
checkpoint = torch.load(rf"/media/warmachine/DBDISK/Andrea/DicDic/logs/tb_logs/checkpoints/OptimizedModel_2024-04-24_11-53-18/version_0/checkpoints/epoch=5-step=246.ckpt")
model.load_state_dict(checkpoint["state_dict"])
model.eval()

GenericLightningNetwork(
  (model): GenericNetwork(
    (layers): ModuleList(
      (0): ConvBnAct(
        (0): Conv2d(4, 16, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GELU()
      )
      (1): MaxPool(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (2): CSPMBConvBlock(
        (main_path): Sequential(
          (0): MBConv(
            (steps): Sequential(
              (0): ConvBnAct(
                (0): Conv2d(8, 40, kernel_size=(1, 1), stride=(1, 1))
                (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): ReLU()
              )
              (1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40)
              (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): ReLU()
   

In [9]:
model.cuda()

GenericLightningNetwork(
  (model): GenericNetwork(
    (layers): ModuleList(
      (0): ConvBnAct(
        (0): Conv2d(4, 16, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GELU()
      )
      (1): MaxPool(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (2): CSPMBConvBlock(
        (main_path): Sequential(
          (0): MBConv(
            (steps): Sequential(
              (0): ConvBnAct(
                (0): Conv2d(8, 40, kernel_size=(1, 1), stride=(1, 1))
                (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): ReLU()
              )
              (1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40)
              (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): ReLU()
   

## INFERENCE

In [10]:
input_tensor = input_tensor[None, :, :, :]
input_tensor.shape

torch.Size([1, 256, 4, 256])

In [11]:
input_tensor = input_tensor.permute(0, 2, 1, 3)
input_tensor.shape

torch.Size([1, 4, 256, 256])

In [12]:
inferred = model(input_tensor)
inferred

tensor([[ 4.4528, -1.8159]], device='cuda:0', grad_fn=<AddmmBackward0>)