<a href="https://colab.research.google.com/github/yulinlina/MedMnist/blob/ChestMNIST/ConvNeXt_for_MedMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Use this notebook to finetune a ConvNeXt-tiny model on CIFAR 10 dataset. The [official ConvNeXt repository](https://github.com/facebookresearch/ConvNeXt) is instrumented with [Weights and Biases](https://wandb.ai/site). You can now easily log your train/test metrics and version control your model checkpoints to Weigths and Biases

# ⚽️ Installation and Setup

The following installation instruction is based on [INSTALL.md](https://github.com/facebookresearch/ConvNeXt/blob/main/INSTALL.md) provided by the official ConvNeXt repository. 

In [1]:
#运行前先在"代码执行程序"中选择"更改运行时类型"为GPU
!pip install -qq torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -qq wandb timm==0.4.12 six tensorboardX

[31mERROR: Could not find a version that satisfies the requirement torch==1.8.0+cu111 (from versions: 1.11.0, 1.11.0+cpu, 1.11.0+cu102, 1.11.0+cu113, 1.11.0+cu115, 1.11.0+rocm4.3.1, 1.11.0+rocm4.5.2, 1.12.0, 1.12.0+cpu, 1.12.0+cu102, 1.12.0+cu113, 1.12.0+cu116, 1.12.0+rocm5.0, 1.12.0+rocm5.1.1, 1.12.1, 1.12.1+cpu, 1.12.1+cu102, 1.12.1+cu113, 1.12.1+cu116, 1.12.1+rocm5.0, 1.12.1+rocm5.1.1, 1.13.0, 1.13.0+cpu, 1.13.0+cu116, 1.13.0+cu117, 1.13.0+cu117.with.pypi.cudnn, 1.13.0+rocm5.1.1, 1.13.0+rocm5.2, 1.13.1, 1.13.1+cpu, 1.13.1+cu116, 1.13.1+cu117, 1.13.1+cu117.with.pypi.cudnn, 1.13.1+rocm5.1.1, 1.13.1+rocm5.2, 2.0.0, 2.0.0+cpu, 2.0.0+cpu.cxx11.abi, 2.0.0+cu117, 2.0.0+cu117.with.pypi.cudnn, 2.0.0+cu118, 2.0.0+rocm5.3, 2.0.0+rocm5.4.2, 2.0.1, 2.0.1+cpu, 2.0.1+cpu.cxx11.abi, 2.0.1+cu117, 2.0.1+cu117.with.pypi.cudnn, 2.0.1+cu118, 2.0.1+rocm5.3, 2.0.1+rocm5.4.2)[0m[31m
[0m[31mERROR: No matching distribution found for torch==1.8.0+cu111[0m[31m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Download the official ConvNeXt respository. 

In [2]:
!git clone https://github.com/facebookresearch/ConvNeXt

Cloning into 'ConvNeXt'...
remote: Enumerating objects: 252, done.[K
remote: Counting objects: 100% (249/249), done.[K
remote: Compressing objects: 100% (118/118), done.[K
remote: Total 252 (delta 129), reused 192 (delta 110), pack-reused 3[K
Receiving objects: 100% (252/252), 69.63 KiB | 2.79 MiB/s, done.
Resolving deltas: 100% (129/129), done.


# 🏀 Download the Dataset

We will be finetuning on CIFAR-10 dataset. To use any custom dataset (CIFAR-10 here) the format of the dataset should be as shown below:

```
/path/to/dataset/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class2/
      img4.jpeg
```



In [3]:
!pip install --upgrade git+https://github.com/MedMNIST/MedMNIST.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/MedMNIST/MedMNIST.git
  Cloning https://github.com/MedMNIST/MedMNIST.git to /tmp/pip-req-build-f2n9eic6
  Running command git clone --filter=blob:none --quiet https://github.com/MedMNIST/MedMNIST.git /tmp/pip-req-build-f2n9eic6
  Resolved https://github.com/MedMNIST/MedMNIST.git to commit 18a7564bc1fc3c68adbfeac7590d3949fe91467b
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fire (from medmnist==2.2.2)
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: medmnist, fire
  Building wheel for medmnist (setup.py) ... [?25l[?25hdone
  Created wheel for medmnist: filename=medmnist-2.2.2-py3-none-any.whl size=21964 sha256=4640324a0f4dac96c22

In [4]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator
data_flag = 'chestmnist'
# data_flag = 'breastmnist'
download = True
info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])
# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])
# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

Downloading https://zenodo.org/record/6496656/files/chestmnist.npz?download=1 to /root/.medmnist/chestmnist.npz


100%|██████████| 82802576/82802576 [02:33<00:00, 541086.57it/s]


Using downloaded and verified file: /root/.medmnist/chestmnist.npz
Using downloaded and verified file: /root/.medmnist/chestmnist.npz


In [5]:
!python -m medmnist save --flag=chestmnist --folder=MedMNIST/ --postfix=jpeg

Saving chestmnist train...
100% 78468/78468 [00:14<00:00, 5380.86it/s]
Saving chestmnist val...
100% 11219/11219 [00:02<00:00, 4114.07it/s]
Saving chestmnist test...
100% 22433/22433 [00:04<00:00, 5226.32it/s]


In [17]:
import os
import re

path = '/content/MedMNIST/chestmnist/'
os.chdir(path)

regex = r'test(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)\.jpeg'

for filename in os.listdir('.'):
    match = re.match(regex, filename)
    if match:
        new_name = f"test{match.group(1)}_{match.group(2)}.jpeg"
        os.rename(filename, new_name)

In [18]:
import os
import re

path = '/content/MedMNIST/chestmnist/'
os.chdir(path)

regex = r'train(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)\.jpeg'

for filename in os.listdir('.'):
    match = re.match(regex, filename)
    if match:
        new_name = f"train{match.group(1)}_{match.group(2)}.jpeg"
        os.rename(filename, new_name)

In [19]:
import os
import re

path = '/content/MedMNIST/chestmnist/'
os.chdir(path)

regex = r'val(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)\.jpeg'

for filename in os.listdir('.'):
    match = re.match(regex, filename)
    if match:
        new_name = f"val{match.group(1)}_{match.group(2)}.jpeg"
        os.rename(filename, new_name)

In [22]:
#以下是数据集格式整理脚本，目的是把数据集变成类似CIFAR-10的格式，忽略mv: cannot stat 'val802_6.jpeg': No such file or directory等等输出即可
%cd /content/MedMNIST/chestmnist/
%mkdir train val test
%cd /content/MedMNIST/chestmnist/train/
%mkdir class1 class2
%cd /content/MedMNIST/chestmnist/test/
%mkdir class1 class2
%cd /content/MedMNIST/chestmnist/val/
%mkdir class1 class2
%cd /content/MedMNIST/chestmnist/
%mv train{0..78467}_0.jpeg /content/MedMNIST/chestmnist/train/class1/
%mv train{0..78467}_1.jpeg /content/MedMNIST/chestmnist/train/class2/
%cd /content/MedMNIST/chestmnist/
%mv test{0..11218}_0.jpeg /content/MedMNIST/chestmnist/test/class1/
%mv test{0..11218}_1.jpeg /content/MedMNIST/chestmnist/test/class2/
%cd /content/MedMNIST/chestmnist/
%mv val{0..22432}_0.jpeg /content/MedMNIST/chestmnist/val/class1/
%mv val{0..22432}_1.jpeg /content/MedMNIST/chestmnist/val/class2/

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
mv: cannot stat 'val17433_1.jpeg': No such file or directory
mv: cannot stat 'val17434_1.jpeg': No such file or directory
mv: cannot stat 'val17435_1.jpeg': No such file or directory
mv: cannot stat 'val17436_1.jpeg': No such file or directory
mv: cannot stat 'val17437_1.jpeg': No such file or directory
mv: cannot stat 'val17438_1.jpeg': No such file or directory
mv: cannot stat 'val17439_1.jpeg': No such file or directory
mv: cannot stat 'val17440_1.jpeg': No such file or directory
mv: cannot stat 'val17441_1.jpeg': No such file or directory
mv: cannot stat 'val17442_1.jpeg': No such file or directory
mv: cannot stat 'val17443_1.jpeg': No such file or directory
mv: cannot stat 'val17444_1.jpeg': No such file or directory
mv: cannot stat 'val17445_1.jpeg': No such file or directory
mv: cannot stat 'val17446_1.jpeg': No such file or directory
mv: cannot stat 'val17447_1.jpeg': No such file or directory
mv: cannot stat 'val17448_1.jpeg': No such f

# 🏈 Download Pretrained Weights

We will be finetuning the ConvNeXt Tiny model pretrained on ImageNet 1K dataset.

In [23]:
#先检查一下/content/MedMNIST/chestmnist文件夹是否把所有图片归类到train val test
%cd /content/ConvNeXt/
#下面是下载预训练模型，需要用到在Imagenet-1k上预训练模型，否则效果不好(可以去掉试试)
!wget https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth

/content/ConvNeXt
--2023-05-23 06:50:24--  https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.160.249.12, 18.160.249.77, 18.160.249.45, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.160.249.12|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 265112135 (253M) [binary/octet-stream]
Saving to: ‘convnext_small_22k_224.pth’


2023-05-23 06:50:32 (33.6 MB/s) - ‘convnext_small_22k_224.pth’ saved [265112135/265112135]



# 🎾 Train with Weights and Biases

If you want to log the train and evaluation metrics using Weights and Biases pass `--enable_wandb true`. 

You can also save the finetuned checkpoints as version controlled W&B [Artifacts](https://docs.wandb.ai/guides/artifacts) if you pass `--wandb_ckpt true`.



In [31]:
!python main.py --epochs 30 \
                --model convnext_small \
                --data_set image_folder \
                --data_path /content/MedMNIST/chestmnist/train \
                --eval_data_path /content/MedMNIST/chestmnist/test \
                --nb_classes 2 \
                --num_workers 8 \
                --warmup_epochs 0 \
                --save_ckpt true \
                --output_dir model_ckpt \
                --cutmix 0 \
                --mixup 0 --lr 4e-4 \
                --enable_wandb true --wandb_ckpt true \
                --finetune convnext_small_22k_224.pth 

Not using distributed mode
Namespace(batch_size=64, epochs=30, update_freq=1, model='convnext_small', drop_path=0, input_size=224, layer_scale_init_value=1e-06, model_ema=False, model_ema_decay=0.9999, model_ema_force_cpu=False, model_ema_eval=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, weight_decay_end=None, lr=0.0004, layer_decay=1.0, min_lr=1e-06, warmup_epochs=0, warmup_steps=-1, color_jitter=0.4, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', crop_pct=None, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.0, cutmix=0.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', finetune='convnext_small_22k_224.pth', head_init_scale=1.0, model_key='model|module', model_prefix='', data_path='/content/MedMNIST/chestmnist/train', eval_data_path='/content/MedMNIST/chestmnist/test', nb_classes=2, imagenet_default_mean_and_std=True, data_set='image_folder', output_dir='mode

In [None]:
#装载Google dirve
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#保存
%cd /content
%cp -r MedMNIST drive/MyDrive/MedMNIST
%cp -r ConVNeXt drive/MyDrive/ConVNeXt

/content


# 🏐 Conclusion

* **The above setting gives a top-1 accuracy of ~95%.**
* The ConvNeXt repository comes with modern training regimes and is easy to finetune on any dataset. 
* The finetune model achieves competitive results. 

* By passing two arguments you get the following:

  * Repository of all your experiments (train and test metrics) as a [W&B Project](https://docs.wandb.ai/ref/app/pages/project-page). You can easily compare experiments to find the best performing model.
  * Hyperparameters (Configs) used to train individual models. 
  * System (CPU/GPU/Disk) metrics.
  * Model checkpoints saved as W&B Artifacts. They are versioned and easy to share. 

  Check out the associated [W&B run page](https://wandb.ai/ayut/convnext/runs/16vi9e31). $→$