In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/outta/clothes_segmentation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.functional import F
from tqdm.auto import tqdm
import seaborn as sns

from codes.load_testset import load_testset
from codes.show_result import show_segment, show_test_output
from codes.test_F1 import evaluate_F1
from codes.test_mIoU import evaluate_mIoU
from models.u2_net import U2NET_lite
from loss.scloss import SCLoss

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model = U2NET_lite(7).to(device=device)

In [None]:
train_loader, valid_loader = load_testset('../outta_3_project/dataset/G2_train.h5', 8000, 2000)
# 개인 데이터셋 경로로 지정할 것

In [None]:
test_input = None
test_label = None
for i, j in valid_loader:
    test_input = i[1]
    test_label = j[1]

show_test_output(model, test_input, test_label, device)

In [None]:
criterion = CrossEntropyLoss()
sc_loss = SCLoss()
optimizer = Adam(model.parameters(), lr=1e-8)

In [None]:
train = []
valid = []
f1 = []
miou = []

In [None]:
def train_valid(model, criterion, sc_loss, L,optimizer, trainloader, validloader, num_epochs=1):
    count = 0
    train_losses = []
    valid_losses = []
    f1_scores = []
    miou_scores = []

    for epoch in tqdm(range(num_epochs)):
        train_losses.append(epoch_loop(model, criterion, sc_loss, L,optimizer, trainloader))
        valid_losses.append(epoch_loop(model, criterion, sc_loss, L,optimizer, validloader, valid = True))

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_losses[-1]}, Valid Loss: {valid_losses[-1]}')

        f1_scores.append(evaluate_F1(model, valid_loader, 7, device))
        miou_scores.append(evaluate_mIoU(model, valid_loader, 7, device))

        print(f'Epoch {epoch+1}/{num_epochs}, F1: {f1_scores[-1]}, mioU: {miou_scores[-1]}')

        show_test_output(model, test_input, test_label, device)


        if train_losses[-1] > valid_losses[-1]:
            count += 1
            if count > 3:
                print('early stop')


        if min(valid_losses) == valid_losses[-1]:
            torch.save(model.state_dict(), f'../../../../model_{epoch}.pth')


    return train_losses, valid_losses, f1_scores, miou_scores


In [None]:
def epoch_loop(model, criterion, sc_loss, L, optimizer, dataloader, valid=False):
    epoch_loss = 0

    if valid:
        model.eval()
    else:
        model.train()

    for inputs, targets in tqdm(dataloader, leave=False):
        inputs, targets = inputs.to(device), targets.to(device)
        targets = targets.long()

        # Forward pass
        outputs = model(inputs)

        # Main Loss 계산 (CrossEntropyLoss)
        ce_loss = criterion(outputs, targets)

        # SCLoss 계산
        outputs_probs = F.softmax(outputs, dim=1)
        targets_onehot = F.one_hot(targets, num_classes=outputs.shape[1]).permute(0, 3, 1, 2).float()
        sc_loss_value = sc_loss(outputs_probs, targets_onehot)

        loss = ce_loss + L * sc_loss_value

        if not valid:
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss += loss.item()

        del inputs, targets, outputs, ce_loss, outputs_probs, targets_onehot, sc_loss_value, loss
        torch.cuda.empty_cache()

    return epoch_loss/len(dataloader)

In [None]:
optimizer = Adam(model.parameters(), lr=1e-4)
t, v, f, m = train_valid(model, criterion, sc_loss, 0.5, optimizer, train_loader, valid_loader, num_epochs=10)
train += t
valid += v
f1 += f
miou + m

In [None]:
optimizer = Adam(model.parameters(), lr=1e-6)
t, v, f, m = train_valid(model, criterion, sc_loss, 0.5, optimizer, train_loader, valid_loader, num_epochs=5)
train += t
valid += v
f1 += f
miou + m

In [None]:
optimizer = Adam(model.parameters(), lr=1e-8)
t, v, f, m = train_valid(model, criterion, sc_loss, 0.5, optimizer, train_loader, valid_loader, num_epochs=5)
train += t
valid += v
f1 += f
miou + m

In [None]:
sns.lineplot((train, valid))

In [None]:
torch.save(model.state_dict(), 'model_final.pth')

In [None]:
%cd clothes_segmentation/