In [None]:
import torch
import torchvision.models as models
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

root_path = 'data/nahimova_f/'

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data, targets, root_path, transform=None):
        self.data = data
        self.targets = targets
        self.root_path = root_path
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        x = self.data.iloc[index]
        y = self.targets.iloc[index]

        # загрузка изображения и вырезание участка
        image_path = self.root_path + x["filename"]

        image = Image.open(image_path)
        cropped_image = image.crop(
            (
                x["x_from"],
                x["y_from"],
                x["x_from"] + x["width"],
                x["y_from"] + x["height"],
            )
        )

        if self.transform:
            cropped_image = self.transform(cropped_image)

        return cropped_image, y

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # Размер, который ожидает ResNet
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),  # Нормализация для ResNet
    ]
)

In [None]:
test_dataset = CustomDataset(
    data=x_test, targets=y_test, root_path=root_path, transform=transform
)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
resnet = models.resnet50(pretrained=False)
resnet.fc = torch.nn.Linear(resnet.fc.in_features, 155)

device = torch.device("cpu")
resnet.to(device)

In [None]:
def model_predictions(model, loader):
    real_target, pred_target = [], []
    resnet.to(device)

    resnet.to(device)

    for X, Y in loader:
        X, Y = X.to(device), Y.to(device)
        pred = model(X)
        pred_target.append(pred)
        real_target.append(Y)
    pred_target, real_target = torch.cat(pred_target), torch.cat(real_target)

    return (
        real_target.detach().cpu().numpy(),
        F.softmax(pred_target, dim=-1).argmax(dim=-1).cpu().detach().numpy(),
    )