# How can I use a pre-trained neural network with grayscale images?
This work is aimed to validate this answer at StackOverflow: https://stackoverflow.com/questions/51995977/how-can-i-use-a-pre-trained-neural-network-with-grayscale-images#answer-54777347

The idea is to fix the first convolution layer by summing up the weights over the color channels.

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from tqdm import tqdm # progress bar
import matplotlib.pyplot as plt # drawing

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
fcnt = 0
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        fcnt += 1
        if fcnt < 10:
            print(os.path.join(dirname, filename))
print(fcnt)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/imagenetmini-1000/imagenet-mini/val/n01531178/ILSVRC2012_val_00029581.JPEG
/kaggle/input/imagenetmini-1000/imagenet-mini/val/n01531178/ILSVRC2012_val_00048710.JPEG
/kaggle/input/imagenetmini-1000/imagenet-mini/val/n01531178/ILSVRC2012_val_00001274.JPEG
/kaggle/input/imagenetmini-1000/imagenet-mini/val/n01531178/ILSVRC2012_val_00025151.JPEG
/kaggle/input/imagenetmini-1000/imagenet-mini/val/n02412080/ILSVRC2012_val_00017145.JPEG
/kaggle/input/imagenetmini-1000/imagenet-mini/val/n02412080/ILSVRC2012_val_00019446.JPEG
/kaggle/input/imagenetmini-1000/imagenet-mini/val/n02098413/ILSVRC2012_val_00049190.JPEG
/kaggle/input/imagenetmini-1000/imagenet-mini/val/n02098413/ILSVRC2012_val_00007497.JPEG
/kaggle/input/imagenetmini-1000/imagenet-mini/val/n02098413/ILSVRC2012_val_00033278.JPEG
38676


In [2]:
data_dir = '/kaggle/input/imagenetmini-1000/imagenet-mini/train'

In [3]:
import torch
import torchvision

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device, torch.cuda.get_device_properties(device) if torch.cuda.is_available() else 'CPU'

(device(type='cpu'), 'CPU')

In [4]:
import PIL

class ImageNetDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.transform = transform
        self.data = []
        for cl in tqdm(os.listdir(data_dir), desc='data'):
            for f in os.listdir(f'{data_dir}/{cl}'):
                self.data += [(f'{data_dir}/{cl}/{f}', cl)]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        assert index >= 0
        assert index < len(self.data)
        f, cl = self.data[index]
        img = PIL.Image.open(f).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, cl, f
    
    def classes(self):
        return list(set(cl for f, cl in self.data))
    
    def get_class(self, index):
        assert index >= 0
        assert index < len(self.data)
        f, cl = self.data[index]
        return cl

dataset = ImageNetDataset(data_dir)
len(dataset), len(dataset.classes())

data: 100%|██████████| 1000/1000 [00:00<00:00, 1980.86it/s]


(34745, 1000)

In [5]:
def predict(model, loader, out_cnt=1000):
    model.to(device)
    model.eval()
    predictions = []
    weights = []
    for x, _, _ in tqdm(loader, desc='eval'):
        ws = model(x.to(device))
        weights += [ws.cpu().detach().reshape(-1, out_cnt)]
        predictions += [torch.argmax(ws, dim=1).cpu().detach().reshape(-1, 1)]
    return np.vstack(predictions), np.vstack(weights)

In [6]:
import sklearn.metrics
from collections import Counter, defaultdict

def metrics(classes, dataset):
    assert len(classes) == len(dataset)
    classes = classes.reshape(-1)
    gt_classes = np.copy(classes)
    cl_id = defaultdict(list)
    for i in range(len(dataset)):
        # find sample_ids for each class
        cl = dataset.get_class(i)
        cl_id[cl] += [i]
    for cl in cl_id:
        # guess the class - pick the most common class_id
        cmn_cl = Counter(classes[cl_id[cl]]).most_common(1)[0][0]
        gt_classes[cl_id[cl]] = cmn_cl
    return sklearn.metrics.precision_score(gt_classes, classes, average='micro')

In [7]:
import psutil

def validate(model, dataset):
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=psutil.cpu_count())
    cl, w = predict(model, data_loader)
    return metrics(cl, dataset)

psutil.cpu_count()

4

In [8]:
# Updating to the recent pretrained data
!mkdir -pv ~/.cache/torch/hub/checkpoints/
!cp -av /kaggle/input/torchvision-resnet-pretrained/resnet*.pth ~/.cache/torch/hub/checkpoints/
!mv -vf ~/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth ~/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
!mv -vf ~/.cache/torch/hub/checkpoints/resnet101-cd907fc2.pth ~/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
!mv -vf ~/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth ~/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth

mkdir: created directory '/root/.cache/torch'
mkdir: created directory '/root/.cache/torch/hub'
mkdir: created directory '/root/.cache/torch/hub/checkpoints/'
'/kaggle/input/torchvision-resnet-pretrained/resnet101-63fe2227.pth' -> '/root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth'
'/kaggle/input/torchvision-resnet-pretrained/resnet101-cd907fc2.pth' -> '/root/.cache/torch/hub/checkpoints/resnet101-cd907fc2.pth'
'/kaggle/input/torchvision-resnet-pretrained/resnet152-394f9c45.pth' -> '/root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth'
'/kaggle/input/torchvision-resnet-pretrained/resnet152-f82ba261.pth' -> '/root/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth'
'/kaggle/input/torchvision-resnet-pretrained/resnet18-f37072fd.pth' -> '/root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth'
'/kaggle/input/torchvision-resnet-pretrained/resnet34-b627a593.pth' -> '/root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth'
'/kaggle/input/torchvision-resnet-pretrained

In [9]:
rn_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def resnet_transform(img):
    img = rn_transform(img)
    return img

In [10]:
validate(
    torchvision.models.resnet18(pretrained=True),
    ImageNetDataset(data_dir, resnet_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 1724.49it/s]
eval: 100%|██████████| 2172/2172 [26:18<00:00,  1.38it/s]


0.7440207224061016

In [11]:
validate(
    torchvision.models.resnet34(pretrained=True),
    ImageNetDataset(data_dir, resnet_transform)
)

data: 100%|██████████| 1000/1000 [00:01<00:00, 758.56it/s]
eval: 100%|██████████| 2172/2172 [44:53<00:00,  1.24s/it]


0.8000863433587566

In [12]:
validate(
    torchvision.models.resnet50(pretrained=True),
    ImageNetDataset(data_dir, resnet_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 1595.41it/s]
eval: 100%|██████████| 2172/2172 [1:01:14<00:00,  1.69s/it]


0.9063174557490287

In [13]:
validate(
    torchvision.models.resnet101(pretrained=True),
    ImageNetDataset(data_dir, resnet_transform)
)

data: 100%|██████████| 1000/1000 [00:05<00:00, 183.54it/s]
eval: 100%|██████████| 2172/2172 [1:38:12<00:00,  2.71s/it]


0.9227802561519644

In [14]:
validate(
    torchvision.models.resnet152(pretrained=True),
    ImageNetDataset(data_dir, resnet_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 1183.82it/s]
eval: 100%|██████████| 2172/2172 [2:20:20<00:00,  3.88s/it]


0.9347244207799683

In [15]:
gr_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.functional.to_grayscale,
    torchvision.transforms.ToTensor(),
])

def grayscale_transform(img):
    img = gr_transform(img)
    return img

def grayscale_fix_model(model):
    w = model.conv1.weight.data.sum(axis=1).reshape(64, 1, 7, 7)
    model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.conv1.weight.data = w
    return model

In [16]:
validate(
    grayscale_fix_model(torchvision.models.resnet18(pretrained=True)),
    ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:07<00:00, 139.52it/s]
eval: 100%|██████████| 2172/2172 [25:19<00:00,  1.43it/s]


0.33898402647863

In [17]:
validate(
    grayscale_fix_model(torchvision.models.resnet50(pretrained=True)),
    ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:05<00:00, 184.31it/s]
eval: 100%|██████████| 2172/2172 [59:14<00:00,  1.64s/it]


0.7534033673909916

In [18]:
gr_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.functional.to_grayscale,
    torchvision.transforms.ToTensor(),
])

In [19]:
validate(
    grayscale_fix_model(torchvision.models.resnet18(pretrained=True)),
    ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 1077.92it/s]
eval: 100%|██████████| 2172/2172 [26:15<00:00,  1.38it/s]


0.26182184486976545

In [20]:
validate(
    grayscale_fix_model(torchvision.models.resnet50(pretrained=True)),
    ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:04<00:00, 218.51it/s]
eval: 100%|██████████| 2172/2172 [1:00:42<00:00,  1.68s/it]


0.19202762987480212

In [21]:
gr_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.functional.autocontrast,
    torchvision.transforms.functional.to_grayscale,
    torchvision.transforms.ToTensor(),
])

In [22]:
validate(
    grayscale_fix_model(torchvision.models.resnet18(pretrained=True)),
    ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:06<00:00, 164.24it/s]
eval: 100%|██████████| 2172/2172 [25:26<00:00,  1.42it/s]


0.3505540365520219

In [23]:
validate(
    grayscale_fix_model(torchvision.models.resnet50(pretrained=True)),
    ImageNetDataset(data_dir, grayscale_transform)
)

data: 100%|██████████| 1000/1000 [00:00<00:00, 1046.07it/s]
eval: 100%|██████████| 2172/2172 [1:00:13<00:00,  1.66s/it]


0.7551014534465391