In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import importlib

#system
from pathlib import Path
import time

#ai
import torch
from torch import nn
import torchvision
from torch.utils.data import TensorDataset, DataLoader

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import src
importlib.reload(src)

import src.utils.metrics as metrics
import src.utils.gau as gaussian
import src.utils.plots as plots

from src.models import AE, RNN, LinearClassifier, get_latent_features, rnn_loss_function, ae_loss_function
from src.training import EarlyStopping, train_ae, train_rnn, train_lp, augment_data

from src.final_model import FinalModel, combined_loss_function
from src.final_training import train_final_model as train_fm

importlib.reload(src.final_model)
importlib.reload(src.final_training)

  from .autonotebook import tqdm as notebook_tqdm


<module 'src.final_training' from 'c:\\Users\\pmarc\\Universidad\\inteli\\ZTF_alert_8a\\src\\final_training.py'>

In [2]:
stamps_dataset = pd.read_pickle('data/5stamps_dataset.pkl')

def rename_labels(dataset, old_value, new_value):
    for key in dataset.keys():
        if old_value in dataset[key]:
            dataset[key][new_value] = dataset[key].pop(old_value)

rename_labels(stamps_dataset, 'labels', 'class')
rename_labels(stamps_dataset, 'science', 'images')

In [3]:
train_template = torch.tensor(stamps_dataset['Train']['template'], dtype=torch.float32)
validation_template = torch.tensor(stamps_dataset['Validation']['template'], dtype=torch.float32)
test_template = torch.tensor(stamps_dataset['Test']['template'], dtype=torch.float32)

train_difference = torch.tensor(stamps_dataset['Train']['difference'], dtype=torch.float32)
validation_difference = torch.tensor(stamps_dataset['Validation']['difference'], dtype=torch.float32)
test_difference = torch.tensor(stamps_dataset['Test']['difference'], dtype=torch.float32)

train_image = torch.tensor(stamps_dataset['Train']['images'], dtype=torch.float32)
validation_image = torch.tensor(stamps_dataset['Validation']['images'], dtype=torch.float32)
test_image = torch.tensor(stamps_dataset['Test']['images'], dtype=torch.float32)

train_class_0 = torch.tensor(stamps_dataset['Train']['class'], dtype=torch.float32)
validation_class_0 = torch.tensor(stamps_dataset['Validation']['class'], dtype=torch.float32)
test_class_0 = torch.tensor(stamps_dataset['Test']['class'], dtype=torch.float32)

train_template = train_template.unsqueeze(1).repeat(1, 5, 1, 1)
validation_template = validation_template.unsqueeze(1).repeat(1, 5, 1, 1)
test_template = test_template.unsqueeze(1).repeat(1, 5, 1, 1)


train_dataset = torch.stack((train_template, train_image, train_difference), dim=3  )
validation_dataset = torch.stack((validation_template, validation_difference, validation_difference), dim=3)
test_dataset = torch.stack((test_template, test_image, test_difference), dim=3)

train_template = train_template.unsqueeze(2)  # (samples, 5, 1, 21, 21)
train_image = train_image.unsqueeze(2)        
train_difference = train_difference.unsqueeze(2)  

validation_template = validation_template.unsqueeze(2)
validation_image = validation_image.unsqueeze(2)
validation_difference = validation_difference.unsqueeze(2)

test_template = test_template.unsqueeze(2)
test_image = test_image.unsqueeze(2)
test_difference = test_difference.unsqueeze(2)

# Apilar los tensores a lo largo de la dimensión correcta
train_dataset = torch.cat((train_template, train_image, train_difference), dim=2)
validation_dataset = torch.cat((validation_template, validation_image, validation_difference), dim=2)
test_dataset = torch.cat((test_template, test_image, test_difference), dim=2)

# Crear los conjuntos de datos
train_dataset = TensorDataset(train_dataset, train_class_0)
validation_dataset = TensorDataset(validation_dataset, validation_class_0)
test_dataset = TensorDataset(test_dataset, test_class_0)

train_dataset.tensors[1]
unique, counts = torch.unique(train_dataset.tensors[1], return_counts=True)
print(dict(zip(unique.numpy(), counts.numpy())))

from torch.utils.data import Subset

# Get the indices of each class
class_indices = {cls: (train_dataset.tensors[1] == cls).nonzero(as_tuple=True)[0] for cls in unique}

# Find the minimum number of samples in any class
min_samples = min(len(indices) for indices in class_indices.values())

# Create balanced indices by sampling min_samples from each class
balanced_indices = torch.cat([indices[:min_samples] for indices in class_indices.values()])

# Create a balanced dataset
balanced_train_dataset = Subset(train_dataset, balanced_indices)

# Verify the balance
balanced_counts = torch.unique(balanced_train_dataset.dataset.tensors[1][balanced_indices], return_counts=True)

{0.0: 28092, 1.0: 1516, 2.0: 93619}


In [7]:
#rnn parameters
rnn_type = 'LSTM'
hidden_dim = 128
num_layers = 2

# autoencoder params
latent_dim = 50
n_channels = 3
n_classes = 3

# training params
loss_func = combined_loss_function
alpha = 0.5

max_epochs = 100
max_time = 600 #minutos
lr = 0.666e-3
batch_size = 128

random_sampler = True
use_gpu = True
augmentation = False
early_stop = 5
num_cpu = 16

In [11]:
import src.final_model
import src.final_training

importlib.reload(src.final_model)
importlib.reload(src.final_training)

<module 'src.final_training' from 'c:\\Users\\pmarc\\Universidad\\inteli\\ZTF_alert_8a\\src\\final_training.py'>

In [13]:
modelo_final = src.final_model.FinalModel(latent_dim, n_channels, rnn_type, hidden_dim, num_layers, n_classes, name='final', description='entrenado con dataset completo, loss function combined.')

curves = src.final_training.train_final_model(modelo_final,
                  train_dataset,
                  validation_dataset,
                  loss_func,
                  alpha,
                  max_epochs,
                  max_time,
                  batch_size,
                  lr,
                  random_sampler,
                  augmentation,
                  early_stop,
                  use_gpu,
                  num_cpu)

torch.save(modelo_final.state_dict(), 'models/modelo_final_0_1.pth')


Setup finishied. Starting training...


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [128, 5, 3, 21, 21]