In [23]:
import torch
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
import pandas as pd
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.transforms import v2
from torchvision.io import ImageReadMode
from torch import nn
import numpy as np
from sklearn.metrics import r2_score
import sales_prediction.sales_prediction as sp

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) #presi da timm
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

class SalesDataset(Dataset):
    def __init__(self, references, tabular_data, descriptions, img_path, target, transform=None, target_transform=None):
        self.img_ref = references 
        self.tabular = tabular_data
        self.descriptions = descriptions
        self.target = target
        self.img_path = img_path
        
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_ref)
    
    def __getitem__(self, idx):
        image = read_image(self.img_path + self.img_ref[idx], ImageReadMode.RGB)
        
        tabular_row = torch.from_numpy(self.tabular.iloc[idx].values).float()

        desc_tensor = self.descriptions[self.img_ref[idx]]
        
        target = self.target[idx]
        
        if self.transform: 
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)
        
        return image, tabular_row, desc_tensor, target 

def getDataset(references, tabular_data, descriptions, target, img_path, batch_size, proportion):
    transform_img = v2.Compose([
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
    ])
    
    dataset = SalesDataset(references, tabular_data, descriptions, img_path, target, transform_img, None)
    
    dataset, _ = random_split(dataset, [proportion, 1 - proportion])

    validation_dataloader = DataLoader(dataset, batch_size=batch_size)

    return validation_dataloader

def validation_loop(dataloader, model, loss_fn, device):
    model.eval()
    num_batches = len(dataloader)
    mae = nn.L1Loss()
    avg_mse, avg_mae = 0, 0
    label, prediction = np.array([]), np.array([])
    i = 0
    with torch.no_grad():
        for img, tab, desc, y in dataloader:
            img, desc, tab, y = img.to(device), desc.to(device), tab.to(device), y.to(device)
            pred = model(img, tab, desc)
            print(pred)
            y_np, pred_np = y.cpu().detach().numpy(), pred.cpu().detach().numpy()
            label = np.append(y_np, label)
            prediction = np.append(pred_np, prediction)
            
            avg_mse += loss_fn(pred.squeeze(), y.float()).item()
            avg_mae += mae(pred.squeeze(), y.float()).item()
            
            if i % 500 == 0:
                print(i)
            i= i+1
            
    avg_mse /= num_batches
    avg_mae /= num_batches
    r2 = r2_score(label, prediction)
    bias = np.mean(prediction - label)
    
    print(f"Validation Error: \n Avg MSE: {avg_mse:>8f} \n Avg MAE: {avg_mae:>8f} \n R2: {r2:>8f}\n" + 
          f" Bias: {bias:>8f}\n")
    
    return avg_mse, avg_mae, r2, bias

def get_tabular(tabular_path, desc_path):
    data, references, target = get_data(tabular_path)
    #tokenized_desc = word_embedding(descriptions)
    descrizioni = torch.load(desc_path)
    return data, references, descrizioni, target

def get_data(path):
    data = pd.read_csv(path)
    data = data.sample(8000)
    references = data['IdProdotto'].values
    target = data['Quantity'].values
    
    data = data.drop(columns = ['Descrizione', 'IdProdotto', 'Quantity'], axis='columns')
  
    columns = ['CodiceColore', 'PianoTaglia', 'WaveCode', 'AstronomicalSeasonExternalID', 'SalesSeasonDescription']
    for col in columns:
        encoded_labels, _ = pd.factorize(data[col])
        data[col] = encoded_labels
    
    for col in data.columns: #normalizzo tutto tranne la quantità
        if col != 'Quantity':
            val = data[col]
            if val.std() != 0:
                normalized_labels = (val - val.mean())/val.std()
                data[col] = normalized_labels

    return data, references, target

img_path = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\ResizedImages\\'
desc_path= 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\descrizioni\\descrizioni.pt'
desc_tot_path = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\descrizioni\\descrizioni_tot.pt'

data_week = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\front_img_week.csv'
data_month = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\front_img_month.csv'
data_season = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\front_img_season.csv'

noneg_week = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\nonegozio_week.csv'
noneg_month = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\nonegozio_month.csv'
noneg_season = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\Dati\\nonegozio_season.csv'

path = 'C:\\Users\\GRVRLD00P\\Documents\\Progetto ORS\\results\\month_negozi\\weights.pt'


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Sto usando {device}")

modello = sp.create_model(0,1)
modello = torch.jit.load(path)
modello.eval()

data, references, descriptions, target = get_tabular(data_month, desc_path)
datashuffle = data
datashuffle['LocationId'] = datashuffle['LocationId'].sample(frac=1).reset_index(drop=True)
val = getDataset(references, data, descriptions, target, img_path, 128, 1)
validation_loop(val, modello, nn.MSELoss(), device)
print('Shuffle\n')
val = getDataset(references, datashuffle, descriptions, target, img_path, 128, 1)
validation_loop(val, modello, nn.MSELoss(), device)


Sto usando cuda


  descrizioni = torch.load(desc_path)


tensor([[   nan],
        [   nan],
        [   nan],
        [2.2897],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [2.4953],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.8267],
        [   nan],
        [ 

KeyboardInterrupt: 