In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from torch import cuda
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.functional import F


from codes.load_testset import load_testset
from models.u2_net import U2NET_lite
from loss.scloss import SCLoss

In [3]:
device = 'cuda' if cuda.is_available() else 'cpu'

In [4]:
model = U2NET_lite(7).to(device=device)

In [5]:
train_loader, valid_loader = load_testset('dataset/G2_train.h5', 8000, 2000)

torch.Size([8000, 3, 128, 128]) torch.Size([8000, 128, 128])


In [6]:
criterion = CrossEntropyLoss()
sc_loss = SCLoss()
optimizer = Adam(model.parameters(), lr=1e-4)

In [7]:
def train(model, criterion, sc_loss, optimizer, dataloader, num_epochs=1):
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            targets = targets.long()

            # Forward pass
            outputs = model(inputs)
            print(targets.shape, outputs.shape )
            # Main Loss 계산 (CrossEntropyLoss)
            ce_loss = criterion(outputs, targets)
            print(ce_loss)
            # SCLoss 계산
            
            outputs_probs = F.softmax(outputs, dim=1)  # outputs는 [b, 7, 128, 128]에서 확률 분포로 변환됨

            # 타겟을 원핫 인코딩으로 변환
            targets_onehot = F.one_hot(targets, num_classes=outputs.shape[1]).permute(0, 3, 1, 2).float()  # [b, 7, 128, 128]
            
            sc_loss_value = sc_loss(outputs_probs, targets_onehot)
            print(sc_loss_value)
            
            # 두 손실을 합산 (가중치를 조정 가능)
            loss = ce_loss + sc_loss_value
            print(loss)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            print(epoch_loss)


        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader)}')

In [8]:
train(model, criterion, sc_loss, optimizer, train_loader, num_epochs=1)

torch.Size([32, 128, 128]) torch.Size([32, 7, 128, 128])
tensor(1.9134, device='cuda:0', grad_fn=<NllLoss2DBackward0>)
tensor(1.9134, device='cuda:0', grad_fn=<NllLoss2DBackward0>)
tensor(1.9334, device='cuda:0', grad_fn=<AddBackward0>)
1.9333983659744263
torch.Size([32, 128, 128]) torch.Size([32, 7, 128, 128])
tensor(1.8094, device='cuda:0', grad_fn=<NllLoss2DBackward0>)
tensor(1.8094, device='cuda:0', grad_fn=<NllLoss2DBackward0>)
tensor(1.8300, device='cuda:0', grad_fn=<AddBackward0>)
3.763394594192505
torch.Size([32, 128, 128]) torch.Size([32, 7, 128, 128])
tensor(1.7126, device='cuda:0', grad_fn=<NllLoss2DBackward0>)
tensor(1.7126, device='cuda:0', grad_fn=<NllLoss2DBackward0>)
tensor(1.7330, device='cuda:0', grad_fn=<AddBackward0>)
5.496382355690002
torch.Size([32, 128, 128]) torch.Size([32, 7, 128, 128])
tensor(1.6062, device='cuda:0', grad_fn=<NllLoss2DBackward0>)
tensor(1.6062, device='cuda:0', grad_fn=<NllLoss2DBackward0>)
tensor(1.6262, device='cuda:0', grad_fn=<AddBackward0