# Colab-mmclassification

Original repo: [open-mmlab/mmclassification](https://github.com/open-mmlab/mmclassification)

My fork: [styler00dollar/Colab-mmclassification](https://github.com/styler00dollar/Colab-mmclassification)

In [None]:
!nvidia-smi

In [None]:
#@title install
!git clone https://github.com/open-mmlab/mmclassification.git
%cd mmclassification
!pip install -e .  # or "python setup.py develop"
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print('Google Drive connected.')

In [None]:
#@title copy and extract own dataset
!cp "/content/drive/MyDrive/classification.7z" "/content/classification.7z"
%cd /content/
!7z x /content/classification.7z

# Dataset

Spaces are very bad, you need to remove them. A simple fix is md5.

In [None]:
#@title (Optional) Hash all files to md5 and move broken files to another folder
import glob
import hashlib
from tqdm import tqdm
import cv2
import os
import shutil
input_folder = '/content/images/' #@param {type:"string"}
broken_folder = '/content/broken/' #@param {type:"string"}


hash_md5 = hashlib.md5()
for subdir, dirs, files in os.walk(input_folder):
  for dir in dirs:
      files = glob.glob(input_folder + "/" + dir + '/**/*.png', recursive=True)
      files_jpg = glob.glob(input_folder + "/" + dir + '/**/*.jpg', recursive=True)
      files.extend(files_jpg)

      for f in tqdm(files):
        image = cv2.imread(f)
        
        original_folder = os.path.split(f)[0]
        with open(f, "rb") as file:
            for chunk in iter(lambda: file.read(4096), b""):
                hash_md5.update(chunk)
        if image is not None:
            shutil.move(f, os.path.join(original_folder, os.path.basename(hash_md5.hexdigest()+os.path.splitext(f)[1])))
        else:
            shutil.move(f, os.path.join(broken_folder, os.path.basename(hash_md5.hexdigest()+os.path.splitext(f)[1])))

In [None]:
#@title (Optional) Re-saving png images with OpenCV to avoid ``libpng warning: iCCP: known incorrect sRGB profile``
import glob
import hashlib
from tqdm import tqdm
import cv2
import os
import shutil
input_folder = '/content/images/' #@param {type:"string"}
broken_folder = '/content/broken/' #@param {type:"string"}


hash_md5 = hashlib.md5()
for subdir, dirs, files in os.walk(input_folder):
  for dir in dirs:
      files = glob.glob(input_folder + "/" + dir + '/**/*.png', recursive=True)

      for f in tqdm(files):
        image = cv2.imread(f)
        if image is not None:
            cv2.imwrite(f, image)
        else:
            with open(f, "rb") as file:
              for chunk in iter(lambda: file.read(4096), b""):
                  hash_md5.update(chunk)
            shutil.move(f, os.path.join(broken_folder, os.path.basename(hash_md5.hexdigest()+os.path.splitext(f)[1])))

In [None]:
#@title [dataset creation](https://github.com/bentrevett/pytorch-image-classification/blob/master/5_resnet.ipynb) (Split dataset in ```/train``` and ```/test```. Searches for ```/images```)

%cd /content/
TRAIN_RATIO = 0.9 #@param {type:"number"}
data_dir = '/content/' #@param {type:"string"}


images_dir = os.path.join(data_dir, 'images')
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')

if os.path.exists(train_dir):
    shutil.rmtree(train_dir) 
if os.path.exists(test_dir):
    shutil.rmtree(test_dir)
    
os.makedirs(train_dir)
os.makedirs(test_dir)

classes = os.listdir(images_dir)

for c in classes:
    
    class_dir = os.path.join(images_dir, c)
    
    images = os.listdir(class_dir)
       
    n_train = int(len(images) * TRAIN_RATIO)
    
    train_images = images[:n_train]
    test_images = images[n_train:]
    
    os.makedirs(os.path.join(train_dir, c), exist_ok = True)
    os.makedirs(os.path.join(test_dir, c), exist_ok = True)
    
    for image in train_images:
        image_src = os.path.join(class_dir, image)
        image_dst = os.path.join(train_dir, c, image) 
        shutil.copyfile(image_src, image_dst)
        
    for image in test_images:
        image_src = os.path.join(class_dir, image)
        image_dst = os.path.join(test_dir, c, image) 
        shutil.copyfile(image_src, image_dst)


It is needed to create annotation files. Warning: The classes are represented with numbers. Do that once for validation and training data. ```classes.txt``` will show the mapping.

In [None]:
#@title Generate [needed annotation files](https://github.com/open-mmlab/mmclassification/blob/master/docs/tutorials/new_dataset.md).
import os
import glob
data_dir = '/content/test' #@param {type:"string"}
annotation_output = '/content/test.txt' #@param {type:"string"}
class_output = '/content/classes_test.txt' #@param {type:"string"}
counter = 0

if os.path.exists(annotation_output):
  os.remove(annotation_output)
if os.path.exists(class_output):
  os.remove(class_output)

for subdir, dirs, files in os.walk(data_dir):
  for dir in dirs:
    folder_path = os.path.join(data_dir, dir)

    files = glob.glob(folder_path + '/**/*.png', recursive=True)
    files_jpg = glob.glob(folder_path + '/**/*.jpg', recursive=True)
    files.extend(files_jpg)

    f=open(annotation_output,'a')
    for ele in files:
        f.write(ele+" "+str(counter)+'\n')

    f.close()

    f=open(class_output,'a')
    f.write(str(dir)+" "+str(counter)+'\n')

    f.close()

    counter += 1

# Example with [mmclassification/configs/resnext/resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py).
```
_base_ = [
    '../_base_/models/resnext50_32x4d.py',
    '../_base_/datasets/imagenet_bs32.py',
    '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
```

In [None]:
#@title imagenet_bs256.py (max epoch)
%%writefile /content/mmclassification/configs/_base_/schedules/imagenet_bs256.py
# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[30, 60, 90])
runner = dict(type='EpochBasedRunner', max_epochs=100)

In [None]:
#@title resnext50_32x4d.py (amount classes)
%%writefile /content/mmclassification/configs/_base_/models/resnext50_32x4d.py
# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNeXt',
        depth=50,
        num_stages=4,
        out_indices=(3, ),
        groups=32,
        width_per_group=4,
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),
    ))


In [None]:
#@title imagenet_bs32.py (edit paths)
%%writefile /content/mmclassification/configs/_base_/datasets/imagenet_bs32.py
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_prefix='/content/images',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='/content/images',
        ann_file='/content/test.txt',
        pipeline=test_pipeline),
    test=dict(
        # replace `data/val` with `data/test` for standard test
        type=dataset_type,
        data_prefix='data/imagenet/val',
        ann_file='/content/test.txt',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='accuracy')


In [None]:
#@title train (resnext50_32x4d_b32x8_imagenet.py)
%cd /content/mmclassification/
!python tools/train.py /content/mmclassification/configs/resnext/resnext50_32x4d_b32x8_imagenet.py

# Example with [mmclassification/configs/_base_/models/resnest50.py](https://github.com/open-mmlab/mmclassification/blob/24fd4fb62734cc87c0fec551be9185668c30c52f/configs/_base_/models/resnest50.py). 
ResNeSt is currently not in the documentation, but can be added manually.

In [None]:
#@title create resnest50.py
%%writefile /content/mmclassification/configs/resnest50.py
_base_ = [
    '/content/mmclassification/configs/_base_/models/resnest50.py',
    '/content/mmclassification/configs/_base_/datasets/imagenet_bs32.py',
    '/content/mmclassification/configs/_base_/schedules/imagenet_bs256.py', '/content/mmclassification/configs/_base_/default_runtime.py'
]

In [None]:
#@title imagenet_bs256.py (max epoch)
%%writefile /content/mmclassification/configs/_base_/schedules/imagenet_bs256.py
# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[30, 60, 90])
runner = dict(type='EpochBasedRunner', max_epochs=20)


In [None]:
#@title resnest50.py (amount classes / topk)
%%writefile /content/mmclassification/configs/_base_/models/resnest50.py
# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNeSt',
        depth=50,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=2,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 1),
    ))


In [None]:
#@title imagenet_bs32.py (edit paths)
%%writefile /content/mmclassification/configs/_base_/datasets/imagenet_bs32.py
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_prefix='/content/train/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='/content/test/',
        ann_file='/content/test.txt',
        pipeline=test_pipeline),
    test=dict(
        # replace `data/val` with `data/test` for standard test
        type=dataset_type,
        data_prefix='/content/test/',
        ann_file='/content/test.txt',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='accuracy')


In [None]:
#@title formatting.py (avoiding img_metas to avoid errors during validation/interference)
%%writefile /content/mmclassification/mmcls/datasets/pipelines/formating.py
from collections.abc import Sequence

import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from PIL import Image

from ..builder import PIPELINES


def to_tensor(data):
    """Convert objects of various python types to :obj:`torch.Tensor`.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int` and :class:`float`.
    """
    if isinstance(data, torch.Tensor):
        return data
    elif isinstance(data, np.ndarray):
        return torch.from_numpy(data)
    elif isinstance(data, Sequence) and not mmcv.is_str(data):
        return torch.tensor(data)
    elif isinstance(data, int):
        return torch.LongTensor([data])
    elif isinstance(data, float):
        return torch.FloatTensor([data])
    else:
        raise TypeError(
            f'Type {type(data)} cannot be converted to tensor.'
            'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
            '`Sequence`, `int` and `float`')


@PIPELINES.register_module()
class ToTensor(object):

    def __init__(self, keys):
        self.keys = keys

    def __call__(self, results):
        for key in self.keys:
            results[key] = to_tensor(results[key])
        return results

    def __repr__(self):
        return self.__class__.__name__ + f'(keys={self.keys})'


@PIPELINES.register_module()
class ImageToTensor(object):

    def __init__(self, keys):
        self.keys = keys

    def __call__(self, results):
        for key in self.keys:
            img = results[key]
            if len(img.shape) < 3:
                img = np.expand_dims(img, -1)
            results[key] = to_tensor(img.transpose(2, 0, 1))
        return results

    def __repr__(self):
        return self.__class__.__name__ + f'(keys={self.keys})'


@PIPELINES.register_module()
class Transpose(object):

    def __init__(self, keys, order):
        self.keys = keys
        self.order = order

    def __call__(self, results):
        for key in self.keys:
            results[key] = results[key].transpose(self.order)
        return results

    def __repr__(self):
        return self.__class__.__name__ + \
            f'(keys={self.keys}, order={self.order})'


@PIPELINES.register_module()
class ToPIL(object):

    def __init__(self):
        pass

    def __call__(self, results):
        results['img'] = Image.fromarray(results['img'])
        return results


@PIPELINES.register_module()
class ToNumpy(object):

    def __init__(self):
        pass

    def __call__(self, results):
        results['img'] = np.array(results['img'], dtype=np.float32)
        return results


@PIPELINES.register_module()
class Collect(object):
    """
    Collect data from the loader relevant to the specific task.

    This is usually the last stage of the data loader pipeline. Typically keys
    is set to some subset of "img" and "gt_label".

    Args:
        keys (Sequence[str]): Keys of results to be collected in ``data``.
        meta_keys (Sequence[str], optional): Meta keys to be converted to
            ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
            Default: ``('filename', 'ori_shape', 'img_shape', 'flip',
            'flip_direction', 'img_norm_cfg')``

    Returns:
        dict: The result dict contains the following keys
                - keys in``self.keys``
                - ``img_metas`` if avaliable
    """

    def __init__(self,
                 keys,
                 meta_keys=('filename', 'ori_shape', 'img_shape', 'flip',
                            'flip_direction', 'img_norm_cfg')):
        self.keys = keys
        self.meta_keys = meta_keys

    def __call__(self, results):
        data = {}
        img_meta = {}
        for key in self.meta_keys:
            if key in results:
                img_meta[key] = results[key]
        #data['img_metas'] = DC(img_meta, cpu_only=True)
        for key in self.keys:
            data[key] = results[key]
        return data

    def __repr__(self):
        return self.__class__.__name__ + \
            f'(keys={self.keys}, meta_keys={self.meta_keys})'


@PIPELINES.register_module()
class WrapFieldsToLists(object):
    """Wrap fields of the data dictionary into lists for evaluation.

    This class can be used as a last step of a test or validation
    pipeline for single image evaluation or inference.

    Example:
        >>> test_pipeline = [
        >>>    dict(type='LoadImageFromFile'),
        >>>    dict(type='Normalize',
                    mean=[123.675, 116.28, 103.53],
                    std=[58.395, 57.12, 57.375],
                    to_rgb=True),
        >>>    dict(type='ImageToTensor', keys=['img']),
        >>>    dict(type='Collect', keys=['img']),
        >>>    dict(type='WrapIntoLists')
        >>> ]
    """

    def __call__(self, results):
        # Wrap dict fields into lists
        for key, val in results.items():
            results[key] = [val]
        return results

    def __repr__(self):
        return f'{self.__class__.__name__}()'


In [None]:
#@title (Optional) accuracy.py (forcing top1 instead of topk, only do this if you have less than 5 classes)
%%writefile /content/mmclassification/mmcls/models/losses/accuracy.py
import numpy as np
import torch
import torch.nn as nn


def accuracy_numpy(pred, target, topk=1, thrs=None):
    if thrs is None:
        thrs = 0.0
    if isinstance(thrs, float):
        thrs = (thrs, )
        res_single = True
    elif isinstance(thrs, tuple):
        res_single = False
    else:
        raise TypeError(
            f'thrs should be float or tuple, but got {type(thrs)}.')

    res = []
    maxk = max(topk)
    num = pred.shape[0]
    pred_label = pred.argsort(axis=1)[:, -maxk:][:, ::-1]
    pred_score = np.sort(pred, axis=1)[:, -maxk:][:, ::-1]

    for k in topk:
        correct_k = pred_label[:, :k] == target.reshape(-1, 1)
        res_thr = []
        for thr in thrs:
            # Only prediction values larger than thr are counted as correct
            _correct_k = correct_k & (pred_score[:, :k] > thr)
            _correct_k = np.logical_or.reduce(_correct_k, axis=1)
            res_thr.append(_correct_k.sum() * 100. / num)
        if res_single:
            res.append(res_thr[0])
        else:
            res.append(res_thr)
    return res


def accuracy_torch(pred, target, topk=1, thrs=None):
    if thrs is None:
        thrs = 0.0
    if isinstance(thrs, float):
        thrs = (thrs, )
        res_single = True
    elif isinstance(thrs, tuple):
        res_single = False
    else:
        raise TypeError(
            f'thrs should be float or tuple, but got {type(thrs)}.')

    res = []
    maxk = max(topk)
    num = pred.size(0)
    pred_score, pred_label = pred.topk(1, dim=1)
    pred_label = pred_label.t()
    correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
    for k in topk:
        res_thr = []
        for thr in thrs:
            # Only prediction values larger than thr are counted as correct
            _correct = correct & (pred_score.t() > thr)
            correct_k = _correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res_thr.append(correct_k.mul_(100. / num))
        if res_single:
            res.append(res_thr[0])
        else:
            res.append(res_thr)
    return res


def accuracy(pred, target, topk=1, thrs=None):
    """Calculate accuracy according to the prediction and target

    Args:
        pred (torch.Tensor | np.array): The model prediction.
        target (torch.Tensor | np.array): The target of each prediction
        topk (int | tuple[int]): If the predictions in ``topk``
            matches the target, the predictions will be regarded as
            correct ones. Defaults to 1.
        thrs (float, optional): thrs (float | tuple[float], optional):
            Predictions with scores under the thresholds are considered
            negative. Default to None.

    Returns:
        float | list[float] | list[list[float]]: If the input ``topk`` is a
            single integer, the function will return a single float or a list
            depending on whether ``thrs`` is a single float. If the input
            ``topk`` is a tuple, the function will return a list of results
            of accuracies of each ``topk`` number. That is to say, as long as
            ``topk`` is a tuple, the returned list shall be of the same length
            as topk.
    """
    assert isinstance(topk, (int, tuple))
    if isinstance(topk, int):
        topk = (topk, )
        return_single = True
    else:
        return_single = False

    if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
        res = accuracy_torch(pred, target, topk, thrs)
    elif isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
        res = accuracy_numpy(pred, target, topk, thrs)
    else:
        raise TypeError(
            f'pred and target should both be torch.Tensor or np.ndarray, '
            f'but got {type(pred)} and {type(target)}.')

    return res[0] if return_single else res


class Accuracy(nn.Module):

    def __init__(self, topk=(1, )):
        """Module to calculate the accuracy

        Args:
            topk (tuple): The criterion used to calculate the
                accuracy. Defaults to (1,).
        """
        super().__init__()
        self.topk = topk

    def forward(self, pred, target):
        """Forward function to calculate accuracy

        Args:
            pred (torch.Tensor): Prediction of models.
            target (torch.Tensor): Target for each prediction.

        Returns:
            list[float]: The accuracies under different topk criterions.
        """
        return accuracy(pred, target, self.topk)


In [None]:
#@title train (resnest50.py)
%cd /content/mmclassification/
!python tools/train.py /content/mmclassification/configs/resnest50.py

# Test

In [None]:
#@title image_demo.py (printing result instead of plotting)
%%writefile /content/mmclassification/demo/image_demo.py
from argparse import ArgumentParser

from mmcls.apis import inference_model, init_model, show_result_pyplot
import cv2

def main():
    parser = ArgumentParser()
    parser.add_argument('img', help='Image file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint, device=args.device)
    # test a single image
    result = inference_model(model, args.img)
    # show the results
    #show_result_pyplot(model, args.img, result)
    print("result")
    print(result)


if __name__ == '__main__':
    main()


In [None]:
#@title imagenet.py (edit classes if you want to print that)
%%writefile /content/mmclassification/mmcls/datasets/imagenet.py
import os

import numpy as np

from .base_dataset import BaseDataset
from .builder import DATASETS


def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)


def find_folders(root):
    """Find classes by folders under a root.

    Args:
        root (string): root directory of folders

    Returns:
        folder_to_idx (dict): the map from folder name to class idx
    """
    folders = [
        d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))
    ]
    folders.sort()
    folder_to_idx = {folders[i]: i for i in range(len(folders))}
    return folder_to_idx


def get_samples(root, folder_to_idx, extensions):
    """Make dataset by walking all images under a root.

    Args:
        root (string): root directory of folders
        folder_to_idx (dict): the map from class name to class idx
        extensions (tuple): allowed extensions

    Returns:
        samples (list): a list of tuple where each element is (image, label)
    """
    samples = []
    root = os.path.expanduser(root)
    for folder_name in sorted(os.listdir(root)):
        _dir = os.path.join(root, folder_name)
        if not os.path.isdir(_dir):
            continue

        for _, _, fns in sorted(os.walk(_dir)):
            for fn in sorted(fns):
                if has_file_allowed_extension(fn, extensions):
                    path = os.path.join(folder_name, fn)
                    item = (path, folder_to_idx[folder_name])
                    samples.append(item)
    return samples


@DATASETS.register_module()
class ImageNet(BaseDataset):
    """`ImageNet <http://www.image-net.org>`_ Dataset.

    This implementation is modified from
    https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py  # noqa: E501
    """

    IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
    CLASSES = [
        'bad',
        'good'
    ]

    def load_annotations(self):
        if self.ann_file is None:
            folder_to_idx = find_folders(self.data_prefix)
            samples = get_samples(
                self.data_prefix,
                folder_to_idx,
                extensions=self.IMG_EXTENSIONS)
            if len(samples) == 0:
                raise (RuntimeError('Found 0 files in subfolders of: '
                                    f'{self.data_prefix}. '
                                    'Supported extensions are: '
                                    f'{",".join(self.IMG_EXTENSIONS)}'))

            self.folder_to_idx = folder_to_idx
        elif isinstance(self.ann_file, str):
            with open(self.ann_file) as f:
                samples = [x.strip().split(' ') for x in f.readlines()]
        else:
            raise TypeError('ann_file must be a str or None')
        self.samples = samples

        data_infos = []
        for filename, gt_label in self.samples:
            info = {'img_prefix': self.data_prefix}
            info['img_info'] = {'filename': filename}
            info['gt_label'] = np.array(gt_label, dtype=np.int64)
            data_infos.append(info)
        return data_infos


In [None]:
#@title print classification result
%cd /content/mmclassification
!python demo/image_demo.py /content/image.jpg \
    /content/mmclassification/configs/resnest50.py \
    /content/mmclassification/work_dirs/resnest50/latest.pth

In [None]:
#@title getting topk metrics
!python tools/test.py /content/mmclassification/configs/resnext/resnext50_32x4d_b32x8_imagenet.py \
    /content/mmclassification/work_dirs/resnext50_32x4d_b32x8_imagenet/latest.pth