In [1]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torch.utils.data import DataLoader
import pandas as pd
import torch
from torch.utils.data import Dataset, random_split

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [3]:
class CustomCSVDataset(Dataset):
    def __init__(self, tabular_data, transform=None):
        """
        Args:
            csv_file (string): Path to the CSV file containing data.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_frame = tabular_data
        self.transform = transform
        self.targets = ["X4", "X11", "X18", "X26", "X50", "X3112"]

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.data_frame.iloc[idx]

        id = sample['id']
        targets = sample[[target+"_mean" for target in self.targets]]          
        features = sample.drop(["id"] + [target+"_mean" for target in self.targets] + [target+"_sd" for target in self.targets]) 
        
        features = torch.tensor(features, dtype=torch.float32)
        targets = torch.tensor(targets, dtype=torch.float32)

        if self.transform:
            features = self.transform(features)

        return id, features, targets

In [4]:
class PGLSDataset(Dataset):
    def __init__(self, tabular_data, image_folder, transform_csv=None):
        """
        Args:
            csv_file (string): Path to the CSV file containing data.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_frame = tabular_data
        self.image_folder = image_folder
        self.transform_csv = transform_csv
        self.targets = ["X4", "X11", "X18", "X26", "X50", "X3112"]

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.data_frame.iloc[idx]
        id = int(sample['id'])
        image = self.image_folder.loader(self.image_folder.root + "/0/" + str(id) + ".jpeg")

        if self.image_folder.transform is not None:
            image = self.image_folder.transform(image)

        targets = sample[[f"{target}_mean" for target in self.targets]].values        
        features = sample.drop(["id"] + [f"{target}_mean" for target in self.targets] + [f"{target}_sd" for target in self.targets])
        
        features = torch.tensor(features.values, dtype=torch.float32)
        targets = torch.tensor(targets, dtype=torch.float32)

        if self.transform_csv:
            features = self.transform_csv(features)
            
        return image, features, targets

In [5]:
transform = transforms.Compose([
    transforms.ToTensor()            # Convert PIL image to tensor (H x W x C) in the range [0.0, 1.0]
])

batch_size = 32

In [6]:
train_images_path = 'data/train_images'
train_csv_path = 'data/train.csv'


tabular_data = pd.read_csv(train_csv_path)
targets = ["X4", "X11", "X18", "X26", "X50", "X3112"]
upper_values = {}
for target in targets:
    upper_values[target] = tabular_data[target+"_mean"].quantile(0.99)
    tabular_data = tabular_data[tabular_data[target+"_mean"] < upper_values[target]]
    tabular_data = tabular_data[tabular_data[target+"_mean"] > 0]
original_means = tabular_data[[f"{target}_mean" for target in targets]].mean()
original_stds = tabular_data[[f"{target}_mean" for target in targets]].std()
# Normalize the targets
tabular_data[[f"{target}_mean" for target in targets]] = (tabular_data[[f"{target}_mean" for target in targets]] - original_means) / original_stds
# Normalize the features
for column in tabular_data.columns:
    if column in ["id"]+targets:
        continue
    min_val = tabular_data[column].min()
    max_val = tabular_data[column].max()
    tabular_data[column] = (tabular_data[column] - min_val) / (max_val - min_val)

train_csv_dataset = CustomCSVDataset(tabular_data=tabular_data, transform=None)
train_images_dataset = ImageFolder(root=train_images_path, transform=transform)


train_image_csv_dataset = PGLSDataset(tabular_data=tabular_data, image_folder=train_images_dataset, transform_csv=None)
train, val = random_split(train_image_csv_dataset, [int(0.8*len(train_image_csv_dataset)), len(train_image_csv_dataset) - int(0.8*len(train_image_csv_dataset))])


train_data_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(val, batch_size=batch_size, shuffle=True)

In [7]:
class PGLSModel(torch.nn.Module):
    def __init__(self, image_model, tabular_model):
        super(PGLSModel, self).__init__()
        self.image_model = image_model
        self.tabular_model = tabular_model
        self.fc = torch.nn.Linear(1000 + 100, 6)

    def forward(self, image, tabular):
        image_features = self.image_model(image)
        tabular_features = self.tabular_model(tabular)
        features = torch.cat((image_features, tabular_features), 1)
        return self.fc(features)

class SimpleTabularModel(torch.nn.Module):
    def __init__(self, input_data_len):
        super(SimpleTabularModel, self).__init__()
        self.fc1 = torch.nn.Linear(input_data_len, input_data_len*4)
        self.fc2 = torch.nn.Linear(input_data_len*4, 100)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        return x

In [8]:
effnet = efficientnet_b0(weights=EfficientNet_B0_Weights)
tabular_model = SimpleTabularModel(input_data_len=tabular_data.shape[1]-1)
model = PGLSModel(effnet, tabular_model)



In [9]:
model.to("cpu")

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
for epoch in range(10):
    for data in train_data_loader:
        image, features, targets = data
        # image = image.to(device)
        # features = features.to(device)
        # targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(image, features)
        loss = torch.nn.functional.mse_loss(outputs, targets)
        loss.backward()
        optimizer.step()
        print(loss.item())


: 