In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.optim import Adam
from torch import nn
import numpy as np

In [None]:
TRANSFORM = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

EPOCHS = 1
BATCH_SIZE = 1

In [None]:
class MyModel(nn.Module):
    def __init__(self, train: bool=False, domain: bool=False, num_classes: int=10):
        super().__init__()
        self.prediction_model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', weights=None)
        self.prediction_model.classifier[1] = nn.Linear(self.prediction_model.last_channel, num_classes)
        self.domain_model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', weights=None)
        self.domain_model.classifier[1] = nn.Linear(self.domain_model.last_channel, 1)

        self.train = train
        self.domain = domain

    def forward(self, img):
        # training prediction flow
        if self.train:
            if self.domain:
                prob = self.domain_model(img)
                prob = prob.view(-1, 1)
                return prob
            else:
                return self.prediction_model(img)

        # normal prediction flow
        prob = self.domain_model(img)
        pred = np.argmax(prob)
        if pred == 1:
            return self.prediction_model(img)
        else:
            return np.random.randint(0, 9)

In [None]:
class BinaryDataLoader(Dataset):
    def __init__(self, dataset, label):
        self.dataset = dataset
        self.label = label

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        item, _ = self.dataset[index]
        label = self.label
        return item, label

In [None]:
def get_domain_data(in_path, out_path):
    in_data = datasets.ImageFolder(in_path, transform=TRANSFORM)
    out_data = datasets.ImageFolder(out_path, transform=TRANSFORM)

    in_data_loader = BinaryDataLoader(in_data, 1)
    out_data_loader = BinaryDataLoader(out_data, 0)

    combined_dataset = ConcatDataset([in_data_loader, out_data_loader])
    combined_data_loader = DataLoader(combined_dataset, batch_size=BATCH_SIZE, shuffle=True)

    return combined_data_loader

In [None]:
def learn(path_to_in_domain: str, path_to_out_domain: str):
    in_data = datasets.ImageFolder(path_to_in_domain, transform=TRANSFORM)
    out_data = datasets.ImageFolder(path_to_out_domain, transform=TRANSFORM)

    in_data_loader = DataLoader(in_data, batch_size=BATCH_SIZE, shuffle=True)
    out_data_loader = DataLoader(out_data, batch_size=BATCH_SIZE, shuffle=True)

    comb = get_domain_data(path_to_in_domain, path_to_out_domain)
    image, label = next(iter(comb))
    print(image.shape)
    print(label)
    # print(next(iter(comb)))

    model = MyModel()

    # train domain
    # model.train = True
    # model.domain = True
    #
    # running_loss = 0.
    # last_loss = 0.
    #
    # optimizer = Adam(model.parameters(), lr=1e-4)
    # criterion = nn.BCEWithLogitsLoss()

    # for e in range(EPOCHS):
    #     print(f'epoch: {e}')
    #     for i, data in enumerate(comb):
    #         inputs, labels = data
    #         labels = labels.view(-1, 1)
    #
    #         optimizer.zero_grad()
    #
    #         outputs = model(inputs)
    #         print(outputs)
    #         print(labels)
    #
    #         loss = criterion(outputs.float(), labels.float())
    #         loss.backward()
    #
    #         optimizer.step()

    # train prediction
    model.train = True
    model.domain = False
    running_loss = 0.
    last_loss = 0.

    optimizer = Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    for e in range(EPOCHS):
        print(f'epoch: {e}')
        for i, data in enumerate(in_data_loader):
            inputs, labels = data

            optimizer.zero_grad()

            outputs = model(inputs)

            print(outputs)
            print(labels)

            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

    return model

In [None]:
def accuracy(path_to_eval_folder: str, model) -> float:
    data = datasets.ImageFolder(path_to_eval_folder, transform=TRANSFORM)
    data_loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)

    model.train = False
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in data_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct/total

In [None]:
model = MyModel()

In [None]:
acc = accuracy('A4data/in-domain-eval', model)
print(acc)
acc_2 = accuracy('A4data/out-domain-eval', model)
print(acc_2)