In [4]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from importlib import reload
import cv2
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

import albumentations
from albumentations import augmentations
import albumentations.pytorch

In [2]:
def feat_pool(feat: torch.tensor, operation: str):
    """
    Fuses the image's patches feature representation
    :param feat: The network object
    :param operation: Either max or mean for the pooling operation
    :returns: The final 256-D feature representation of the entire image
    """
    if operation == "max":
        return feat.max(axis=0)
    elif operation == "mean":
        return feat.mean(axis=0)
    else:
        raise Exception("The operation can be either mean or max")

In [3]:
df = pd.read_csv('all_features.csv')
train_df = df[~df["fold"].isin([1])]
test_df = df[df["fold"].isin([1])]

In [5]:
class FeatSet(Dataset):
    def __init__(self, dataframe, mode, val_fold, test_fold):
        super().__init__()

        self.dataframe = dataframe
        self.mode = mode
        self.val_fold = val_fold
        self.test_fold = test_fold

        if self.mode == "train":
            rows = self.dataframe[~self.dataframe["fold"].isin([self.val_fold, self.test_fold])]
        elif self.mode == "val":
            rows = self.dataframe[self.dataframe["fold"] == self.val_fold]
        else:
            rows = self.dataframe[self.dataframe["fold"] == self.test_fold]

        print(
            "real:{}, fakes:{}, mode = {}".format(
                len(rows[rows["label"] == 0]), len(rows[rows["label"] == 1]), self.mode
            )
        )

        self.data = []

        for row in rows.values:
            _, label, _, _, _, feature = row

            feature_array = torch.load(feature)
            self.data.append((feature_array, label))

        np.random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        