# How can I use a pre-trained neural network with grayscale images?
This work is aimed to validate this answer at StackOverflow: https://stackoverflow.com/questions/51995977/how-can-i-use-a-pre-trained-neural-network-with-grayscale-images#answer-54777347

The idea is to fix the first convolution layer by summing up the weights over the color channels.

In [1]:
!rm -Rf data* cache debug*
!ls -l

total 4
drwxr-xr-x 1 root root 4096 Sep 14 13:44 sample_data


In [2]:
# https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000
!wget -O data.zip 'https://storage.googleapis.com/kaggle-data-sets/547506/998277/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20220916%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220916T190126Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=250534aa0e9a8fb7744fd44ac4bcc01081c5feb875ae4765228e82049a7bd604a6d5028d9a30dd0a4e1bf1b0b652eb04e7c6b1cec61926fbe6894fae4dfb088043d807c650f46dd07d71fbc3c91b1d7520b58d9b5593260864fb78219208aa9da6b83cf57a89599d4624722fd0e26c58eaea20e71f80f6f1a29d2b249f78f1a66620c6baff0189b442532abb9e5b6c4db8bb78421b441e2c800f7765057e545eadc63a04d407afb6c6dfa54efb4080350d36e268c479d3e577d57d8b599a9f41e498293e2d80f769456e1e81cc73a7045f55392aef9b5cb4faffb8b7187b0a308ad65bd663380c067cc89a69c0e8f9d15b9abedc9d0c9eea78f275d47da822fc'
!unzip -q -d data data.zip

--2022-09-16 20:38:24--  https://storage.googleapis.com/kaggle-data-sets/547506/998277/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20220916%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220916T190126Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=250534aa0e9a8fb7744fd44ac4bcc01081c5feb875ae4765228e82049a7bd604a6d5028d9a30dd0a4e1bf1b0b652eb04e7c6b1cec61926fbe6894fae4dfb088043d807c650f46dd07d71fbc3c91b1d7520b58d9b5593260864fb78219208aa9da6b83cf57a89599d4624722fd0e26c58eaea20e71f80f6f1a29d2b249f78f1a66620c6baff0189b442532abb9e5b6c4db8bb78421b441e2c800f7765057e545eadc63a04d407afb6c6dfa54efb4080350d36e268c479d3e577d57d8b599a9f41e498293e2d80f769456e1e81cc73a7045f55392aef9b5cb4faffb8b7187b0a308ad65bd663380c067cc89a69c0e8f9d15b9abedc9d0c9eea78f275d47da822fc
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.202.128, 74.125.20.128, 108.177.98.128, ...
Connecting to storage.g

In [3]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from tqdm import tqdm # progress bar
import matplotlib.pyplot as plt # drawing

import os
fcnt = 0
for dirname, _, filenames in os.walk('data'):
  for filename in filenames:
    fcnt += 1
    if fcnt < 10:
      print(os.path.join(dirname, filename))
print(fcnt)

data/imagenet-mini/train/n02165105/n02165105_130.JPEG
data/imagenet-mini/train/n02165105/n02165105_1487.JPEG
data/imagenet-mini/train/n02165105/n02165105_9230.JPEG
data/imagenet-mini/train/n02165105/n02165105_2027.JPEG
data/imagenet-mini/train/n02165105/n02165105_413.JPEG
data/imagenet-mini/train/n02165105/n02165105_1003.JPEG
data/imagenet-mini/train/n02165105/n02165105_6979.JPEG
data/imagenet-mini/train/n02165105/n02165105_11806.JPEG
data/imagenet-mini/train/n02165105/n02165105_8750.JPEG
38668


In [4]:
data_dir = 'data/imagenet-mini/train'

In [5]:
import torch
import torchvision

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [17]:
import PIL

class ImageNetDataset(torch.utils.data.Dataset):
  def __init__(self, data_dir, transform=None):
    self.transform = transform
    self.data = []
    for cl in tqdm(os.listdir(data_dir), desc='data'):
      for f in os.listdir(f'{data_dir}/{cl}'):
        self.data += [(f'{data_dir}/{cl}/{f}', cl)]
    self.cache = [None for x in self.data]
    
  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, index):
    assert index >= 0
    assert index < len(self.data)
    f, cl = self.data[index]
    if self.cache[index] is not None:
      return self.cache[index], cl, f
    img = PIL.Image.open(f).convert('RGB')
    if self.transform:
      img = self.transform(img)
    self.cache[index] = img
    return img, cl, f
  
  def classes(self):
    return list(set(cl for f, cl in self.data))
  
  def get_class(self, index):
    assert index >= 0
    assert index < len(self.data)
    f, cl = self.data[index]
    return cl

dataset = ImageNetDataset(data_dir)
len(dataset), len(dataset.classes())

data: 100%|██████████| 1000/1000 [00:00<00:00, 24344.16it/s]


(34745, 1000)

In [7]:
def predict(model, loader, out_cnt=1000):
  model.to(device)
  model.eval()
  predictions = []
  weights = []
  for x, _, _ in tqdm(loader, desc='eval'):
    ws = model(x.to(device))
    weights += [ws.cpu().detach().reshape(-1, out_cnt)]
    predictions += [torch.argmax(ws, dim=1).cpu().detach().reshape(-1, 1)]
  return np.vstack(predictions), np.vstack(weights)

In [8]:
import sklearn.metrics
from collections import Counter

debug_id = 1

def metrics(classes, dataset):
  assert len(classes) == len(dataset)
  classes = classes.reshape(-1)
  gt_classes = np.copy(classes)
  cl_id = {}
  orig_classes = []
  for i in range(len(dataset)):
    # find sample ids for each class
    cl = dataset.get_class(i)
    orig_classes += [cl]
    if cl in cl_id:
      cl_id[cl] += [i]
    else:
      cl_id[cl] = [i]
  for cl in cl_id.keys():
    # guess the class - pick the most common class_id
    cmn_cl = Counter(classes[cl_id[cl]]).most_common(1)[0][0]
    gt_classes[cl_id[cl]] = cmn_cl
  
  global debug_id
  pd.DataFrame(np.hstack([
    classes.reshape(-1, 1),
    gt_classes.reshape(-1, 1),
    np.array(orig_classes).reshape(-1, 1),
  ]), columns = ['cl', 'gt', 'orig']).to_csv(f'debug{debug_id}.csv')
  debug_id += 1
  
  return sklearn.metrics.precision_score(gt_classes, classes, average='micro')

In [9]:
import psutil

def validate(model, dataset):
  data_loader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=psutil.cpu_count())
  cl, w = predict(model, data_loader)
  return metrics(cl, dataset)

In [10]:
import hashlib

class Cache:
  def __init__(self, realm, func):
    self.realm = realm
    self.func = func

  def __call__(self, img):
    data, h = self.cache_get(img)
    if data is not None:
      return data
    data = self.func(img)
    self.cache_store(h, data)
    return data

  def cache_get(self, img):
    h = hashlib.sha256(img.tobytes()).hexdigest()
    cd = f'cache/{self.realm}'
    os.makedirs(cd, 0o700, True)
    cf = f'{cd}/{h}'
    if os.path.exists(cf):
      data = np.load(cf)
      return torch.from_numpy(data), h
    return None, h

  def cache_store(self, h, data):
    cd = f'cache/{self.realm}'
    os.makedirs(cd, 0o700, True)
    cf = f'{cd}/{h}'
    np.save(cf, data.numpy())


In [11]:
rn_transform = torchvision.transforms.Compose([
  torchvision.transforms.Resize((224, 224)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def resnet_transform(img):
  img = rn_transform(img)
  return img

In [12]:
#!echo data/imagenet-mini/train/n01* | wc
#!mv data data_orig
#!mkdir -pv data/imagenet-mini/train
#!cp -a data_orig/imagenet-mini/train/n01* data/imagenet-mini/train/
#!du -sh data*

In [19]:
validate(
  torchvision.models.resnet18(weights='DEFAULT'),
  ImageNetDataset(data_dir, resnet_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 1924.21it/s]
eval:  43%|████▎     | 941/2172 [01:47<02:20,  8.79it/s]


RuntimeError: ignored

In [None]:
validate(
  torchvision.models.resnet50(weights='DEFAULT'),
  ImageNetDataset(data_dir, resnet_transform)
)

In [15]:
gr_transform = torchvision.transforms.Compose([
  torchvision.transforms.Resize((224, 224)),
  torchvision.transforms.functional.to_grayscale,
  torchvision.transforms.ToTensor(),
])

def grayscale_transform(img):
  img = gr_transform(img)
  return img

def grayscale_fix_model(model):
  w = model.conv1.weight.data.sum(axis=1).reshape(64, 1, 7, 7)
  model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  model.conv1.weight.data = w
  return model

In [16]:
validate(
  grayscale_fix_model(torchvision.models.resnet18(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 7595.25it/s]
eval:  44%|████▍     | 951/2172 [01:36<02:03,  9.89it/s]


KeyboardInterrupt: ignored

In [None]:
validate(
  grayscale_fix_model(torchvision.models.resnet50(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

In [None]:
gr_transform = torchvision.transforms.Compose([
  torchvision.transforms.Resize((224, 224)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  torchvision.transforms.ToPILImage(),
  torchvision.transforms.functional.to_grayscale,
  torchvision.transforms.ToTensor(),
])

def grayscale_transform(img):
  img = gr_transform(img)
  return img

In [None]:
validate(
  grayscale_fix_model(torchvision.models.resnet18(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

In [None]:
validate(
  grayscale_fix_model(torchvision.models.resnet50(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

In [None]:
gr_transform = torchvision.transforms.Compose([
  torchvision.transforms.Resize((224, 224)),
  torchvision.transforms.functional.autocontrast,
  torchvision.transforms.functional.to_grayscale,
  torchvision.transforms.ToTensor(),
])

def grayscale_transform(img):
  img = gr_transform(img)
  return img

In [None]:
validate(
  grayscale_fix_model(torchvision.models.resnet18(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)

In [None]:
validate(
  grayscale_fix_model(torchvision.models.resnet50(weights='DEFAULT')),
  ImageNetDataset(data_dir, grayscale_transform)
)