<a href="https://colab.research.google.com/github/wnsports/SportsAnalyticsNotebook/blob/main/SwinUnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/HuCaoFighting/Swin-Unet.git

Cloning into 'Swin-Unet'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 98 (delta 24), reused 14 (delta 14), pack-reused 58 (from 1)[K
Receiving objects: 100% (98/98), 42.76 KiB | 616.00 KiB/s, done.
Resolving deltas: 100% (40/40), done.


In [3]:
!pip install -q ml-collections SimpleITK timm einops tensorboardX yacs

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/101.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [35]:
%%writefile /content/Swin-Unet/train_custom.py

import argparse
import glob
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
import random

def random_rot_flip(image, label):
    k = np.random.randint(0, 4)
    image = np.rot90(image, k)
    label = np.rot90(label, k)
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis=axis).copy()
    label = np.flip(label, axis=axis).copy()
    return image, label


def random_rotate(image, label):
    angle = np.random.randint(-20, 20)
    image = ndimage.rotate(image, angle, order=0, reshape=False)
    label = ndimage.rotate(label, angle, order=0, reshape=False)
    return image, label


class RandomGenerator(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)
        x, y = image.shape[:2]
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)  # why not 3?
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=3)
        image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
        label = torch.from_numpy(label.astype(np.float32)).unsqueeze(0)
        sample = {'image': image, 'label': label}
        return sample

class CustomDataset(Dataset):
    def __init__(self, base_dir, transform=None):
        self.transform = transform  # using transform in torch!
        self.data_dir = base_dir
        self.input_files = sorted(glob.glob(self.data_dir + "/inputs/*.jpg"))
        self.output_files = sorted(glob.glob(self.data_dir + "/outputs/*.jpg"))

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_image = Image.open(self.input_files[idx])
        output_image = Image.open(self.output_files[idx])

        if self.transform:
            input_image_tensor = self.transform(input_image)
            output_image_tensor = self.transform(output_image)

        sample = {'image': input_image_tensor, 'label': output_image_tensor}

        return sample


# custom datasetを変更して、num_classes = 3にして、trainをmse_lossに変更？

from networks.vision_transformer import SwinUnet
from config import get_config
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import logging
import os


parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/content/datasets")
parser.add_argument('--num_classes', type=int,
                    default=9, help='output channel of network')
parser.add_argument('--output_dir', type=str, help='output dir')
parser.add_argument('--max_iterations', type=int,
                    default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
                    default=150, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
                    default=24, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int,  default=1,
                    help='whether use deterministic training')
parser.add_argument('--base_lr', type=float,  default=0.01,
                    help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
                    default=224, help='input patch size of network input')
parser.add_argument('--seed', type=int,
                    default=1234, help='random seed')
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                    help='no: no cache, '
                            'full: cache all data, '
                            'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
                    help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                    help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')

args = parser.parse_args()
config = get_config(args)

device = "cuda"

model = SwinUnet(config, img_size=args.img_size, num_classes=args.num_classes).to(device)


base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu
# max_iterations = args.max_iterations
db_train = CustomDataset(base_dir=args.root_path,
                               transform=transforms.Compose(
                                   [transforms.Resize(args.img_size), transforms.CenterCrop(args.img_size), transforms.ToTensor()]))
print("The length of train set is: {}".format(len(db_train)))

def worker_init_fn(worker_id):
    random.seed(args.seed + worker_id)


trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
                             worker_init_fn=worker_init_fn)
if args.n_gpu > 1:
    model = nn.DataParallel(model)
model.train()


snapshot_path = "/content/snap"
# optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
optimizer = optim.Adam(model.parameters(), lr=base_lr)
writer = SummaryWriter(snapshot_path + '/log')
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)  # max_epoch = max_iterations // len(trainloader) + 1
logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
best_performance = 0.0
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
    for i_batch, sampled_batch in enumerate(trainloader):
        image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
        image_batch, label_batch = image_batch.to(device), label_batch.to(device)
        outputs = model(image_batch)
        loss = F.mse_loss(outputs, label_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_

        iter_num = iter_num + 1
        writer.add_scalar('info/lr', lr_, iter_num)
        writer.add_scalar('info/total_loss', loss, iter_num)

        logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
        print('iteration %d : loss : %f' % (iter_num, loss.item()))

    save_interval = 50  # int(max_epoch/6)
    if epoch_num > int(max_epoch / 2) and (epoch_num + 1) % save_interval == 0:
        save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
        torch.save(model.state_dict(), save_mode_path)
        logging.info("save model to {}".format(save_mode_path))

    if epoch_num >= max_epoch - 1:
        save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
        torch.save(model.state_dict(), save_mode_path)
        logging.info("save model to {}".format(save_mode_path))
        iterator.close()
        break

writer.close()

Overwriting /content/Swin-Unet/train_custom.py


In [None]:
%cd /content/Swin-Unet
# 224の倍数ならいける
# SwinTransformerSysのimg_sizeを調整できるようにすればいけそう
!python train_custom.py --root_path /content/datasets --num_classes 3 --img_size 448 --batch_size 1 --cfg /content/Swin-Unet/configs/swin_tiny_patch4_window7_224_lite.yaml

In [42]:
# inference
%%writefile /content/Swin-Unet/inference.py

import argparse
from networks.vision_transformer import SwinUnet
from config import get_config
import torch

parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/content/datasets")
parser.add_argument('--num_classes', type=int,
                    default=9, help='output channel of network')
parser.add_argument('--output_dir', type=str, help='output dir')
parser.add_argument('--max_iterations', type=int,
                    default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
                    default=150, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
                    default=24, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int,  default=1,
                    help='whether use deterministic training')
parser.add_argument('--base_lr', type=float,  default=0.01,
                    help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
                    default=224, help='input patch size of network input')
parser.add_argument('--seed', type=int,
                    default=1234, help='random seed')
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                    help='no: no cache, '
                            'full: cache all data, '
                            'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
                    help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                    help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')

args = parser.parse_args()
config = get_config(args)

device = "cuda"
model = SwinUnet(config, img_size=args.img_size, num_classes=args.num_classes).to(device)
model.load_state_dict(torch.load("/content/snap/epoch_149.pth"))
model.eval()

# transformしてtensorに直して画像を出力って感じか



Overwriting /content/Swin-Unet/inference.py


In [43]:
%cd /content/Swin-Unet/
!python inference.py --img_size 448 --num_classes 3 --batch_size 1 --cfg /content/Swin-Unet/configs/swin_tiny_patch4_window7_224_lite.yaml

/content/Swin-Unet
=> merge config from /content/Swin-Unet/configs/swin_tiny_patch4_window7_224_lite.yaml
SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:3
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
---final upsample expand_first---
  model.load_state_dict(torch.load("/content/snap/epoch_149.pth"))
