In [None]:
import torch
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset,Subset, random_split
from sklearn.model_selection import KFold
import torchvision
from torchvision import datasets, models
from torchvision import transforms as T
import torchvision.transforms.functional as F
import torch.nn as nn
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import matplotlib.pyplot as plt
from IPython.display import display
import lightning as L
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import torchmetrics, argparse

from torchvision.datasets import ImageFolder

from PIL import Image
import os

import multiprocessing
num_workers = multiprocessing.cpu_count()
print(num_workers)
import timm
import wandb



In [None]:
wandb.login()

In [10]:
class CFG:
    ver = 3.1
    seed = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_dir =  "/root/signate_tecno/"
    input_dir = base_dir + "input/train"
    test_dir = base_dir + "input/test"
    output_dir = base_dir + "output/"
    sub_dir  = base_dir + "submit/"
    log_dir = base_dir + "logs/"
    model_dir = base_dir + "model/"
    ckpt_dir = base_dir + "ckpt/"

    MODEL = "vit-base"
    DATASET = "TECNO"
    n_folds = 5

    learning_rate = 1e-3
    weight_decay = 1e-5
    optimizer = "SGD"
    data_aug = "RandAug"

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## train用のデータセット作成

In [12]:
class ImageDataset(Dataset):
    def __init__(self,data_dir,transform=None,phase = "train"):
        super().__init__()
        self.data_dir = data_dir
        self.image_paths = []
        self.labels = []
        for label in os.listdir(data_dir):
            label_dir = os.path.join(data_dir, label)
            for image_name in os.listdir(label_dir):
                image_path = os.path.join(label_dir, image_name)
                self.image_paths.append(image_path)
                self.labels.append(1 if label == "hold" else 0)
        self.transform = T.Compose([
                                        T.Resize((336,336)),
                                        T.RandomHorizontalFlip(p=0.5),
                                        T.ToTensor(),
                                        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.phase = phase
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, label

In [13]:
class LitDataModule(L.LightningDataModule):
    def __init__(self, batch_size=128, data_dir="./input", data_aug="RandAug",num_folds = 5, fold = 0):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.data_aug = data_aug
        self.num_folds = num_folds  # フォールド数
        self.fold = fold  # 現在のフォールド番号

    def setup(self, stage=None):
        # トレーニング用とバリデーション用の変換を作成
        train_transform = T.Compose([
            T.Resize((384, 384)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomAdjustSharpness(sharpness_factor=2, p = 0.5),
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        val_transform = T.Compose([
            T.Resize((384, 384)),
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # 全データセットをまずロード
        full_dataset = ImageDataset(data_dir=self.data_dir, transform=None, phase="train")
        num_data = len(full_dataset)

        kf = KFold(n_splits=self.num_folds, shuffle = True, random_state = 42)
        indices = list(range(num_data))

        # データをランダムに分割
        for i, (train_indices, val_indices) in enumerate(kf.split(indices)):
            if i == self.fold:
                break

        # 分割されたインデックスに基づいてトレーニングとバリデーション用のサブセットを作成
        self.train_dataset = Subset(full_dataset, train_indices)
        self.val_dataset = Subset(full_dataset, val_indices)

        # トレーニングとバリデーションにそれぞれ異なるtransformを適用
        self.train_dataset.dataset.transform = train_transform
        self.val_dataset.dataset.transform = val_transform

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=4, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=4)


In [26]:
ds = ImageDataset(data_dir = CFG.input_dir)
dm = LitDataModule(data_dir = CFG.input_dir,batch_size = 32)
dm.prepare_data()
dm.setup(stage = "fit")

In [14]:
class ViTNet(L.LightningModule):
    def __init__(self,learning_rate = 1e-3, weight_decay = 1e-5, optimizer_name = "SGD", data_aug = "RandAug"):
        super().__init__()
        self.model = timm.create_model("vit_base_patch16_clip_384.laion2b_ft_in1k" , pretrained = True, num_classes = 2)

        # すべての層を固定
        for param in self.model.parameters():
            param.requires_grad = False

        # 'norm', 'fc_norm', 'head_drop', 'head' のみトレーニング可能にする
        for param in self.model.norm.parameters():
            param.requires_grad = True
        for param in self.model.fc_norm.parameters():
            param.requires_grad = True
        for param in self.model.head_drop.parameters():
            param.requires_grad = True
        for param in self.model.head.parameters():
            param.requires_grad = True
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.optimizer_name = optimizer_name
        self.data_aug = data_aug
        self.save_hyperparameters()
        self.acc = torchmetrics.classification.Accuracy(task= 'binary')
        self.class_acc = torchmetrics.classification.Accuracy(task = 'binary')
        self.loss_fn = nn.CrossEntropyLoss()
        self.predictions = []
        self.training_step_loss = []




    def forward(self,x):
        out = self.model(x)
        
        return out

    def _eval(self,batch,phase, on_step , on_epoch):
        x,y = batch
        out = self(x)
        loss = self.loss_fn(out, y)
        preds = torch.argmax(out, dim=1)
        acc = self.acc(preds, y)
        self.log(f"{phase}_loss", loss)
        self.log(f"{phase}_acc", acc, on_step = on_step, on_epoch = on_epoch)
        if phase == "val":
            self.class_acc(preds,y)
            self.log('hp_metric', acc, on_step = False, on_epoch = True,prog_bar = True, logger = True)
        return loss

    def training_step ( self,batch, batch_idx):
        loss = self._eval(batch, "train", on_step = False, on_epoch = True)
        self.training_step_loss.append(loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._eval(batch, "val", on_step = False, on_epoch = True)
        return loss

    def on_train_epoch_end(self) -> None:
        all_loss = torch.stack(self.training_step_loss)
        self.log("train_epoch_loss", all_loss.mean())


    def configure_optimizers(self):
        if self.optimizer_name == "SGD":
            optimizer = optim.SGD(self.parameters(), lr = self.learning_rate, weight_decay = self.weight_decay)
        elif self.optimizer_name == "Adam":
            optimizer = optim.Adam(self.parameters(), lr = self.learning_rate, weight_decay = self.weight_decay)
        elif self.optimizer_name == "AdamW":
            optimizer = optim.AdamW(self.parameters(), lr = self.learning_rate, weight_decay = self.weight_decay)

        return optimizer

net = ViTNet(learning_rate = CFG.learning_rate,
             weight_decay = CFG.weight_decay,
             optimizer_name = CFG.optimizer,
             data_aug = CFG.data_aug)


In [None]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=3,
)
num_folds = 5
for fold in range(num_folds):
    data_module = LitDataModule(batch_size=128, data_dir = CFG.input_dir, data_aug="RandAug", num_folds=num_folds, fold=fold)
    
    # フォールド番号を含めてModelCheckpointとWandbLoggerを設定
    model_checkpoint = ModelCheckpoint(
        monitor='val_loss',
        dirpath=CFG.ckpt_dir,
        filename=f'{CFG.DATASET}-{CFG.ver}-{CFG.MODEL}-fold{fold}' + '-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        mode='min',
    )
    wandb_logger = WandbLogger(name = f'vit-base-{CFG.ver}-fold{fold}', save_dir=CFG.log_dir)

    trainer = L.Trainer(default_root_dir=CFG.log_dir,
                        max_epochs=10, logger=wandb_logger,
                        callbacks=[model_checkpoint, early_stopping])
    
    # モデルのトレーニング
    trainer.fit(net, datamodule=data_module)
