In [18]:
import torch
import torchvision.models as models
from torch.hub import load_state_dict_from_url
from torchvision.models import EfficientNet_B2_Weights

In [6]:
data = torch.load("preprocessed_images_with_metadata_and_target_small.pt.nosync")
train_record = data[0]
[target, age, implant], image = train_record
train_record

(tensor([ 1., 79.,  0.], dtype=torch.float64),
 tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.5882, 0.5529, 0.5294,  ..., 0.0000, 0.0000, 0.0000],
         [0.0157, 0.0118, 0.0118,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
        dtype=torch.float64))

In [38]:
import torch
import torch.nn as nn
from torchvision.models._api import WeightsEnum


# Weights checksum patching: https://github.com/pytorch/vision/issues/7744
def get_state_dict(self, *args, **kwargs):
    kwargs.pop("check_hash")
    return load_state_dict_from_url(self.url, *args, **kwargs)


WeightsEnum.get_state_dict = get_state_dict


class EfficientNetGray(nn.Module):
    def __init__(self):
        super(EfficientNetGray, self).__init__()
        self.base_model = models.efficientnet_b2(weights=EfficientNet_B2_Weights.DEFAULT)
        self.base_model.eval()

    def forward(self, single_channel_input):
        rgb_input = torch.cat([single_channel_input] * 3, dim=1)
        return self.base_model(rgb_input.float())


model = EfficientNetGray()

In [41]:
with torch.no_grad():
    features = model(image.float().unsqueeze(0).unsqueeze(0))
features

tensor([[-1.0724e+00, -1.0207e-02,  7.2670e-02,  9.9664e-01, -1.8497e-01,
         -3.0801e-01,  4.0875e-01, -1.9868e+00, -1.5537e+00, -1.2971e+00,
         -6.4864e-01, -8.1872e-01, -5.8035e-01, -1.8757e-01, -7.9672e-01,
         -1.1037e+00, -1.0372e+00, -1.0049e+00, -1.2584e+00, -5.7325e-01,
         -8.7652e-01, -2.2784e+00, -9.3959e-01, -2.1066e+00, -1.3285e-01,
         -7.0220e-01, -2.3458e-01, -1.1008e+00, -4.1303e-01, -6.2009e-01,
         -3.0854e-01,  4.2522e-01, -1.6176e-01, -3.7014e-01, -7.1760e-01,
         -2.1375e-01, -2.6798e-01, -1.0201e+00,  1.7759e-01, -2.5644e+00,
         -1.6201e-01, -7.3135e-01, -1.6144e+00, -2.2924e+00, -1.9307e-01,
          2.8706e-03, -1.9070e-01, -1.2422e+00, -1.5156e+00, -1.7021e+00,
          1.9035e-01, -1.0442e+00,  5.7231e-01, -6.5326e-01, -1.4210e+00,
         -1.4023e+00, -9.1073e-01, -1.2671e+00, -1.3012e-01, -1.7947e+00,
          3.6054e-01, -1.4435e+00, -8.7726e-01, -1.0232e+00,  1.5743e-01,
         -1.4760e+00, -1.6800e+00, -4.

In [43]:
import tqdm

data = torch.load("preprocessed_images_with_metadata_and_target.pt.nosync")
featured_data = []
target_data = []


for [target, age, implant], image in tqdm.tqdm(data):
    target_data.append(target)
    with torch.no_grad():
        features = model(image.float().unsqueeze(0).unsqueeze(0))
    featured_data.append((features[0], age, implant))

torch.save(featured_data, "image_features_from_efficient_net_with_meta.pt.nosync")
torch.save(target_data, "targets_for_image_features_from_efficient_net_with_meta.pt.nosync")


100%|██████████| 2304/2304 [07:17<00:00,  5.27it/s]
