In [1]:
import warnings
warnings.filterwarnings('ignore')

from glob import glob
import pandas as pd
import numpy as np 
from tqdm import tqdm
import cv2

import os
import timm
import random

import albumentations as A
from albumentations.pytorch import transforms, ToTensorV2

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.transforms as transforms

from sklearn.metrics import f1_score, accuracy_score

In [2]:
# !pip3 install timm albumentations

In [3]:
# Configs
config = {
    "SEED": 777,
    "SIZE": 384,
    "CSV": "./train_for_autoencoder.csv",
    
    "FOLD": -1,
    "BATCH_SIZE": 16,
    "LEARNING_RATE": 0.001,
    "EPOCHS": 100,
    "N_WORKERS": 4,
    
    "MODEL": "tf_efficientnet_b7",
    "MODEL_SAVE": "./UNet",
    "MODEL_SAVE_PREFIX": "UNet_",
    
    "DEVICE": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

In [4]:
if not os.path.exists(config["MODEL_SAVE"]):
    os.makedirs(config["MODEL_SAVE"], exist_ok=True)

In [5]:
random.seed(config["SEED"])
torch.cuda.manual_seed(config["SEED"])
torch.manual_seed(config["SEED"])
torch.cuda.empty_cache()

In [6]:
class CustomDataset(Dataset):
    def __init__(self,
                 data_path,
                 size,
                 transform=None,
                 fold=0,
                 mode="train"):
        self.csv = pd.read_csv(data_path)
        if 'kfold' in self.csv:
            if mode == "train":
                self.csv = self.csv[self.csv['kfold'] != fold]
            elif mode == "validation":
                self.csv = self.csv[self.csv['kfold'] == fold]
        
        self.path = self.csv['path'].to_list()
        self.labels = self.csv['encoded_label'].to_list()
        self.transform = transform
        self.size = size
        self.mode = mode
    
    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, idx):        
        # Image
        image = cv2.imread(self.path[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        #image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        if self.transform:
            image = self.transform(self.size)(image=image)['image']

        # Only test mode
        if self.mode == "test":
            return {
                'image': image
            }
        
        # Label
        label = self.labels[idx]
        label = torch.tensor(label, dtype=torch.long)
        
        return {
            'image': image,
            'label': label
        }

In [7]:
def create_train_transforms(size=512):
    return A.Compose([
        A.Resize(size, size),
        ToTensorV2()
    ])

In [8]:
def create_validation_transforms(size=128):
    return A.Compose([
        A.Resize(size, size),
        ToTensorV2()
    ])

In [9]:
train_dataset = CustomDataset(
    data_path = config["CSV"],
    size = config["SIZE"],
    transform = create_train_transforms,
    fold = config["FOLD"],
    mode = "train"
)
validation_dataset = CustomDataset(
    data_path = config["CSV"],
    size = config["SIZE"],
    transform = create_validation_transforms,
    fold = config["FOLD"],
    mode = "validation"
)
train_loader = DataLoader(
    dataset = train_dataset,
    shuffle = True,
    batch_size = config["BATCH_SIZE"],
    num_workers = config["N_WORKERS"]
)
validation_loader = DataLoader(
    dataset = validation_dataset,
    shuffle = False,
    batch_size = config["BATCH_SIZE"],
    num_workers = config["N_WORKERS"]
)

In [10]:
from UNet import UNet
model = UNet(3, 1).to(config["DEVICE"])

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=config["LEARNING_RATE"])
criterion = nn.MSELoss()

In [None]:
best_loss, best_pred = float('inf'), 0
for epoch in range(config["EPOCHS"]):
    model.train()
    pbar = tqdm(train_loader, total=len(train_loader))
    total_train_loss = 0
    train_data_cnt = 0
    train_loss = 0
    for batch in pbar:
        x = torch.tensor(batch['image'], dtype=torch.float32, device=config["DEVICE"])
        pred = model(x)

        loss = criterion(pred, x)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_train_loss += loss.item() / len(train_loader)
        train_loss = train_loss * train_data_cnt + loss.item()
        train_data_cnt += 1
        train_loss /= train_data_cnt
        pbar.set_postfix({
            "epoch": f"{epoch}/{config['EPOCHS']}",
            "train_loss" : f"{train_loss:.5f}",
            "total_train_loss": f"{total_train_loss:.5f}"
        })
        
    if best_loss > train_loss:
        best_loss = train_loss
        torch.save(model.state_dict(),
                   f"{config['MODEL_SAVE']}/{config['MODEL_SAVE_PREFIX']}_best.pth")
    
    torch.save(model.state_dict(),
               f"{config['MODEL_SAVE']}/{config['MODEL_SAVE_PREFIX']}_{epoch}.pth")
    pbar.close()

100%|██████████| 227/227 [01:37<00:00,  2.33it/s, epoch=0/100, train_loss=13418.45588, total_train_loss=13418.45588]
  6%|▌         | 14/227 [00:06<01:30,  2.35it/s, epoch=1/100, train_loss=12011.38274, total_train_loss=740.79013]