In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
import torchvision.models as models

# Triplet Loss

In [2]:
from scripts import dataloader
# HSV color space
train_loader, val_loader, test_loader = dataloader.get_loaders(color_space='hsv')
# RGB color space
# train_loader, val_loader, test_loader = dataloader.get_loaders(color_space='rgb')
# Gray color space
# train_loader, val_loader, test_loader = dataloader.get_loaders(color_space='gray')
# LAB color space
# train_loader, val_loader, test_loader = dataloader.get_loaders(color_space='lab')

In [20]:
from sklearn.metrics import roc_auc_score

def metrics(labels, preds):
    roc = roc_auc_score(labels, preds, multi_class='ovr')
    # acc = accuracy_score(labels, preds)
    return roc
    

In [4]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import random 
import cv2

class TripletDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.classes = ['carrying', 'threat', 'normal']
        self.carry_data = [x for x in dataset if x.parent.name == self.classes[0]]
        random.shuffle(self.carry_data)
        self.threat_data = [x for x in dataset if x.parent.name == self.classes[1]]
        random.shuffle(self.threat_data)
        self.normal_data = [x for x in dataset if x.parent.name == self.classes[2]]
        random.shuffle(self.normal_data)
        self.data_dict = {0: self.carry_data, 1: self.threat_data, 2: self.normal_data}
        self.transform = transform


    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        anchor = self.dataset[idx]

        anchor_label = self.classes.index(anchor.parent.name)
        # print(anchor_label, len(self.data_dict[anchor_label]))
        anchor_image = cv2.imread(str(anchor))

        positive_sample = random.choice(self.data_dict[anchor_label])
        positive_image = cv2.imread(str(positive_sample))
        # positive_label = anchor_label

        negative_label = random.choice([x for x in self.data_dict.keys() if x != anchor_label])
        negative_sample = random.choice(self.data_dict[negative_label])
        negative_image = cv2.imread(str(negative_sample))

        anchor_image = cv2.cvtColor(anchor_image, cv2.COLOR_BGR2RGB)
        if self.transform:
            anchor_image = self.transform(anchor_image)
            negative_image = self.transform(negative_image)
            positive_image = self.transform(positive_image)
        
        return (anchor_image, positive_image, negative_image, anchor_label) # positive_label, negative_label)


In [5]:
import torchvision.transforms as T

def split_data(data_dir, train_size=0.8, val_size=0.1):
    random.seed(1234)
    data = Path(data_dir).glob('*/*')
    data = [x for x in data if x.is_file() and x.suffix != '.zip']
    random.shuffle(data)
    train_size = int(len(data) * train_size)
    val_size = int(len(data) * val_size)
    train_data = data[:train_size]
    val_data = data[train_size:train_size+val_size]
    test_data = data[train_size+val_size:]

    return train_data, val_data, test_data

train_transforms = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transforms = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transforms = {
    'train': train_transforms,
    'val': val_transforms,
    'test': test_transforms
}


train_data, val_data, test_data = split_data('data')
print(len(train_data), len(val_data), len(test_data))
train_dataset = TripletDataset(train_data, transform=transforms['train'])
val_dataset = TripletDataset(val_data, transform=transforms['val'])
train_loader = DataLoader(train_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)

3954 494 495


In [6]:
import torch 

x_anchor, x_positive, x_negative, y_anchor = next(iter(train_loader))
print(x_anchor.shape, x_positive.shape, x_negative.shape)
# print(y_anchor.shape, y_positive.shape, y_negative.shape)
# print(torch.sum(y_anchor == y_positive), torch.sum(y_anchor == y_negative))
# train_dataset.__getitem__(0)[0].shape

torch.Size([32, 3, 224, 224]) torch.Size([32, 3, 224, 224]) torch.Size([32, 3, 224, 224])


In [7]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [23]:
from tqdm import tqdm 

class TripletLossTrainer:
    def __init__(self, model, optimizer, scheduler, device='cuda'):
        # The trainer uses a one-hot distribution for the labels, so we need to use the CrossEntropyLoss
        # instead of the NLLLoss
        # Using FCC layer as the last layer, we can try to use basic loss functions like MSE or L1

        self.model = model
        self.optimizer = optimizer
        self.criterion = nn.TripletMarginLoss(margin=1.0, p=2)
        self.scheduler = scheduler
        self.best_acc = 0.5
        self.train_acc_arr = []
        self.val_acc_arr = []
        self.train_losses = []
        self.val_losses = []
        if (device == 'cuda') and torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu') 
    
    def get_loss(self, anchor, positive, negative):
        anchor = anchor.to(self.device)
        positive = positive.to(self.device)
        negative = negative.to(self.device)
        anchor_embedding = self.model(anchor)
        positive_embedding = self.model(positive)
        negative_embedding = self.model(negative)

        loss = self.criterion(anchor_embedding, positive_embedding, negative_embedding)

        # d_p = torch.sum((anchor_embedding - positive_embedding)**2, dim=1)
        # d_n = torch.sum((anchor_embedding - negative_embedding)**2, dim=1)
        # loss = torch.sum(torch.clamp(d_p - d_n + 0.2, min=0.0))
        return loss

    def train(self, train_loader, val_loader, epochs=10):
        self.model.to(self.device)
        total = 0
        correct = 0
        total_loss = 0
        for epoch in range(epochs):
            print(f"EPOCH {epoch}")
            self.model.train()
            tq = tqdm(enumerate(train_loader))
            for i, (anchor, positive, negative, y) in tq:
                # x = x.to(self.device)
                anchor = anchor.to(self.device)
                positive = positive.to(self.device)
                negative = negative.to(self.device)
                y_label = y
                y = F.one_hot(y, num_classes=3).to(self.device).float()
                total += y.size(0)
                
                self.optimizer.zero_grad()
                y_pred = self.model(anchor)
                loss = self.get_loss(anchor, positive, negative)
                
                loss.backward()
                self.optimizer.step()
                total_loss += loss

                # Calculate Accuracy - Only for softmax/logit distributions
                _, predicted = torch.max(y_pred.data, 1)
                correct += (predicted.cpu() == y_label).sum().item()
                tq.set_postfix(loss=loss.item(), acc=correct/total)
                if i % 100 == 0:
                    print(f'Epoch: {epoch}, Loss: {loss.item()}')
                
                # return y, y_pred 
                roc = metrics(y_label.detach().numpy(), F.softmax(y_pred, dim=1).detach().numpy())
                writer.add_scalar("Loss/train", loss, epoch)
                writer.add_scalar("ROC/train", roc, epoch)
            writer.add_scalar("Accuracy/train", correct/total, epoch)
            self.train_acc_arr.append(correct/total)
            self.train_losses.append(total_loss)
            self.validate(val_loader, epoch)
            self.scheduler.step()
            writer.add_hparams({'hparam/lr': self.scheduler.get_last_lr()[0]})
            print(f'Epoch: {epoch}, Accuracy: {correct/total}')
        writer.flush()

    def validate(self, val_loader, epoch):
        self.model.eval()
        total = 0
        correct = 0
        total_loss = 0
        with torch.no_grad():
            tq = tqdm(enumerate(val_loader))
            for i, (anchor, positive, negative, y) in tq:
                # x = x.to(self.device)
                anchor = anchor.to(self.device)
                positive = positive.to(self.device)
                negative = negative.to(self.device)
                y_label = y
                y = F.one_hot(y, num_classes=3).to(self.device).float()
                total += y.size(0)
                loss = self.get_loss(anchor, positive, negative)

                total += y.size(0)
                y_pred = self.model(anchor)
                _, predicted = torch.max(y_pred.data, 1)
                # print(predicted)

                total_loss += loss
                correct += (predicted.cpu() == y_label).sum().item()
                if i % 100 == 0:
                    print(f'Validation Loss: {loss.item()}')
                roc = metrics(y_label.detach().numpy(), F.softmax(y_pred, dim=1).detach().numpy())
                writer.add_scalar("Loss/val", loss, epoch)
                writer.add_scalar("ROC/val", roc, epoch)

            writer.add_scalar("Accuracy/val", correct/total, epoch)
            print(f'Validation Accuracy: {correct/total}')
            self.val_acc_arr.append(correct/total)
            self.val_losses.append(total_loss)
            if correct/total > self.best_acc:
                self.best_acc = correct/total
                print('Saving model...')
                torch.save(self.model.state_dict(), 'best_model.pth')

    def test(self, test_loader):
        self.model.eval()
        total = 0
        correct = 0
        with torch.no_grad():
            for i, (x, y) in tqdm(enumerate(test_loader)):
                x = x.to(self.device)
                y_label = y
                y = F.one_hot(y, num_classes=3).to(self.device).float()
                total += y.size(0)
                y_pred = self.model(x)
                loss = self.criterion(y_pred, y)

                _, predicted = torch.max(y_pred.data, 1)
                correct += (predicted.detach().numpy() == y_label).sum().item()
                if i % 100 == 0:
                    print(f'Test Loss: {loss.item()}')
        print(f'Accuracy: {100 * correct / total}')

    def save_model(self, path):
        torch.save(self.model.state_dict(), path)

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path))


In [None]:
# Training
from torchvision.models import resnet18

model = resnet18(weights=None, num_classes=3)
model.fc = nn.Linear(512, 3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

trainer = TripletLossTrainer(model, optimizer, scheduler)
trainer.train(train_loader, val_loader, epochs=10)

In [9]:
from torchvision.models import resnet18

model = resnet18(pretrained=False)
model.fc = nn.Linear(512, 3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

trainer = TripletLossTrainer(model, optimizer, scheduler)

y, y_pred = trainer.train(train_loader, val_loader, epochs=10)
# metrics(y.detach().numpy(), y_pred.detach().numpy())



EPOCH 0


0it [00:10, ?it/s]


In [27]:
import numpy as np 

print(F.softmax(y_pred, dim=1))
# fin_arr = np.zeros((y.shape[0]))
# for i in range(y.shape[0]):
#     fin_arr[i] = np.argmax(y[i].detach().numpy())
# print(fin_arr)

tensor([[0.2649, 0.2284, 0.5067],
        [0.2577, 0.2193, 0.5229],
        [0.2421, 0.1767, 0.5811],
        [0.2484, 0.2293, 0.5223],
        [0.2943, 0.2524, 0.4534],
        [0.2484, 0.2227, 0.5289],
        [0.2867, 0.2380, 0.4753],
        [0.3395, 0.2106, 0.4499],
        [0.2159, 0.2113, 0.5727],
        [0.2752, 0.2088, 0.5160],
        [0.2442, 0.2133, 0.5425],
        [0.2431, 0.2145, 0.5424],
        [0.2842, 0.2169, 0.4989],
        [0.2796, 0.2336, 0.4868],
        [0.2305, 0.2300, 0.5395],
        [0.3362, 0.2321, 0.4317],
        [0.2552, 0.2002, 0.5446],
        [0.1779, 0.2361, 0.5859],
        [0.3064, 0.2664, 0.4272],
        [0.2616, 0.2448, 0.4937],
        [0.2735, 0.1925, 0.5340],
        [0.2875, 0.2206, 0.4919],
        [0.2035, 0.1849, 0.6116],
        [0.2450, 0.2402, 0.5148],
        [0.3234, 0.2142, 0.4624],
        [0.2241, 0.2000, 0.5759],
        [0.2902, 0.2364, 0.4734],
        [0.3075, 0.2013, 0.4912],
        [0.2430, 0.2341, 0.5229],
        [0.279