In [1]:
# # 计算训练集数据均值、方差
# import numpy as np
# import cv2
# import random

# import os
# from tqdm.notebook import tqdm_notebook

# img_root = "/kaggle/input/chest-xray-pneumoniacovid19tuberculosis/train"
# sub_dirs = [os.path.join(img_root, sub_dir) for sub_dir in os.listdir(path=img_root)]

# CNum = 6000  # 挑选多少图片进行计算

# img_h, img_w = 256, 256
# imgs = np.zeros([img_w, img_h, 1, 1])
# means, stdevs = [], []

# data_files = []
# for sub_dir in sub_dirs:
#     data_files.extend([os.path.join(sub_dir, data_type) for data_type in os.listdir(path=sub_dir)])

# random.shuffle(data_files)  # shuffle, 随机挑选图片
# for index in tqdm_notebook(range(CNum)):
#     if index >= len(data_files):
#         break
#     data_file = data_files[index]
#     img = cv2.imread(data_file)
#     img = cv2.resize(img, (img_h, img_w))
#     img = img.transpose(2, 0, 1)
#     img = img.mean(axis=0)
#     img = img[:, :, np.newaxis, np.newaxis]

#     imgs = np.concatenate((imgs, img), axis=3)

# imgs = imgs.astype(np.float32) / 255.

# for i in tqdm_notebook(range(1)):
#     pixels = imgs[:, :, i, :].ravel()  # 拉成一行
#     means.append(np.mean(pixels))
#     stdevs.append(np.std(pixels))

# # cv2 读取的图像格式为BGR，PIL/Skimage读取到的都是RGB不用转
# means.reverse()  # BGR --> RGB
# stdevs.reverse()

# print("normMean = {}".format(means))
# print("normStd = {}".format(stdevs))
# print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))

# # normMean = [0.48698336]
# # normStd = [0.23673548]
# # transforms.Normalize(normMean = [0.48698336], normStd = [0.23673548])

In [2]:
# data.py

import os
import torch
import json

from PIL import Image

TRAIN_DATA_PATH = "train"
VAL_DATA_PATH = "val"
TEST_DATA_PATH = "test"

CLASS_NORMAL_KEY = "NORMAL"
CLASS_PNEUMONIA_KEY = "PNEUMONIA"
CLASS_COVID19_KEY = "COVID19"
CLASS_TUBERCULOSIS_KEY = "TURBERCULOSIS"

CLASS_COUNT = 4

TARGET_FIELD_LABELS = "labels"
TARGET_FIELD_IMAGE_ID = "image_id"
TARGET_FIELD_HEIGHT_WIDTH = "height_width"
TARGET_FIELD_IMAGE_DATA = "image_data"

# 定义数据集
class ChestXRayDataset(torch.utils.data.Dataset):
    _SUPPORT_FILE_TYPES = (".jpg", ".jpeg", ".mpo", ".png")

    def __init__(self, data_path, transforms):
        super(ChestXRayDataset, self).__init__()
        self.data_path = data_path if len(data_path) > 0 else "."
        self.sub_dirs = [os.path.join(self.data_path, sub_dir) for sub_dir in os.listdir(path=data_path)]
        # read class_indict
        json_file = "/kaggle/input/chest-x-ray-classes/chest_x_ray_classes.json"
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)

        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.data_filenames = []
        self.label_type = []
        for sub_dir in self.sub_dirs:
            label_type = os.path.basename(sub_dir)
            data_file_names = [os.path.join(data_path, sub_dir, data_file_name) for data_file_name in os.listdir(path=os.path.join(data_path, sub_dir))
                                if self._is_support_file_type(filename=data_file_name)]
            self.data_filenames.extend(data_file_names)
            self.label_type.extend([self.class_dict[label_type] for _ in range(len(data_file_names))])

        self.transforms = transforms

    def __len__(self):
        return len(self.data_filenames)

    def __getitem__(self, i):
        assert i < len(self.data_filenames), f"Index {i} out of bounds."
        return self._get_data(data_file_name=self.data_filenames[i], label_type=self.label_type[i])

    def _get_data(self, data_file_name, label_type):
        try:
            # data_file_name = 'D:\\workspace\\ai_study\\dataset\\ChestX-Ray\\ChestXRay\\train\\COVID19\\COVID19(436).jpg'
            raw_image_data = Image.open(fp=data_file_name, mode="r")
            if not self._is_file_ext_supported("." + raw_image_data.format):
                raise ValueError(f"Image '{data_file_name}'s' format not JPEG")

            target = {}
            target[TARGET_FIELD_IMAGE_DATA] = raw_image_data
            target[TARGET_FIELD_IMAGE_ID] = data_file_name
            target[TARGET_FIELD_LABELS] = torch.as_tensor([label_type], dtype=torch.int64)
            target[TARGET_FIELD_HEIGHT_WIDTH] = [raw_image_data.height, raw_image_data.width]

            if self.transforms is not None:
                image, target = self.transforms(raw_image_data, target)
            return image, target[TARGET_FIELD_LABELS]
        except:
            print(f"Open {data_file_name} failed.")
            return (None, None)
    
    @staticmethod
    def collate_fn(batch):
        images, targets = tuple(zip(*batch))

        filter_images = []
        filter_targets = []
        for index in range(len(images)):
            if images[index] is not None:
                filter_images.append(images[index])
                filter_targets.append(targets[index])

        return tuple(filter_images), tuple(filter_targets)

    def _is_file_ext_supported(self, file_ext):
        return file_ext.lower() in self._SUPPORT_FILE_TYPES

    def _is_support_file_type(self, filename):
        return os.path.splitext(p=(filename.lower()))[-1] in self._SUPPORT_FILE_TYPES

def get_chest_xray_dataloader(data_root_path,
                              data_use_type,
                              transforms,
                              batch_size,
                              drop_last,
                              shuffle,
                              collate_fn):
    dataset = ChestXRayDataset(os.path.join(data_root_path, data_use_type), transforms)
    # 数据加载器
    return torch.utils.data.DataLoader(dataset=dataset,
                                       batch_size=batch_size,
                                       drop_last=drop_last,
                                       shuffle=shuffle,
                                       collate_fn=collate_fn if collate_fn is not None else dataset.collate_fn)

In [3]:
# util.py

from collections import defaultdict, deque

import torch
from torch import nn
from torch import Tensor

DEFAULT_HIDDEN_SIZE = 768
IMAGE_DEFAULT_INNER_SIZE = 256
PATCH_DEFAULT_SIZE = 16
IMAGE_DEFAULT_CHANNELS = 3

# 注意力计算函数
# [BatchSize, Head_i, SeqLen, Emb_Size]
def attention(Q, K, V, mask, max_len, emb_size):
    # b句话,每句话50个词,每个词编码成32维向量,4个头,每个头分到8维向量
    # Q,K,V = [b, 4, 50, 8]

    # [b, 4, 256 + 1, 192] * [b, 4, 192, 256 + 1] -> [b, 4, 256 + 1, 256 + 1]
    # Q,K矩阵相乘,求每个词相对其他所有词的注意力
    score = torch.matmul(Q, K.permute(0, 1, 3, 2))

    # 除以每个头维数的平方根,做数值缩放
    score /= 8 ** 0.5

    # mask 遮盖,mask是true的地方都被替换成-inf,这样在计算softmax的时候,-inf会被压缩到0
    # mask = [b, 1, 256 + 1, 256 + 1]
    if mask is not None:
        score = score.masked_fill_(mask, -float('inf'))

    score = torch.softmax(score, dim=-1)

    # 以注意力分数乘以V,得到最终的注意力结果
    # [b, 4, 256 + 1, 256 + 1] * [b, 4, 256 + 1, 192] -> [b, 4, 256 + 1, 192]
    score = torch.matmul(score, V)

    # 每个头计算的结果合一
    # [b, 4, 256 + 1, 192] -> [b, 256 + 1, 768]
    score = score.permute(0, 2, 1, 3).reshape(-1, max_len, emb_size)

    return score


# 多头注意力计算层
class MultiHead(nn.Module):
    def __init__(self, max_len, emb_size):
        super().__init__()
        self.max_len = max_len
        self.emb_size = emb_size
        # Q 矩阵
        self.fc_Q = nn.Linear(emb_size, emb_size)
        # K 矩阵
        self.fc_K = nn.Linear(emb_size, emb_size)
        # V 矩阵
        self.fc_V = nn.Linear(emb_size, emb_size)

        self.out_fc = nn.Linear(emb_size, emb_size)
        #
        self.norm = nn.LayerNorm(normalized_shape=emb_size, elementwise_affine=True)

        self.dropout = nn.Dropout(p=0.1)

    def forward(self, Q, K, V, mask):
        # Q, K, V 指的是 embedding + pe 之后的结果
        # b句话,每句话50个词,每个词编码成32维向量
        # Q,K,V = [b, 256 + 1, 768]

        # 批量
        b = Q.shape[0]

        # 保留下原始的Q,后面要做短接用
        clone_Q = Q.clone()

        # 规范化
        Q = self.norm(Q)
        K = self.norm(K)
        V = self.norm(V)

        # 线性运算,维度不变
        # [b, 256 + 1, 768] -> [b, 256 + 1, 768]
        K = self.fc_K(K)
        V = self.fc_V(V)
        Q = self.fc_Q(Q)

        # 拆分成多个头
        # b句话,每句话50个词,每个词编码成32维向量,4个头,每个头分到8维向量
        # [b, 256 + 1, 768] -> [b, 4, 256 + 1, 192]
        Q = Q.reshape(b, self.max_len, 4, self.emb_size // 4).permute(0, 2, 1, 3)
        K = K.reshape(b, self.max_len, 4, self.emb_size // 4).permute(0, 2, 1, 3)
        V = V.reshape(b, self.max_len, 4, self.emb_size // 4).permute(0, 2, 1, 3)

        # 计算注意力
        # [b, 4, 256 + 1, 192] -> [b, 256 + 1, 768]
        score = attention(Q, K, V, mask, self.max_len, self.emb_size)

        # 计算输出,维度不变
        # [b, 256 + 1, 768] -> [b, 256 + 1, 768]
        score = self.dropout(self.out_fc(score))

        # 短接
        score = clone_Q + score
        return score

# 位置编码层
class PatchEmbedding(nn.Module):
    def __init__(self,
                 in_channels=IMAGE_DEFAULT_CHANNELS,
                 patch_size=PATCH_DEFAULT_SIZE,
                 emb_size=DEFAULT_HIDDEN_SIZE,
                 img_size=IMAGE_DEFAULT_INNER_SIZE):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=emb_size,
                      kernel_size=patch_size,
                      stride=patch_size)
        )
        self.max_seq_len = (img_size // patch_size) ** 2 + 1
        self.emb_size = emb_size
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn(self.max_seq_len, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        # [B, C, H, W] -> [B, C', H', W']
        x = self.projection(x)
        # [B, C', H', W'] -> [B, H', W', C']
        x = x.permute(0, 2, 3, 1)
        # [B, H', W', C'] -> [B, PATCHES_COUNT, C']
        x = x.view(b, -1, self.emb_size)
        # [1, 1, C'] -> [B, 1, C']
        cls_tokens = self.cls_token.repeat(b, 1, 1)
        # [B, PATCHES_COUNT, C'] -> [B, PATCHES_COUNT + 1, C']
        x = torch.cat([cls_tokens, x], dim=1)
        # 融入位置编码信息
        x += self.positions

        return x

class ClassificationHead(nn.Module):
    def __init__(self, emb_size= DEFAULT_HIDDEN_SIZE, n_classes=CLASS_COUNT):
        super(ClassificationHead, self).__init__()
        self.classification_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=emb_size),
            nn.Linear(in_features=emb_size, out_features=n_classes)
        )

    def forward(self, x: Tensor) -> Tensor:
        # [B, PATCHES_COUNT, C'] -> [B, C']
        x = x.mean(dim=1)
        return self.classification_head(x)

# 全连接输出层
class FullyConnectedOutput(nn.Module):
    def __init__(self, hidden_size=DEFAULT_HIDDEN_SIZE):
        super().__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=hidden_size, out_features=2 * hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=2 * hidden_size, out_features=hidden_size),
            torch.nn.Dropout(p=0.1)
        )

        self.norm = torch.nn.LayerNorm(normalized_shape=hidden_size,
                                       elementwise_affine=True)

    def forward(self, x):
        # 保留下原始的x,后面要做短接用
        clone_x = x.clone()

        # 规范化
        x = self.norm(x)

        # 线性全连接运算
        # [b, 256 + 1, 768] -> [b, 256 + 1, 768]
        out = self.fc(x)

        # 做短接
        out = clone_x + out

        return out

In [4]:
# transforms.py

import random

import torchvision.transforms as t
from torchvision.transforms import functional as F

class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target=None):
        for trans in self.transforms:
            image, target = trans(image, target)
        return image, target


class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image).contiguous()
        channel_count = len(image)
        if channel_count > 1:
            image = image.mean(dim=0).unsqueeze(dim=0)
        return image, target


class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes,该方法应放在ToTensor后"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            image = image.flip(-1)  # 水平翻转图片
        return image, target


class SSDCropping(object):
    """
    对图像进行裁剪,该方法应放在ToTensor前
    """
    def __init__(self):
        super(SSDCropping, self).__init__()
        self.sample_options = (None, (0.3, 1.0))

    def __call__(self, image, target):
        # 死循环，确保一定会返回结果
        while True:
            mode = random.choice(self.sample_options)
            if mode is None:  # 不做随机裁剪处理
                return image, target

            # Implementation use 5 iteration to find possible candidate
            for _ in range(5):
                # 0.3*0.3 approx. 0.1
                w = random.uniform(0.3, 1.0)
                h = random.uniform(0.3, 1.0)
                if w/h < 0.5 or w/h > 2:  # 保证宽高比例在0.5-2之间
                    continue

                # left 0 ~ wtot - w, top 0 ~ htot - h
                left = random.uniform(0, 1.0 - w)
                top = random.uniform(0, 1.0 - h)

                right = left + w
                bottom = top + h

                htot = target[TARGET_FIELD_HEIGHT_WIDTH][0]
                wtot = target[TARGET_FIELD_HEIGHT_WIDTH][1]

                # 裁剪 patch
                left_idx = int(left * wtot)
                top_idx = int(top * htot)
                right_idx = int(right * wtot)
                bottom_idx = int(bottom * htot)
                image = image.crop((left_idx, top_idx, right_idx, bottom_idx))
                # image.save(f"./output/crop/crop_image{self.count}.jpg")
                # self.count += 1
                return image, target


class Resize(object):
    """对图像进行resize处理,该方法应放在ToTensor前"""
    def __init__(self, size=(256, 256)):
        self.resize = t.Resize(size)
        # self.count = 0

    def __call__(self, image, target):
        image = self.resize(image)
        # image.save(f"./output/resize/resize{self.count}.jpg")
        # self.count += 1
        return image, target


class ColorJitter(object):
    """对图像颜色信息进行随机调整,该方法应放在ToTensor前"""
    def __init__(self, brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05):
        self.trans = t.ColorJitter(brightness, contrast, saturation, hue)
        self.count = 0

    def __call__(self, image, target):
        image = self.trans(image)
        # image.save(f"./output/color_jitter/color_jitter{self.count}.jpg")
        # self.count += 1
        return image, target

# normMean = [0.48698336]
# normStd = [0.23673548]
class Normalization(object):
    """对图像标准化处理,该方法应放在ToTensor后"""
    def __init__(self, mean=None, std=None):
        # TODO: 计算训练集的mean和std
        if mean is None:
            mean = [0.48698336]
        if std is None:
            std = [0.23673548]
        self.normalize = t.Normalize(mean=mean, std=std)

    def __call__(self, image, target):
        image = self.normalize(image)
        return image, target

In [5]:
# model.py

import torch
from torch import nn

# 编码器层
class EncoderLayer(nn.Module):
    def __init__(self, max_len, emb_size):
        super().__init__()
        # 多头注意力
        self.mh = MultiHead(max_len, emb_size)
        # 全连接输出
        self.fc = FullyConnectedOutput()

    def forward(self, x):
        # 计算自注意力,维度不变
        # [b, 256 + 1, 768] -> [b, 256 + 1, 768]
        score = self.mh(x, x, x, mask=None)

        # 全连接输出,维度不变
        # [b, 256 + 1, 768] -> [b, 256 + 1, 768]
        out = self.fc(score)

        return out


class Encoder(torch.nn.Module):
    def __init__(self, max_len, emb_size):
        super().__init__()
        self.layer_1 = EncoderLayer(max_len, emb_size)
        self.layer_2 = EncoderLayer(max_len, emb_size)
        self.layer_3 = EncoderLayer(max_len, emb_size)

    def forward(self, x):
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        return x

# 主模型
class TransformerEncoder(torch.nn.Module):
    def __init__(self, max_len, emb_size):
        super().__init__()
        # 位置编码和词嵌入层
        self.encoder = Encoder(max_len, emb_size)

    def forward(self, x):
        # 编码层计算
        # [b, 256 + 1, 768] -> [b, 256 + 1, 768]
        return self.encoder(x)


class Vit(nn.Sequential):
    def __init__(self):
        super(Vit, self).__init__(
            PatchEmbedding(in_channels=1, patch_size=16, emb_size=768, img_size=256),
            TransformerEncoder(max_len=(256 // 16) ** 2 + 1, emb_size=768),
            ClassificationHead(emb_size=768, n_classes=4)
        )

In [None]:
# main.py

import torch
from PIL import Image
import os
import json
import datetime

# 预测函数
def predict(model, img_path):
    if len(img_path):
        print("img_path is null or empty")
        return

    original_img = Image.open(img_path)
    data_transform = Compose([Resize(),
                              ToTensor(),
                              Normalization()])
    img, _ = data_transform(original_img)
    # 改为批量预测
    x = torch.unsqueeze(img, dim=0)

    # x = [1, ]
    model.eval()
    pred = model(x)

    # read class_indict
    json_file = "./chest_x_ray_classes.json"
    assert os.path.exists(json_file), "{} file not exist.".format(json_file)

    with open(json_file, 'r') as f:
        class_dict = json.load(f)

    print("Predict result: " + class_dict[pred])

def main(args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 1，数据读取后的处理工作
    #     - 类型转换
    #     - 数据增强
    data_transform = {
        TRAIN_DATA_PATH: Compose([SSDCropping(),  # 图像切割
                                  Resize(),  # 统一大小
                                  ColorJitter(),  # 颜色抖动
                                  ToTensor(),  # 转张量
                                  RandomHorizontalFlip(),  # 水平翻转
                                  Normalization()]),  # 标准化

        VAL_DATA_PATH: Compose([Resize(),
                                ToTensor(),
                                Normalization()]),

        TEST_DATA_PATH: Compose([Resize(),
                                 ToTensor(),
                                 Normalization()])
    }

    # 构建训练数据集
    train_loader = get_chest_xray_dataloader(data_root_path=args.data_path,
                                             data_use_type=TRAIN_DATA_PATH,
                                             transforms=data_transform[TRAIN_DATA_PATH],
                                             batch_size=32,
                                             drop_last=False,
                                             shuffle=True,
                                             collate_fn=None)

    # 构建验证数据集
    val_loader = get_chest_xray_dataloader(data_root_path=args.data_path,
                                           data_use_type=VAL_DATA_PATH,
                                           transforms=data_transform[VAL_DATA_PATH],
                                           batch_size=5,
                                           drop_last=False,
                                           shuffle=True,
                                           collate_fn=None)

    # 构建测试数据集
    test_loader = get_chest_xray_dataloader(data_root_path=args.data_path,
                                            data_use_type=TEST_DATA_PATH,
                                            transforms=data_transform[TEST_DATA_PATH],
                                            batch_size=32,
                                            drop_last=False,
                                            shuffle=True,
                                            collate_fn=None)

    # 构建模型
    model = Vit()
    model.to(device=device)
    model.train()
    # 定义损失函数
    loss_func = torch.nn.CrossEntropyLoss()
    # 定义优化器
    optim = torch.optim.Adam(params=model.parameters(), lr=0.00001)

    if len(args.resume) > 0:
        checkpoint = torch.load(f=args.resume, map_location="cpu")
        model.load_state_dict(state_dict=checkpoint["model"])
        optim.load_state_dict(state_dict=checkpoint["optimizer"])
        args.start_epoch = checkpoint["last_epoch"] + 1
        print("the training process from epoch{}...".format(args.start_epoch))

    weighted_losses = torch.zeros(1).to(device=device)
    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        train_acc_list = []
        val_acc_list = []
        test_acc_list = []
        for i, (x, y) in enumerate(train_loader):
            if len(x) <= 0:
                continue
            x = torch.stack(x, dim=0)
            y = torch.stack(y, dim=0)
            # x = [B, 256 + 1, 768]
            x = x.to(device=device)
            # y = [B, 1]
            y = y.to(device=device)
            # pred = [B, 1]
            pred = model(x)

            loss = loss_func(pred, y.reshape(-1))
            weighted_losses = (i * weighted_losses + loss ) / (i + 1)

            optim.zero_grad()
            loss.backward()
            optim.step()

            if i % 20 == 0:
                lr = optim.param_groups[0]['lr']
                print("{} Epoch{}/{}: lr={}, cur_loss={:.4}, weighted_loss={:.4}".format(
                      datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
                      epoch,
                      i,
                      lr,
                      loss.item(),
                      weighted_losses.item()))

            # [B, 1] -> [B]
            pred = pred.argmax(dim=-1).reshape(-1)
            # [B, 1] -> [B]
            y = y.reshape(-1)

            train_acc_list.extend((pred == y).to(dtype=torch.float32).tolist())

        # save weights
        save_files = {
            'model': model.state_dict(),
            'optimizer': optim.state_dict(),
            'last_epoch': epoch}
        torch.save(save_files, "./save_weights/ChestXRay-{}.pth".format(epoch))

        # 验证
        for i, (x, y) in enumerate(val_loader):
            if len(x) <= 0:
                continue
            
            x = torch.stack(x, dim=0)
            y = torch.stack(y, dim=0)
            
            x = x.to(device=device)
            y = y.to(device=device)
            
            pred_val = model(x).argmax(dim=-1).reshape(-1)
            y = y.reshape(-1)
            val_acc_list.extend((pred_val == y).to(dtype=torch.float32).tolist())

        print("Epoch {}: train_acc={:4}/{}, val_acc={:4}/{}, lr={:6}".format(
            epoch,
            sum(train_acc_list) / len(train_acc_list),
            len(train_acc_list),
            sum(val_acc_list) / len(val_acc_list),
            len(val_acc_list),
            optim.param_groups[0]["lr"]))

    # 测试
    for i, (x, y) in enumerate(test_loader):
        if len(x) <= 0:
            continue
            
        x = torch.stack(x, dim=0)
        y = torch.stack(y, dim=0)
            
        x = x.to(device=device)
        y = y.to(device=device)
        
        pred_test = model(x).argmax(dim=-1).reshape(-1)
        y = y.reshape(-1)
        test_acc_list.extend((pred_test == y).to(dtype=torch.float32).tolist())

    print("Epoch {}-{}: test_acc={:4}/{}".format(
          args.start_epoch,
          args.start_epoch + args.epochs,
          sum(test_acc_list) / len(test_acc_list),
          len(test_acc_list)))

if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser(description=__doc__)

    # 检测的目标类别个数，不包括背景(替换：自己的检测类别)
    parser.add_argument('--num_classes', default=4, type=int, help='num_classes')
    # 训练数据集的根目录
    parser.add_argument('--data_path', default='/kaggle/input/chest-xray-pneumoniacovid19tuberculosis', help='dataset')
    # 文件保存地址
    parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
    # 若需要接着上次训练，则指定上次训练保存权重文件地址 /kaggle/working/save_weights/ChestXRay-5.pth
    parser.add_argument('--resume', default='/kaggle/working/save_weights/ChestXRay-24.pth', type=str, help='resume from checkpoint')
    # 指定接着从哪个epoch数开始训练
    parser.add_argument('--start_epoch', default=25, type=int, help='start epoch')
    # 训练的总epoch数
    parser.add_argument('--epochs', default=15, type=int, metavar='N',
                        help='number of total epochs to run')
    # 训练的batch size
    parser.add_argument('--batch_size', default=32, type=int, metavar='N',
                        help='batch size when training.')

    args = parser.parse_args(args=[])

    # 检查保存权重文件夹是否存在，不存在则创建
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
        
    main(args)

the training process from epoch25...
2023-03-19 06:47:13.668940 Epoch25/0: lr=1e-05, cur_loss=0.5735, weighted_loss=0.5735
2023-03-19 06:47:27.861850 Epoch25/20: lr=1e-05, cur_loss=0.3735, weighted_loss=0.4313
2023-03-19 06:47:41.683630 Epoch25/40: lr=1e-05, cur_loss=0.7257, weighted_loss=0.4344
2023-03-19 06:47:54.066411 Epoch25/60: lr=1e-05, cur_loss=0.2973, weighted_loss=0.4217
2023-03-19 06:48:07.384472 Epoch25/80: lr=1e-05, cur_loss=0.274, weighted_loss=0.4079
2023-03-19 06:48:20.670090 Epoch25/100: lr=1e-05, cur_loss=0.2656, weighted_loss=0.4011
2023-03-19 06:48:33.059486 Epoch25/120: lr=1e-05, cur_loss=0.3566, weighted_loss=0.3979
2023-03-19 06:48:47.376959 Epoch25/140: lr=1e-05, cur_loss=0.3241, weighted_loss=0.3909
2023-03-19 06:49:00.459987 Epoch25/160: lr=1e-05, cur_loss=0.3136, weighted_loss=0.3947
2023-03-19 06:49:14.161537 Epoch25/180: lr=1e-05, cur_loss=0.5902, weighted_loss=0.3958
Epoch 25: train_acc=0.8515649699652229/6326, val_acc=0.6578947368421053/38, lr= 1e-05
2023