## **Import**

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Subset
from torchvision import utils
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from torchsummary import summary
import time
import os
from tqdm.notebook import tqdm
from warnings import filterwarnings
from torch import nn, optim
from torch.backends import cudnn
cudnn.benchmark = True
from randaugment import RandAugment, ImageNetPolicy
from torch.autograd import Variable
import timm
from timm.data.transforms_factory import create_transform
from timm.data.dataset_factory import create_dataset
from timm.data.mixup import Mixup
import os
filterwarnings('ignore')
plt.rcParams['font.family'] = 'Malgun Gothic' 
os.environ['KMP_DUPLICATE_LIB_OK']='True'

## **function**

In [3]:
# function to get current lr
def get_lr(opt):
    for param_group in opt.param_groups:
        return param_group['lr']

def imshow(img):
    # [C, H, W]를 [H, W, C]로 변경합니다.
    img = img.transpose((1, 2, 0))
    # 표준화된 데이터를 원본 범위로 되돌립니다.
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = img * std + mean
    # 이미지의 범위가 [0, 1]이 되도록 클리핑합니다.
    img = np.clip(img, 0, 1)
    # 이미지를 표시합니다.
    plt.imshow(img)

def train_model(model, train_dataloader, optimizer, loss_func, lr_scheduler, epochs= 10, device = torch.device('cuda')):
    model.to(device)
    start_time = time.time()
    
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        current_lr = get_lr(optimizer)
        print('Epoch {}/{}, current lr={}'.format(epoch+1,epochs, current_lr))
        
        for batch in train_dataloader:
            image_tensors, true_labels = batch
            image_tensors, true_labels = image_tensors.to(device), true_labels.to(device)

            # Forward pass
            pred_labels = model(image_tensors)
            loss = loss_func(pred_labels, true_labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        lr_scheduler.step()
        
        avg_loss = total_loss / len(train_dataloader)
        print(f'Loss: {avg_loss:.4f}, Time: {((time.time()-start_time)/60):.6f}')
        print('-------------------------------')

# 테스트 데이터셋에 대한 모델의 정확도를 계산하는 함수
def calculate_testset_accuracy(model, test_loader, device):
    model.eval()  # 모델을 평가 모드로 설정
    correction_count = 0

    with torch.no_grad():
        for image_tensors, true_labels in test_loader:
            image_tensors = image_tensors.to(device)
            true_labels = true_labels.to(device)

            pred_labels = model(image_tensors)
            _, preds = torch.max(pred_labels, 1)
            correction_count += torch.sum(preds == true_labels.data)

    # 전체 정확도 계산
    accuracy = correction_count.double() / len(test_loader.dataset)
    print(f'Test Accuracy: {accuracy:.4f}')

# 테스트 데이터셋에 대한 모델의 클래스별 정확도를 계산하는 함수
def calculate_testset_accuracy_per_class(model, test_loader, num_classes, device = torch.device('cuda')):
    model.eval()  # 모델을 평가 모드로 설정
    corrected_class = [0 for _ in range(num_classes)]
    total_class = [0 for _ in range(num_classes)]

    with torch.no_grad():
        for image_tensors, true_labels in test_loader:
            image_tensors = image_tensors.to(device)
            true_labels = true_labels.to(device)
            pred_labels = model(image_tensors)
            _, preds = torch.max(pred_labels, 1)

            for true_label, pred_label in zip(true_labels, preds):
                if true_label == pred_label:
                    corrected_class[true_label] += 1
                total_class[true_label] += 1

    for i in range(num_classes):
        if total_class[i] == 0:
            print(f'Accuracy of class {i} : N/A (No samples in test set)')
        else:
            accuracy = 100 * corrected_class[i] / total_class[i]
            print(f'Accuracy of class {i} : {accuracy:.2f}% ({corrected_class[i]}/{total_class[i]})')

## **make DataLoader**

In [4]:
augment_transform = create_transform(224, is_training=True, auto_augment='rand-m6-mstd0.5', mean =[0.49623227595753333,0.48377202969644434,0.39612923273387035], std=[0.21602388510121484,0.2127661699292398,0.2127661699292398] )
train_dataset = create_dataset(name= '', root='C:/Users/VDRC/Desktop/pythonFolder/SK-Biology-Ai/data/image_dataset/train_image_set', transform = augment_transform)
val_dataset = create_dataset(name= '', root='C:/Users/VDRC/Desktop/pythonFolder/SK-Biology-Ai/data/image_dataset/test_image_set', transform = create_transform(224))


In [5]:
# Mixup 인자 설정
mixup_args = {
    'mixup_alpha': 1.,
    'cutmix_alpha': 1.,
    'prob': 1,
    'switch_prob': 0.5,
    'mode': 'batch',
    'label_smoothing': 0.1,
    'num_classes': 13
}

# Mixup 객체 생성
mixup_params = Mixup(**mixup_args)

class MixupDataLoader(DataLoader):
    def __init__(self, *args, mixup_params, **kwargs):
        super().__init__(*args, **kwargs)
        self.mixup_fn = mixup_params

    def __iter__(self):
        for batch in super().__iter__():
            inputs, targets = batch
            inputs, targets = inputs.to('cuda'), targets.to('cuda')
            mixed_inputs, mixed_targets = self.mixup_fn(inputs, targets)
            yield mixed_inputs, mixed_targets            

In [6]:
train_dataloader = MixupDataLoader(train_dataset, mixup_params=mixup_params, batch_size=32, shuffle=True, num_workers=6)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=6)

## **set model**

In [10]:
model_name = 'resnet101'
model = timm.create_model(model_name, pretrained=True, num_classes=13)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# summary(model, (3,224,224))

model.safetensors:   0%|          | 0.00/179M [00:00<?, ?B/s]

In [11]:
loss_func = nn.CrossEntropyLoss(reduction='sum')
opt = optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = StepLR(opt, step_size=1, gamma=0.8)

## **train & test**

In [None]:
train_model(model, train_dataloader, opt, loss_func, lr_scheduler, epochs= 10, device = device)

In [None]:
calculate_testset_accuracy(model, val_dataloader)
calculate_testset_accuracy_per_class(model, val_dataloader, 13)

In [28]:
torch.save(model.state_dict(), f'../../weight/{model_name}.pth')