# How can I use a pre-trained neural network with grayscale images?
This work is aimed to validate this answer at StackOverflow: https://stackoverflow.com/questions/51995977/how-can-i-use-a-pre-trained-neural-network-with-grayscale-images#answer-54777347

The idea is to fix the first convolution layer by summing up the weights over the color channels.

In [1]:
!rm -Rf data*
!ls -l

total 4
drwxr-xr-x 1 root root 4096 Sep 14 13:44 sample_data


In [2]:
# https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000
!wget -O data.zip 'https://storage.googleapis.com/kaggle-data-sets/547506/998277/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20220916%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220916T190126Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=250534aa0e9a8fb7744fd44ac4bcc01081c5feb875ae4765228e82049a7bd604a6d5028d9a30dd0a4e1bf1b0b652eb04e7c6b1cec61926fbe6894fae4dfb088043d807c650f46dd07d71fbc3c91b1d7520b58d9b5593260864fb78219208aa9da6b83cf57a89599d4624722fd0e26c58eaea20e71f80f6f1a29d2b249f78f1a66620c6baff0189b442532abb9e5b6c4db8bb78421b441e2c800f7765057e545eadc63a04d407afb6c6dfa54efb4080350d36e268c479d3e577d57d8b599a9f41e498293e2d80f769456e1e81cc73a7045f55392aef9b5cb4faffb8b7187b0a308ad65bd663380c067cc89a69c0e8f9d15b9abedc9d0c9eea78f275d47da822fc'
!unzip -q -d data data.zip
!du -sh data*

--2022-09-17 09:53:08--  https://storage.googleapis.com/kaggle-data-sets/547506/998277/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20220916%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220916T190126Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=250534aa0e9a8fb7744fd44ac4bcc01081c5feb875ae4765228e82049a7bd604a6d5028d9a30dd0a4e1bf1b0b652eb04e7c6b1cec61926fbe6894fae4dfb088043d807c650f46dd07d71fbc3c91b1d7520b58d9b5593260864fb78219208aa9da6b83cf57a89599d4624722fd0e26c58eaea20e71f80f6f1a29d2b249f78f1a66620c6baff0189b442532abb9e5b6c4db8bb78421b441e2c800f7765057e545eadc63a04d407afb6c6dfa54efb4080350d36e268c479d3e577d57d8b599a9f41e498293e2d80f769456e1e81cc73a7045f55392aef9b5cb4faffb8b7187b0a308ad65bd663380c067cc89a69c0e8f9d15b9abedc9d0c9eea78f275d47da822fc
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.195.128, 172.253.117.128, 142.250.99.128, ...
Connecting to storage.

In [3]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from tqdm import tqdm # progress bar
import matplotlib.pyplot as plt # drawing

import os
fcnt = 0
for dirname, _, filenames in os.walk('data'):
  for filename in filenames:
    fcnt += 1
    if fcnt < 10:
      print(os.path.join(dirname, filename))
print(fcnt)

data/imagenet-mini/train/n02165105/n02165105_130.JPEG
data/imagenet-mini/train/n02165105/n02165105_1487.JPEG
data/imagenet-mini/train/n02165105/n02165105_9230.JPEG
data/imagenet-mini/train/n02165105/n02165105_2027.JPEG
data/imagenet-mini/train/n02165105/n02165105_413.JPEG
data/imagenet-mini/train/n02165105/n02165105_1003.JPEG
data/imagenet-mini/train/n02165105/n02165105_6979.JPEG
data/imagenet-mini/train/n02165105/n02165105_11806.JPEG
data/imagenet-mini/train/n02165105/n02165105_8750.JPEG
38668


In [4]:
data_dir = 'data/imagenet-mini/train'

In [5]:
import torch
import torchvision

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device, torch.cuda.get_device_properties(device) if torch.cuda.is_available() else 'CPU'

(device(type='cuda', index=0),
 _CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=15109MB, multi_processor_count=40))

In [6]:
import PIL

class ImageNetDataset(torch.utils.data.Dataset):
  def __init__(self, data_dir, transform=None):
    self.transform = transform
    self.data = []
    for cl in tqdm(os.listdir(data_dir), desc='data'):
      for f in os.listdir(f'{data_dir}/{cl}'):
        self.data += [(f'{data_dir}/{cl}/{f}', cl)]
    
  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, index):
    assert index >= 0
    assert index < len(self.data)
    f, cl = self.data[index]
    img = PIL.Image.open(f).convert('RGB')
    if self.transform:
      img = self.transform(img)
    return img, cl, f
  
  def classes(self):
    return list(set(cl for f, cl in self.data))
  
  def get_class(self, index):
    assert index >= 0
    assert index < len(self.data)
    f, cl = self.data[index]
    return cl

dataset = ImageNetDataset(data_dir)
len(dataset), len(dataset.classes())

data: 100%|██████████| 1000/1000 [00:00<00:00, 24987.51it/s]


(34745, 1000)

In [7]:
def predict(model, loader, out_cnt=1000):
  model.to(device)
  model.eval()
  predictions = []
  weights = []
  for x, _, _ in tqdm(loader, desc='eval'):
    ws = model(x.to(device))
    weights += [ws.cpu().detach().reshape(-1, out_cnt)]
    predictions += [torch.argmax(ws, dim=1).cpu().detach().reshape(-1, 1)]
  return np.vstack(predictions), np.vstack(weights)

In [8]:
import sklearn.metrics
from collections import Counter, defaultdict

def metrics(classes, dataset):
  assert len(classes) == len(dataset)
  classes = classes.reshape(-1)
  gt_classes = np.copy(classes)
  cl_id = defaultdict(list)
  for i in range(len(dataset)):
    # find sample_ids for each class
    cl = dataset.get_class(i)
    cl_id[cl] += [i]
  for cl in cl_id:
    # guess the class - pick the most common class_id
    cmn_cl = Counter(classes[cl_id[cl]]).most_common(1)[0][0]
    gt_classes[cl_id[cl]] = cmn_cl
  return sklearn.metrics.precision_score(gt_classes, classes, average='micro')

In [9]:
import psutil

def validate(model, dataset):
  data_loader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=psutil.cpu_count())
  cl, w = predict(model, data_loader)
  return metrics(cl, dataset)

In [10]:
rn_transform = torchvision.transforms.Compose([
  torchvision.transforms.Resize((224, 224)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def resnet_transform(img):
  img = rn_transform(img)
  return img

In [11]:
validate(
  torchvision.models.resnet18(weights='DEFAULT'),
  ImageNetDataset(data_dir, resnet_transform)
)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

data: 100%|██████████| 1000/1000 [00:00<00:00, 21233.86it/s]
eval: 100%|██████████| 2172/2172 [03:51<00:00,  9.37it/s]


0.7440207224061016

In [12]:
validate(
  torchvision.models.resnet34(weights='DEFAULT'),
  ImageNetDataset(data_dir, resnet_transform)
)

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]

data: 100%|██████████| 1000/1000 [00:00<00:00, 21969.71it/s]
eval: 100%|██████████| 2172/2172 [03:48<00:00,  9.49it/s]


0.8000863433587566

In [13]:
validate(
  torchvision.models.resnet50(weights='DEFAULT'),
  ImageNetDataset(data_dir, resnet_transform)
)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

data: 100%|██████████| 1000/1000 [00:00<00:00, 26437.63it/s]
eval: 100%|██████████| 2172/2172 [03:47<00:00,  9.54it/s]


0.9063174557490287

In [14]:
validate(
  torchvision.models.resnet101(weights='DEFAULT'),
  ImageNetDataset(data_dir, resnet_transform)
)

Downloading: "https://download.pytorch.org/models/resnet101-cd907fc2.pth" to /root/.cache/torch/hub/checkpoints/resnet101-cd907fc2.pth


  0%|          | 0.00/171M [00:00<?, ?B/s]

data: 100%|██████████| 1000/1000 [00:00<00:00, 26661.82it/s]
eval: 100%|██████████| 2172/2172 [04:09<00:00,  8.72it/s]


0.9227802561519644

In [15]:
validate(
  torchvision.models.resnet152(weights='DEFAULT'),
  ImageNetDataset(data_dir, resnet_transform)
)

Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /root/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth


  0%|          | 0.00/230M [00:00<?, ?B/s]

data: 100%|██████████| 1000/1000 [00:00<00:00, 27595.56it/s]
eval: 100%|██████████| 2172/2172 [04:40<00:00,  7.75it/s]


0.9347244207799683

In [16]:
gr_transform = torchvision.transforms.Compose([
  torchvision.transforms.Resize((224, 224)),
  torchvision.transforms.functional.to_grayscale,
  torchvision.transforms.ToTensor(),
])

def grayscale_transform(img):
  img = gr_transform(img)
  return img

def grayscale_fix_model(model):
  w = model.conv1.weight.data.sum(axis=1).reshape(64, 1, 7, 7)
  model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  model.conv1.weight.data = w
  return model

In [17]:
validate(
  grayscale_fix_model(torchvision.models.resnet18(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 28783.11it/s]
eval: 100%|██████████| 2172/2172 [03:17<00:00, 11.02it/s]


0.33898402647863

In [18]:
validate(
  grayscale_fix_model(torchvision.models.resnet34(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 26326.12it/s]
eval: 100%|██████████| 2172/2172 [03:17<00:00, 10.98it/s]


0.41977262915527414

In [19]:
validate(
  grayscale_fix_model(torchvision.models.resnet50(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 26124.27it/s]
eval: 100%|██████████| 2172/2172 [03:32<00:00, 10.24it/s]


0.7534033673909916

In [20]:
validate(
  grayscale_fix_model(torchvision.models.resnet101(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 25133.20it/s]
eval: 100%|██████████| 2172/2172 [03:48<00:00,  9.49it/s]


0.7789897827025472

In [21]:
validate(
  grayscale_fix_model(torchvision.models.resnet152(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 25247.58it/s]
eval: 100%|██████████| 2172/2172 [04:25<00:00,  8.19it/s]


0.807454309972658

In [22]:
gr_transform = torchvision.transforms.Compose([
  torchvision.transforms.Resize((224, 224)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  torchvision.transforms.ToPILImage(),
  torchvision.transforms.functional.to_grayscale,
  torchvision.transforms.ToTensor(),
])

In [23]:
validate(
  grayscale_fix_model(torchvision.models.resnet18(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 25875.59it/s]
eval: 100%|██████████| 2172/2172 [04:00<00:00,  9.03it/s]


0.26182184486976545

In [24]:
validate(
  grayscale_fix_model(torchvision.models.resnet34(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 26301.52it/s]
eval: 100%|██████████| 2172/2172 [04:10<00:00,  8.67it/s]


0.21145488559504966

In [25]:
validate(
  grayscale_fix_model(torchvision.models.resnet50(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 24403.37it/s]
eval: 100%|██████████| 2172/2172 [04:17<00:00,  8.45it/s]


0.19202762987480212

In [26]:
#2
validate(
  grayscale_fix_model(torchvision.models.resnet50(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 23508.43it/s]
eval: 100%|██████████| 2172/2172 [04:17<00:00,  8.42it/s]


0.19202762987480212

In [27]:
validate(
  grayscale_fix_model(torchvision.models.resnet101(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 23227.10it/s]
eval: 100%|██████████| 2172/2172 [04:43<00:00,  7.65it/s]


0.20851921139732335

In [28]:
validate(
  grayscale_fix_model(torchvision.models.resnet152(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 25257.46it/s]
eval: 100%|██████████| 2172/2172 [05:04<00:00,  7.13it/s]


0.21289394157432723

In [29]:
gr_transform = torchvision.transforms.Compose([
  torchvision.transforms.Resize((224, 224)),
  torchvision.transforms.functional.autocontrast,
  torchvision.transforms.functional.to_grayscale,
  torchvision.transforms.ToTensor(),
])

In [30]:
validate(
  grayscale_fix_model(torchvision.models.resnet18(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 26013.77it/s]
eval: 100%|██████████| 2172/2172 [03:23<00:00, 10.66it/s]


0.3505540365520219

In [31]:
validate(
  grayscale_fix_model(torchvision.models.resnet34(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 22351.86it/s]
eval: 100%|██████████| 2172/2172 [03:30<00:00, 10.31it/s]


0.43243632177291697

In [32]:
validate(
  grayscale_fix_model(torchvision.models.resnet50(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 25873.20it/s]
eval: 100%|██████████| 2172/2172 [03:41<00:00,  9.81it/s]


0.7551014534465391

In [33]:
validate(
  grayscale_fix_model(torchvision.models.resnet101(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 26003.12it/s]
eval: 100%|██████████| 2172/2172 [04:01<00:00,  8.99it/s]


0.7812634911498058

In [34]:
validate(
  grayscale_fix_model(torchvision.models.resnet152(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 25450.87it/s]
eval: 100%|██████████| 2172/2172 [04:31<00:00,  7.99it/s]


0.8091523960282055